diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b2727d933191de0f2476e57df1850b14a48769f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/_trace_wrapped_higher_order_op.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/_trace_wrapped_higher_order_op.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..105a80ecd05ec5729dabbeed7428a35c06f69098 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/_trace_wrapped_higher_order_op.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/aot_compile.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/aot_compile.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d32063628ccb9d0cda8dd202dcf3cb1ae250a103 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/aot_compile.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/aot_compile_types.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/aot_compile_types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc0c7793f6f07e115866ea761c5429e489087657 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/aot_compile_types.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/bytecode_analysis.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/bytecode_analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6d4f42cfebb992a94d59e7b038ad9a01721c0ea Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/bytecode_analysis.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/bytecode_transformation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/bytecode_transformation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..978df42551011c745d653649df01acf02fd194b2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/bytecode_transformation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/cache_size.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/cache_size.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f170e39214012eb6958983260b1b751eb701891 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/cache_size.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/callback.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/callback.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..132ee2a97aa36188b23e7ae88f6b8eb898f4f83d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/callback.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/code_context.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/code_context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7c1d7563ca6774646b0f5f7350209ca15a34c77 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/code_context.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/codegen.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/codegen.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a16371ee2b13ebdf4dfcb2ea5f34ae693b902cf2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/codegen.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/compiled_autograd.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/compiled_autograd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..386498f8cb7e2c562e81c637065f2cbaac7f08f4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/compiled_autograd.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/comptime.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/comptime.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fe494e0360cca8e44761d800e03a87f7ec0ce86 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/comptime.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..723265689f69ad68590e091d940d8c509f5dbaeb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/config.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/convert_frame.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/convert_frame.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc8219d99fedf42bed253f16a5f623984dc7043b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/convert_frame.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/create_parameter_op.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/create_parameter_op.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bceecef586cf02b72a64a2b34b1bef62468624b5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/create_parameter_op.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/current_scope_id.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/current_scope_id.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..544f63467c52869beae043fcc76fa801722d5c23 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/current_scope_id.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/dce_extra_outputs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/dce_extra_outputs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52358893368451f37d159ecd1dce9d6791995e62 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/dce_extra_outputs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/debug_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/debug_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1786b3918e5dd2210539d088c602a714624a9c79 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/debug_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/decorators.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/decorators.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2f81d46bd5c23366d1fff55637db082fc6de6fb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/decorators.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/device_interface.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/device_interface.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bb1de1fa5a02b6c621cb64c0f1452f980eda8cd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/device_interface.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/distributed.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/distributed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a45f1be4363fff40004b3d994471f6622ba3811 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/distributed.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/exc.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/exc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd9e1d90b8a3564c708fc531674d4f2013c87800 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/exc.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/external_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/external_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcaa50d640c2c4a9f763cb029063bc8ee6a11c99 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/external_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/funcname_cache.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/funcname_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6611a606b6e305b359a910db0d7b5a7dc50925d6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/funcname_cache.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/functional_export.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/functional_export.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f4e0cfffeaa5e99ff94fadee46ad2b37a95e6ab Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/functional_export.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_break_hints.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_break_hints.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdabea6db345fa823cf535735d96dc9775ff1e1f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_break_hints.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_bytecode_inputs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_bytecode_inputs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b1a3e40641a66d79d5cbf2c8858c8cc09768487 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_bytecode_inputs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_deduplication.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_deduplication.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03a9e51178e70e9f84051b23cb44e3f4e87b8d25 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_deduplication.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_region_tracker.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_region_tracker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b53acd527eb80d1a1fd3c04c0763c7948b6c394 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_region_tracker.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b826da9973187428365e39a45e16a5a7ba0567fc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/hooks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/hooks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2da7cd8a4f9eeab5f7b23101b5cc4ad54c790b37 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/hooks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/logging.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/logging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5be4fbf4b8e6a12d77369c2e63e76bc6d6fe6e54 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/logging.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/metrics_context.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/metrics_context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f47bc7b25ebec88d25e12f11802405f7783ac9d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/metrics_context.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/mutation_guard.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/mutation_guard.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..487cc6e618dd0bcd207a4c3e711e16bc2e9c2d97 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/mutation_guard.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/package.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/package.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5610c2adc36e9f85c8cd144d6a016a6086f2fdde Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/package.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/pgo.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/pgo.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08e445408a1198dfb3d7bf8208ee22edbf917067 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/pgo.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/precompile_context.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/precompile_context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0fbae5e7d763063c12f0b8ec629a515d07d794f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/precompile_context.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/profiler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/profiler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af6cf7dd2ba2f9757ea3ce27782135f78c6edbf9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/profiler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/replay_record.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/replay_record.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5be9a49fefc833323e14342625aa875421ad414 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/replay_record.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/resume_execution.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/resume_execution.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d165d55e7d7aeeea8f4b0cbf3761a87bfafb635 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/resume_execution.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/side_effects.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/side_effects.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b7c82660e3997a2feb11fd0bcf4ffb6bd70dc40 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/side_effects.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/source.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/source.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a23dd5b5b942755c977bd8d193b008d4ca610ae9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/source.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/tensor_version_op.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/tensor_version_op.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d502bb5e3dd5b9eb459de4636a2fd4647fe0102d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/tensor_version_op.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_case.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_case.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07e9ba8fb421aa693dcb415a4744ce1806141c38 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_case.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_dont_skip_tracing_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_dont_skip_tracing_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6b44ca9c309c9c301c13c24803db65f1ad6e626 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_dont_skip_tracing_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_minifier_common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_minifier_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39e5e550d1331a99b12fac518a27a3eb3f844182 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_minifier_common.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/testing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/testing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..531e75bf76c0cbad0bab2ae365cf70c18a74ecfe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/testing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/types.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5b3ccedd01b2605d228a31d47e7ef9f3d2f49d9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/__pycache__/types.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afa572aedf8ad220651aa5c2eb6d47b8fe7018b4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6384206e5e15671f13fdb1aa5753fad1f43ad9f9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/common.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/cudagraphs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/cudagraphs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..082759922d2d4c878ce3af8c267719ce7e7172e1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/cudagraphs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/debugging.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/debugging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e657386da5308633e14bafe3a83c1118c70cbc23 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/debugging.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/distributed.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/distributed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf6d7e2e06a8fb439769e05e0f9bcd9d0e18166b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/distributed.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/inductor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/inductor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed76cc166c6cb03f4558cfc35f219bb5cbb664f8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/inductor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/onnxrt.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/onnxrt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b2b14a101be700f896cb3bed6cb11d4fecee1b9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/onnxrt.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/registry.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12ab6f937599604c5c2489ebfdea1b263932659c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/registry.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/tensorrt.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/tensorrt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db0152120deb13be3b03a9e5e4a34eeb3d1fb195 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/tensorrt.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/torchxla.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/torchxla.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c207cd623e2e6079d9720411435bebda5b24803c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/torchxla.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/tvm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/tvm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7091ff5f2282272e860e24ecb91b27552210f5c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/tvm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/common.py new file mode 100644 index 0000000000000000000000000000000000000000..0d2b6ecff0c17d70fd978058c1b5a5915aa41158 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/common.py @@ -0,0 +1,183 @@ +""" +This module provides common utilities and base classes for TorchDynamo backends. + +Key components: +- AotAutograd: Base class for implementing AOT (Ahead-of-Time) autograd backends +- Backend utilities for handling: + - Fake tensor conversion + - Device/dtype detection from inputs + - Memory efficient fusion + - Graph flattening + - Common compiler configurations + +The utilities here are used by various backend implementations to handle +common operations and provide consistent behavior across different backends. +AOT autograd functionality is particularly important as it enables ahead-of-time +optimization of both forward and backward passes. +""" + +import contextlib +import functools +import logging +from collections.abc import Callable, Iterable +from typing import Any +from typing_extensions import ParamSpec, TypeVar +from unittest.mock import patch + +import torch +from torch._dynamo import disable +from torch._dynamo.exc import TensorifyScalarRestartAnalysis +from torch._dynamo.utils import counters, defake, flatten_graph_inputs +from torch._functorch.aot_autograd import ( + aot_module_simplified, + SerializableAOTDispatchCompiler, +) +from torch.utils._python_dispatch import _disable_current_modes + + +log = logging.getLogger(__name__) + +P = ParamSpec("P") +R = TypeVar("R") + + +class AotAutograd: + def __init__(self, **kwargs: Any) -> None: + self.__name__ = "compiler_fn" + self.kwargs = kwargs + + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: Iterable[Any], **kwargs: Any + ) -> Callable[..., Any]: + if kwargs: + log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs) + + if any(isinstance(x, (list, tuple, dict)) for x in example_inputs): + return flatten_graph_inputs( + gm, + example_inputs, + self, + ) + + # Hack to get around circular import problems with aot_eager_decomp_partition + if callable(self.kwargs.get("decompositions")): + self.kwargs["decompositions"] = self.kwargs["decompositions"]() + + # NB: dont delete counter increment + counters["aot_autograd"]["total"] += 1 + use_fallback = False + + if use_fallback: + log.debug("Unable to use AOT Autograd because graph has mutation") + counters["aot_autograd"]["not_ok"] += 1 + return gm + + def wrap_bw_compiler(bw_compiler_fn: Callable[P, R]) -> Callable[..., R]: + def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R: + # Note [Wrapping bw_compiler in disable] + # The two disables here: + # - stop TorchDynamo from trying to compile the bw_compiler function itself + # - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces + + return disable( + disable( + bw_compiler_fn, reason="do not trace backward compiler function" + )(*args, **kwargs), # type: ignore[misc] + reason="do not trace generated backwards pass", + ) + + _wrapped_bw_compiler._is_wrapped_bw_compiler = ( # pyrefly: ignore [missing-attribute] + True + ) + return _wrapped_bw_compiler + + bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"] + + if isinstance(bw_compiler, SerializableAOTDispatchCompiler): + bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn) + elif getattr(bw_compiler, "_is_wrapped_bw_compiler", False): + bw_compiler.compiler_fn = bw_compiler + else: + bw_compiler = wrap_bw_compiler(bw_compiler) + + self.kwargs["bw_compiler"] = bw_compiler + self.kwargs["inference_compiler"] = ( + self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"] + ) + + from functorch.compile import nop + from torch._inductor.debug import enable_aot_logging + + # debug asserts slow down compile time noticeably, + # So only default them on when the aot_eager backend is used. + if self.kwargs.get("fw_compiler", None) is nop: + patch_config: contextlib.AbstractContextManager[Any] = patch( + "functorch.compile.config.debug_assert", True + ) + else: + patch_config = contextlib.nullcontext() + + try: + # NB: NOT cloned! + with enable_aot_logging(), patch_config: + cg = aot_module_simplified(gm, example_inputs, **self.kwargs) + counters["aot_autograd"]["ok"] += 1 + return disable(cg, reason="do not trace AOT-compiled graph") + except TensorifyScalarRestartAnalysis: + raise + except Exception: + counters["aot_autograd"]["not_ok"] += 1 + raise + + +def aot_autograd(**kwargs: Any) -> AotAutograd: + return AotAutograd(**kwargs) + + +def mem_efficient_fusion_kwargs(use_decomps: bool) -> dict[str, Any]: + from functorch.compile import ( + default_decompositions, + min_cut_rematerialization_partition, + ts_compile, + ) + + kwargs = { + # these are taken from memory_efficient_fusion() + "fw_compiler": ts_compile, + "bw_compiler": ts_compile, + "partition_fn": min_cut_rematerialization_partition, + } + + if use_decomps: + kwargs["decompositions"] = default_decompositions + + return kwargs + + +def fake_tensor_unsupported(fn: Callable[[Any, list[Any], Any], R]) -> Any: + """ + Decorator for backends that need real inputs. We swap out fake + tensors for zero tensors. + """ + + @functools.wraps(fn) + def wrapper(model: Any, inputs: Any, **kwargs: Any) -> Any: + with _disable_current_modes(): + inputs = list(map(defake, inputs)) + return fn(model, inputs, **kwargs) # type: ignore[call-arg] + + return wrapper + + +def device_from_inputs(example_inputs: Iterable[Any]) -> torch.device: + for x in example_inputs: + if hasattr(x, "device"): + return x.device + return torch.device("cpu") # Default fallback + + +def dtype_from_inputs(example_inputs: Iterable[Any]) -> torch.dtype: + for x in example_inputs: + if hasattr(x, "dtype"): + return x.dtype + return torch.float32 # Default fallback diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/cudagraphs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/cudagraphs.py new file mode 100644 index 0000000000000000000000000000000000000000..0346614583921620cf9c06433641c8b685d936aa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/cudagraphs.py @@ -0,0 +1,299 @@ +""" +This module implements CUDA graphs support for TorchDynamo backends. + +CUDA graphs allow for capturing and replaying GPU operations, which can significantly +reduce CPU overhead in GPU-accelerated PyTorch models. This module provides: + +- CUDA graph creation and management for both forward and backward passes +- Input mutation detection and handling +- Device compatibility checking +- Stack trace management for debugging +- Integration with TorchInductor's cudagraph trees + +The backend supports two main modes: +1. cudagraphs: Full CUDA graph support with both forward and backward pass optimization +2. cudagraphs_inner: Lower-level CUDA graph implementation used for benchmarking + +Key components: +- CudagraphsBackend: Main backend class for CUDA graph integration +- Mutation detection utilities to ensure graph safety +- Device mapping and compatibility checks +- Stack trace collection for debugging +""" + +import functools +from collections import defaultdict +from collections.abc import Callable, Sequence +from typing import Any, Optional + +import torch +import torch.fx +from torch._dynamo import config +from torch._dynamo.backends.common import aot_autograd +from torch._dynamo.backends.debugging import boxed_nop +from torch._inductor.cudagraph_utils import ( + BoxedDeviceIndex, + check_multiple_devices_or_any_cpu_nodes, + format_default_skip_message, + get_mutation_stack_trace, + get_placeholder_info, + log_cudagraph_skip_and_bump_counter, +) +from torch._inductor.utils import ( + BoxedBool, + count_tangents, + get_first_incompatible_cudagraph_node, + num_fw_fixed_arguments, + output_node, +) +from torch.multiprocessing.reductions import StorageWeakRef + +from .registry import register_backend + + +def find_input_mutations(g: torch.fx.Graph) -> set[int]: + def meta_fk(meta: dict[str, Any]) -> Any: + return meta["val"] if "val" in meta else meta["fake_result"] + + inputs = defaultdict(set) + input_idx = 0 + mutated_inputs = set() + for n in g.nodes: + if n.op == "placeholder": + if isinstance(meta_fk(n.meta), torch.Tensor): + inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx) + input_idx += 1 + elif n.op == "call_function": + if not hasattr(n.target, "_schema"): + continue + + schema = n.target._schema + for i, arg in enumerate(schema.arguments): + if i < len(n.args): + argument = n.args[i] + else: + if arg.name not in n.kwargs: + continue + argument = n.kwargs[arg.name] + mut_arg = False + if arg.alias_info: + if arg.alias_info.is_write: + mut_arg = True + if mut_arg: + # TODO: not correct for args that contain tensors in a struct + # like list + mutated_inputs |= inputs[ + StorageWeakRef(meta_fk(argument.meta)._typed_storage()) + ] + + # TODO: error on unrecognized nodes + return mutated_inputs + + +def get_device_node_mapping( + gm: torch.fx.GraphModule, +) -> dict[torch.device, torch.fx.Node]: + device_node_mapping: dict[torch.device, torch.fx.Node] = {} + for n in gm.graph.nodes: + t = n.meta.get("val", None) + if isinstance(t, torch.Tensor) and t.device not in device_node_mapping: + device_node_mapping[t.device] = n + return device_node_mapping + + +def check_for_mutation_ignore_cuda_graph_managed_tensor( + aot_model: torch.fx.GraphModule, num_fixed: int +) -> Optional[str]: + mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed)) + if not mutation_indices: + return None + + placeholders = get_placeholder_info(aot_model.graph) + return get_mutation_stack_trace(placeholders, mutation_indices) + + +def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[str]: + if not config.cudagraph_backend_support_input_mutation: + if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor( + aot_model, num_fixed + ): + return mut_skip + + if skip := check_multiple_devices_or_any_cpu_nodes( + get_device_node_mapping(aot_model) + ): + return skip + + if node := get_first_incompatible_cudagraph_node(aot_model): + return format_default_skip_message(f"incompatible op ({node.name})") + + return None + + +def get_device_index(gm: torch.fx.GraphModule) -> int: + device = next(iter(get_device_node_mapping(gm))) + assert device.type == "cuda" + return device.index + + +def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]: + output = output_node(gm) + assert len(output.args) == 1 + args = output.args[0] + if not hasattr(args, "__iter__"): + return [] + return [ + (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) + for arg in args # type: ignore[union-attr] + ] + + +def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) -> Any: + from torch._inductor.cudagraph_trees import cudagraphify_impl + + do_cudagraphs = BoxedBool(True) + boxed_device_index = BoxedDeviceIndex(None) + + def forward_cudagraphs( + aot_model: torch.fx.GraphModule, + aot_inputs: list[Any], + is_inference: bool = False, + ) -> Any: + interp = boxed_nop(aot_model, aot_inputs) + fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs)) + if skip_msg := check_for_skip(aot_model, fixed): + BoxedBool.disable(do_cudagraphs) + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {skip_msg}" + ) + return interp + + boxed_device_index.set(get_device_index(aot_model)) + out = cudagraphify_impl( + interp, + aot_inputs, + range(fixed), + device_index=boxed_device_index.value, + is_backward=False, + is_inference=False, # Q: should forward is_inference here? + stack_traces=get_stack_traces(aot_model), + placeholders=get_placeholder_info(aot_model.graph), + mutated_input_idxs=find_input_mutations(aot_model.graph), + ) + out._boxed_call = True # type: ignore[attr-defined] + return out + + def backward_cudagraphs( + aot_model: torch.fx.GraphModule, aot_inputs: list[Any] + ) -> Any: + interp = boxed_nop(aot_model, aot_inputs) + if not do_cudagraphs: + return aot_model + + fixed = count_tangents(aot_model) + if skip_msg := check_for_skip(aot_model, fixed): + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {skip_msg}" + ) + + # See [Backward Generation Handling] + device_idx = boxed_device_index.value + if device_idx is None: + device_idx = 0 # Default to device 0 if not set + manager = torch._inductor.cudagraph_trees.get_manager( + device_idx, create_if_none_exists=False + ) + assert manager is not None + + def fn(inputs: list[Any]) -> Any: + # pyrefly: ignore [missing-attribute] + manager.set_to_running_backward() + return aot_model(inputs) + + fn._boxed_call = True # type: ignore[attr-defined] + return fn + + out = cudagraphify_impl( + interp, + aot_inputs, + range(fixed), + device_index=get_device_index(aot_model), + is_backward=True, + is_inference=False, + stack_traces=get_stack_traces(aot_model), + placeholders=get_placeholder_info(aot_model.graph), + mutated_input_idxs=find_input_mutations(aot_model.graph), + ) + out._boxed_call = True # type: ignore[attr-defined] + return out + + aot_cudagraphs = aot_autograd( + fw_compiler=forward_cudagraphs, + bw_compiler=backward_cudagraphs, + inference_compiler=functools.partial(forward_cudagraphs, is_inference=True), + keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation, + ) + return aot_cudagraphs(dynamo_model, dynamo_inputs) + + +class CudagraphsBackend: + compiler_name = "cudagraphs" + + @staticmethod + def reset() -> None: + from torch._inductor.cudagraph_trees import reset_cudagraph_trees + + reset_cudagraph_trees() + + @staticmethod + def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any: + return cudagraphs(model, inputs) + + +# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful +# for debugging and can serve as a perf baseline. +register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend()) + + +def cudagraphs_inner( + model: Callable[..., Any], + inputs: Sequence[Any], + copy_outputs: bool = True, + copy_inputs: bool = True, +) -> Callable[..., Sequence[Any]]: + """This isn't registered as a backend, but is used in some benchmarks""" + assert isinstance(inputs, (list, tuple)) + if copy_inputs: + static_inputs = [torch.zeros_like(x) for x in inputs] + else: + static_inputs = list(inputs) + + # warmup + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + model(*inputs) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + # record + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + static_outputs = model(*static_inputs) + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + def run(*new_inputs: Any) -> Sequence[Any]: + assert len(static_inputs) == len(new_inputs) + if copy_inputs: + for dst, src in zip(static_inputs, new_inputs): + dst.copy_(src) + graph.replay() + if copy_outputs: + return [x.clone() for x in static_outputs] + else: + return static_outputs + + return run diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/debugging.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/debugging.py new file mode 100644 index 0000000000000000000000000000000000000000..0e62e08cf1fc93a3acb11249e561ee06eb44e655 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/debugging.py @@ -0,0 +1,558 @@ +""" +This module provides debugging backends for TorchDynamo to help diagnose and troubleshoot +compilation and execution issues. It includes: + +Key Debugging Backends: +- eager: Simple pass-through backend that runs models in eager mode +- eager_noexcept: Similar to eager but with additional exception handling +- eager_debug: Adds schema validation checks for custom operators +- aot_eager: Uses AOT Autograd with nop compiler for debugging +- aot_eager_decomp_partition: Uses TorchInductor decompositions for debugging +- torchscript: Compiles using TorchScript for debugging JIT-related issues + +Testing and Development Tools: +- Backends for inducing specific errors (compile/runtime/accuracy) +- ExplainOutput class for detailed graph compilation analysis +- Utilities for cross-referencing and mode management +- Tools for graph detail inspection and break reason analysis + +These backends are primarily used for: +1. Debugging graph breaks and compilation failures +2. Testing error handling and recovery mechanisms +3. Analyzing performance bottlenecks +4. Validating operator schemas and decompositions +""" + +import dataclasses +import functools +import logging +from collections.abc import Callable, Iterable +from importlib import import_module +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +from functorch.compile import min_cut_rematerialization_partition +from torch import _guards +from torch._dynamo.output_graph import GraphCompileReason +from torch._functorch import config as functorch_config +from torch._functorch.compilers import ts_compile + +from .common import aot_autograd +from .registry import CompiledFn, CompilerFn, register_debug_backend as register_backend + + +if TYPE_CHECKING: + from torch.fx.node import Target + + +log = logging.getLogger(__name__) + + +@register_backend +def eager( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any +) -> Callable[..., Any]: + if kwargs: + log.warning("eager backend ignoring extra kwargs %s", kwargs) + return gm.forward + + +def make_eager_backend_with_torch_function_mode( + mode: torch.overrides.TorchFunctionMode, +) -> Callable[..., Any]: + return make_eager_backend_with_torch_function_modes([mode]) + + +def make_eager_backend_with_torch_function_modes( + modes: Iterable[torch.overrides.TorchFunctionMode], +) -> Callable[..., Any]: + """Used to trace HOPs (cond and while) for eager execution, the metadata + TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks + in the HOP, so we need to externally run this mode and not trace it.""" + from contextlib import ExitStack + + def fn( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any + ) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + with ExitStack() as stack: + for mode in modes: + stack.enter_context(mode) + return gm.forward(*args, **kwargs) + + return wrapper + + return fn + + +@register_backend +def eager_noexcept( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any +) -> Callable[..., Any]: + if kwargs: + log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs) + + # This backend is intended to check that dynamo-generated GraphModules + # do not cause errors. + def inner(*args: Any) -> Any: + try: + return gm(*args) + except Exception as e: + raise torch._dynamo.exc.TorchDynamoException( + "Unexpected exception when running generated GraphModule" + ) from e + + return inner + + +@register_backend +def pre_dispatch_eager( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any +) -> torch.fx.GraphModule: + if kwargs: + log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs) + + from torch.fx.experimental.proxy_tensor import make_fx + + def runnable_gm(*args: Any) -> Any: + return torch.fx.Interpreter(gm).run(*args) + + pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs) + pre_dispatch_gm.print_readable() + + return pre_dispatch_gm + + +@register_backend +def eager_debug( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any +) -> Callable[..., Any]: + if kwargs: + log.warning("eager_debug backend ignoring extra kwargs %s", kwargs) + + from torch._subclasses.schema_check_mode import SchemaCheckMode + + # We could add more debugging bits here. + # Right now, this backend can be used to check for and error on + # custom dispatcher ops that have incorrect schemas. + def inner(*args: Any) -> Any: + with SchemaCheckMode(): + return torch.fx.Interpreter(gm).run(*args) + + return inner + + +@register_backend(name="ts") # type: ignore[misc] +def torchscript( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor] +) -> torch.jit.ScriptModule: + return torch.jit.script(gm) + + +# used boxed call to discard inputs when they are no longer needed +def boxed_nop( + fx_g: torch.fx.GraphModule, example_inputs: list[torch.Tensor] +) -> Callable[..., Any]: + from torch.fx.graph import _BoxedCodeGen + + # Set the graph to use boxed codegen + fx_g.graph.set_codegen(_BoxedCodeGen()) + fx_g.recompile() + + # Wrap the forward method in a function so we can set _boxed_call attribute + forward_fn = fx_g.forward + + def run(args: Any) -> Any: + return forward_fn(args) + + run._boxed_call = True # type: ignore[attr-defined] + return run + + +def boxed_nop_with_mode( + fx_g: torch.fx.GraphModule, + example_inputs: list[torch.Tensor], + *, + mode: torch.overrides.TorchFunctionMode, +) -> Callable[..., Any]: + from torch.fx.graph import _BoxedCodeGen + + # Set the graph to use boxed codegen + fx_g.graph.set_codegen(_BoxedCodeGen()) + fx_g.recompile() + + # Create a wrapper that runs with the mode + forward_fn = fx_g.forward + + def run(args: Any) -> Any: + with mode: + return forward_fn(args) + + run._boxed_call = True # type: ignore[attr-defined] + return run + + +def fake_crossref_boxed_nop( + fx_g: torch.fx.GraphModule, + example_inputs: list[torch.Tensor], + ignore_op_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None, +) -> Callable[..., Any]: + from torch.fx.graph import _BoxedCodeGen + + # Set the graph to use boxed codegen + fx_g.graph.set_codegen(_BoxedCodeGen()) + fx_g.recompile() + + # Create a wrapper that runs with the mode + forward_fn = fx_g.forward + + def run(args: Any) -> Any: + with torch._subclasses.CrossRefFakeMode(ignore_op_fn): + return forward_fn(args) + + run._boxed_call = True # type: ignore[attr-defined] + return run + + +def ignore_builtins(op: torch._ops.OpOverload) -> bool: + return op.namespace in ("aten", "prims", "prim") + + +def get_nop_func() -> Callable[ + [torch.fx.GraphModule, list[torch.Tensor]], Callable[..., Any] +]: + if not torch._functorch.config.fake_tensor_crossref: + return boxed_nop + elif torch._functorch.config.fake_tensor_crossref == "all": + return fake_crossref_boxed_nop + else: + assert torch._functorch.config.fake_tensor_crossref == "custom_ops" + return functools.partial(fake_crossref_boxed_nop, ignore_op_fn=ignore_builtins) + + +# Useful for debugging purpose +# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. +def aot_eager( + gm: torch.fx.GraphModule, + fake_tensor_inputs: list[torch.Tensor], + fw_compiler: Optional[Callable[..., Any]] = None, + bw_compiler: Optional[Callable[..., Any]] = None, + **kwargs: Any, +) -> Callable[..., Any]: + return aot_autograd( + fw_compiler=fw_compiler or boxed_nop, + bw_compiler=bw_compiler or boxed_nop, + partition_fn=min_cut_rematerialization_partition, + keep_inference_input_mutations=True, + )(gm, fake_tensor_inputs, **kwargs) + + +register_backend(name="aot_eager", compiler_fn=aot_eager) + +aot_eager_default_partitioner = aot_autograd( + fw_compiler=boxed_nop, keep_inference_input_mutations=True +) +register_backend( + name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner +) + + +# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs +# inductor problems. +# aot_eager_decomp_partition just replaces the inductor compiler with nop to help +# isolate inductor vs aot_eager errors +def aot_eager_decomp_partition( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any +) -> Callable[..., Any]: + if kwargs: + log.warning( + "aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs + ) + + from torch._inductor.compiler_bisector import CompilerBisector + + config_patches = {"unlift_effect_tokens": True} + if bisect_changes := CompilerBisector.get_config_change( + "aot_eager_decomp_partition" + ): + config_patches.update(bisect_changes) # type: ignore[arg-type] + + with functorch_config.patch(config_patches): + return aot_autograd( + # these are taken from memory_efficient_fusion() + fw_compiler=get_nop_func(), + bw_compiler=get_nop_func(), + # NB: lambda here is to delay import of inductor + decompositions=lambda: import_module( + "torch._inductor.compile_fx" + ).select_decomp_table(), + partition_fn=functools.partial( + min_cut_rematerialization_partition, compiler="inductor" + ), + )(gm, fake_tensor_inputs) + + +register_backend( + name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition +) + + +# aot_eager_decomp_partition_with_mode is similar as aot_eager_decomp_partition, +# except that it takes a TorchDispatchMode mode and run the fw/bw in the mode +def aot_eager_decomp_partition_with_mode( + gm: torch.fx.GraphModule, + fake_tensor_inputs: list[torch.Tensor], + mode: Any, + **kwarg: Any, +) -> Callable[..., Any]: + return aot_autograd( + # these are taken from memory_efficient_fusion() + fw_compiler=functools.partial(boxed_nop_with_mode, mode=mode), + bw_compiler=functools.partial(boxed_nop_with_mode, mode=mode), + # NB: lambda here is to delay import of inductor + decompositions=lambda: import_module( + "torch._inductor.compile_fx" + ).select_decomp_table(), + partition_fn=functools.partial( + min_cut_rematerialization_partition, compiler="inductor" + ), + )(gm, fake_tensor_inputs) + + +register_backend( + name="aot_eager_decomp_partition_with_mode", + compiler_fn=aot_eager_decomp_partition_with_mode, # type: ignore[arg-type] +) + + +def aot_eager_decomp_partition_crossref( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any +) -> Callable[..., Any]: + # if the config is set, respect it, otherwise only test custom_ops. + # custom_op bad metas always manifest as an error whereas aten will only sometimes. + # by default, use the less noisy option + config_val = ( + "custom_ops" + if not functorch_config.fake_tensor_crossref + else functorch_config.fake_tensor_crossref + ) + with functorch_config.patch(fake_tensor_crossref=config_val): + return aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs) + + +register_backend( + name="aot_eager_decomp_partition_crossref", + compiler_fn=aot_eager_decomp_partition_crossref, +) + + +# AOT Autograd with torchscript backend. Default partitioner. +# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser +# by using the relevant fuser with torch.jit.fuser(...) +aot_ts = aot_autograd(fw_compiler=ts_compile) +register_backend(name="aot_ts", compiler_fn=aot_ts) + +# These buggy backends are used for inducing bugs so that we can test +# our repro extraction / minifier scripts + + +class ReluCompileError(Exception): + pass + + +class TestingOnlyCompileError(Exception): + pass + + +@register_backend +def relu_compile_error_TESTING_ONLY( + gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] +) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + if node.target is torch.relu: + raise ReluCompileError + return gm + + +@register_backend +def relu_runtime_error_TESTING_ONLY( + gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] +) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + if node.target is torch.relu: + node.target = torch._assert + node.args = (False, "ReluRuntimeError") + gm.recompile() + return gm + + +@register_backend +def relu_accuracy_error_TESTING_ONLY( + gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] +) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + if node.target is torch.relu: + node.target = torch.add + node.args = (node.args[0], 1) + gm.recompile() + + return gm + + +@register_backend +def non_leaf_compile_error_TESTING_ONLY( + gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] +) -> torch.fx.GraphModule: + # Require at least one non-trivial thing in the graph, + # see https://github.com/pytorch/pytorch/issues/102898 + for node in gm.graph.nodes: + if node.op == "call_function": + break + else: + return gm + for t in example_inputs: + if not t.is_leaf: + raise TestingOnlyCompileError + return gm + + +@dataclasses.dataclass +class ExplainOutput: + """ + This is the output of :func:`torch._dynamo.explain()` + There is no reason to create this class directly. + """ + + graphs: list[torch.fx.GraphModule] + graph_count: int + graph_break_count: int + break_reasons: list[GraphCompileReason] + op_count: int + ops_per_graph: Optional[list[list["Target"]]] = None + out_guards: Optional[list[_guards.Guard]] = None + compile_times: Optional[str] = None + + def __str__(self) -> str: + output = f"Graph Count: {self.graph_count}\n" + output += f"Graph Break Count: {self.graph_break_count}\n" + output += f"Op Count: {self.op_count}\n" + + output += "Break Reasons:\n" + for idx, break_reason in enumerate(self.break_reasons): + output += f" Break Reason {idx + 1}:\n" + output += f" Reason: {break_reason.reason}\n" + output += " User Stack:\n" + for frame_summary in break_reason.user_stack: + output += f" {frame_summary}\n" + + if self.ops_per_graph is not None: + output += "Ops per Graph:\n" + for idx, ops in enumerate(self.ops_per_graph): + output += f" Ops {idx + 1}:\n" + for op in ops: + output += f" {op}\n" + + if self.out_guards is not None: + output += "Out Guards:\n" + for i, guard in enumerate(self.out_guards): + output += f" Guard {i + 1}:\n" + output += f" {str(guard)}" + + if self.compile_times is not None: + output += f"Compile Times: {self.compile_times}\n" + return output + + +def _explain_graph_detail( + gm: torch.fx.GraphModule, + graphs: list[torch.fx.GraphModule], + op_count: int, + ops_per_graph: list[list["Target"]], + break_reasons: list[GraphCompileReason], +) -> tuple[ + torch.fx.GraphModule, + list[torch.fx.GraphModule], + int, + list[list["Target"]], + list[GraphCompileReason], +]: + """ + This function is a utility which processes a torch.fx.GraphModule and + accumulates information about its ops, graph breaks, and other details. It + is intended to be used by the ExplainWithBackend class and + `torch._dynamo.explain()` to provide details from Dynamo's graph capture. + + Parameters: + gm (torch.fx.GraphModule): The GraphModule to be processed. + graphs (list): A list that accumulates all the GraphModules processed. + op_count (int): The total count of operations in all GraphModules processed so far. + ops_per_graph (list): A list that accumulates the operations of each GraphModule. + break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule. + + Returns: + tuple: A tuple containing the processed GraphModule, the updated lists of graphs, + operations per graph, and break reasons, and the updated operation count. + """ + graphs.append(gm) + ops = [node.target for node in gm.graph.nodes if node.op == "call_function"] + op_count += len(ops) + ops_per_graph.append(ops) + if gm.compile_subgraph_reason.graph_break: # type: ignore[union-attr] + break_reasons.append(gm.compile_subgraph_reason) # type: ignore[arg-type] + + return gm, graphs, op_count, ops_per_graph, break_reasons + + +class ExplainWithBackend: + """ + This class is intended to be used as a backend for `torch.compile`. It is + composable with other backends. When used in this way, it accumulates + information about graph breaks, ops, and other info and provides a string + representation summarizing this information. + + Attributes: + backend (str): The name of the backend to use for optimization. + graphs (list): A list of the graphs captured by TorchDynamo. + op_count (int): The total number of operations in all optimized graphs. + break_reasons (list): A list of graph break reasons with stack traces. + + Example Usage: + def fn(x): + x = torch.sigmoid(x) + return x + + torch._dynamo.reset() + eb = ExplainWithBackend("inductor") + optimized_fn = torch.compile(fn, backend=eb) + result = optimized_fn(torch.randn(5)) + print(eb.output()) + """ + + def __init__(self, backend: Union[CompilerFn, str]) -> None: + from .registry import lookup_backend + + self.backend = lookup_backend(backend) + self.graphs: list[torch.fx.GraphModule] = [] + self.op_count = 0 + self.break_reasons: list[GraphCompileReason] = [] + + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] + ) -> CompiledFn: + ops_per_graph: list[list[Target]] = [] + gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail( + gm, self.graphs, self.op_count, ops_per_graph, self.break_reasons + ) + return self.backend(gm, example_inputs) + + def output(self) -> ExplainOutput: + graph_count = len(self.graphs) + output = ExplainOutput( + self.graphs, + graph_count, + graph_count - 1, + self.break_reasons, + self.op_count, + ) + + return output diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/distributed.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..e53becd884bbaf7f4d7c876cffb739c59a1717bf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/distributed.py @@ -0,0 +1,621 @@ +""" +This module implements distributed training optimizations for TorchDynamo backends. + +It provides functionality to optimize models wrapped in DistributedDataParallel (DDP) +by intelligently splitting compiled graphs to align with DDP's gradient synchronization +boundaries. Key features include: + +- Graph partitioning based on parameter bucket sizes +- Optimization of allreduce operations for distributed training +- Support for parameter ignoring and buffer handling +- Submodule compilation and management +- Debugging utilities for distributed training + +The main component is the DDPOptimizer class, which handles graph splitting and +recompilation to enable efficient distributed training while maintaining the benefits +of compilation. +""" + +import logging +import traceback +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Optional, TYPE_CHECKING +from unittest import mock + +import torch +from torch import fx +from torch._dynamo.backends.registry import CompiledFn, CompilerFn +from torch._dynamo.output_graph import GraphCompileReason +from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode +from torch._logging import trace_structured +from torch.fx.node import Node + + +if TYPE_CHECKING: + from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta + + +# Regular log messages should go through 'log'. +# ddp_graph_log is a separate artifact logger reserved for dumping graphs. +# See docs/source/logging.rst for more info. +log = logging.getLogger(__name__) +ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs") + + +def args_str(args: Any) -> str: + # a debug helper + if torch.is_tensor(args): + return f"T[{args.shape}]" + elif isinstance(args, tuple): + return f"tuple({', '.join([args_str(x) for x in args])})" + elif isinstance(args, list): + return f"list({', '.join([args_str(x) for x in args])})" + else: + return str(args) + + +@dataclass +class Bucket: + size: int = 0 + params: list[str] = field(default_factory=list) + nodes: list[fx.Node] = field(default_factory=list) + + # param_ids is just used for unit testing + param_ids: list[int] = field(default_factory=list) + + # keep track of any buckets that were extended for logging purposes + opcount_increased_to_capture_external_output: int = 0 + paramsize_before_opcount_increase: int = 0 + + +def bucket_has_external_output(bucket: Bucket) -> bool: + nodes_in_bucket = set() + # we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards + # so we don't reverse it here + for node in bucket.nodes: + # assume node.op != output, since those are filtered in the original iteration + nodes_in_bucket.add(node) + for user in node.users: + if user not in nodes_in_bucket: + return True + return False + + +def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int) -> None: + headers = ("Index", "Size (b)", "Param Names") + rows: list[tuple[Optional[int], Optional[int], str]] = [] + extended_buckets = [] + for idx, bucket in enumerate(reversed(buckets)): + if len(bucket.params) > 0: + rows.append((idx, bucket.size, bucket.params[0])) + rows.extend((None, None, param) for param in bucket.params[1:]) + if bucket.opcount_increased_to_capture_external_output > 0: + extended_buckets.append( + ( + idx, + bucket.opcount_increased_to_capture_external_output, + bucket.size - bucket.paramsize_before_opcount_increase, + ) + ) + + if rows: + log.info( + "\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.", + bucket_bytes_cap, + len(buckets), + ) + + if extended_buckets: + log.warning( + "Some buckets were extended beyond their requested parameter capacities" + " in order to ensure each subgraph has an output node, required for fx graph partitioning." + " This can be the case when a subgraph would have only contained nodes performing inplace mutation," + " and returning no logical outputs. This should not be a problem, unless it results in too few graph" + " partitions for optimal DDP performance." + ) + + try: + from tabulate import tabulate + + log.debug( + "\nDDPOptimizer produced the following bucket assignments:\n%s", + tabulate(rows, headers=headers, tablefmt="simple_grid"), + ) + + if extended_buckets: + log.warning( + "DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s", + tabulate( + extended_buckets, + headers=("Index", "Extra Ops", "Extra Param Size (b)"), + tablefmt="simple_grid", + ), + ) + except ImportError: + log.debug( + "Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information." + ) + else: + log.debug("DDPOptimizer captured no parameters and did not split this graph.") + + +def has_higher_order_op(gm: fx.GraphModule) -> bool: + # Check if there is a higher order op in the graph + for node in gm.graph.nodes: + if node.op == "get_attr": + maybe_param = getattr(gm, node.target) + if isinstance(maybe_param, torch.fx.GraphModule): + return True + return False + + +def propagate_metadata(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None: + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + # TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384 + module.meta = orig_gm.meta + module._param_name_to_source = orig_gm._param_name_to_source + + +def propagate_dynamo_source(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None: + name_to_dynamo_source = {} + for node in orig_gm.graph.find_nodes(op="placeholder"): + name_to_dynamo_source[node.name] = node._dynamo_source + + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + for node in module.graph.find_nodes(op="placeholder"): + # non-placeholder in original_gm may become placeholder in submodules + node._dynamo_source = name_to_dynamo_source.get(node.name, None) + + +class DDPOptimizerContext: + def __init__(self) -> None: + self.curr_bucket: int = -1 + self.metadata_per_bucket: list[ViewAndMutationMeta] = [] + + +# compile each of the partitioned submodules using the user-provided compiler +class SubmodCompiler(torch.fx.interpreter.Interpreter): + def __init__( + self, + module: fx.GraphModule, + compiler: CompilerFn, + fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, + ) -> None: + super().__init__(module) + self.compiler = compiler + self.fake_mode = fake_mode + # See Note [DDPOptimizer and fw_metadata] + ctx = torch._guards.TracingContext.try_get() + if ctx is not None: + ctx.ddp_optimizer_ctx = DDPOptimizerContext() + + def compile_submod( + self, input_mod: fx.GraphModule, args: list[torch.Tensor], kwargs: Any + ) -> Any: + """ + Compile the submodule, + using a wrapper to make sure its output is always a tuple, + which is required by AotAutograd based compilers + """ + assert len(kwargs) == 0, "We assume only args for these modules" + + class WrapperModule(torch.nn.Module): + def __init__( + self, submod: Callable[..., Any], unwrap_singleton_tuple: bool + ) -> None: + super().__init__() + self.submod = submod + self.unwrap_singleton_tuple = unwrap_singleton_tuple + + def forward(self, *args: Any) -> Any: + x = self.submod(*args) + # TODO(whc) + # for some reason the isinstance check is necessary if I split one node per submod + # - even though I supposedly wrapped the output in a tuple in those cases, the real + # compiled module was still returning a tensor + if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)): + return x[0] + return x + + unwrap_singleton_tuple = False + for sn in input_mod.graph.nodes: + if sn.op == "output": + if not isinstance(sn.args[0], tuple): + unwrap_singleton_tuple = True + sn.args = (sn.args,) + + input_mod.recompile() + input_mod.compile_subgraph_reason = GraphCompileReason( # type: ignore[assignment] + "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])." + " Set `torch._dynamo.config.optimize_ddp = False` to disable.", + [ + # it's close to useless to get a real stacktrace here, and quite verbose. + traceback.FrameSummary(__file__, 0, "DDPOptimizer"), + ], + ) + + wrapper = WrapperModule( + self.compiler(input_mod, args), + unwrap_singleton_tuple, + ) + return wrapper + + # Note: + # + # The way distributed works today around fake tensors can be somewhat confusing. + # Some of these codepaths are shared in both runtime, and compile time. The presence + # of a fake_mode, read off of fake tensor inputs, dictates how we will operate. + # + # A few things to keep in mind: + # + # 1) We invoke `compile_submod` with a real module. The output of that gets stored + # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`. + # + # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the + # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it. + # + # 3) Fake tensors should always be around during compile time. + # + # 4) Fake tensors should never be around at runtime. + # + # 5) We end up with a compilation mode that takes a real submodule and fake tensors, + # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd] + def run_node(self, n: Node) -> Any: + args, kwargs = self.fetch_args_kwargs_from_env(n) + new_args = [] + assert self.fake_mode + for arg in args: + if isinstance(arg, torch.Tensor) and not isinstance( + arg, torch._subclasses.FakeTensor + ): + new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode)) + else: + new_args.append(arg) + + log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args)) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + if n.op == "call_module": + real_mod = self.fetch_attr(str(n.target)) + if self.fake_mode: + curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode) + else: + curr_submod = real_mod + + ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph) + + # When calling the compiler on the submod, inputs (new_args) are expected to + # be FakeTensors already since Dynamo would have made them FakeTensors in the + # non-DDP flow. However, the parameters are _not_ expected to be FakeTensors, + # since this wrapping happens during compilation + + # Note: Returning Fake Tensors on First AOT Autograd Call + # + # Inductor will optimize strides of outputs when it deems it profitable. + # For instance, converting to channels last. When we split the graph here + # into multiple inductor compilations, we need to make sure that the + # output strides of one compilation is appropriately passed to the subsequent + # compilations. However, the mapping from inductor output to dynamo output + # is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing, + # subclass handling, etc. In order to replay all this logic we set a flag such that + # the first invocation of inductor in aot_autograd will return Fake Tensors with + # appropriate strides. Then, all of aot autograd's runtime logic is replayed. + # This gives us the appropriately strided outputs here which will reflect runtime strides. + + class FakeifyFirstAOTInvocationGuard: + def __init__(self) -> None: + self.tc = torch._guards.TracingContext.try_get() + assert self.tc + self.tc.fakify_first_call = True + + def __del__(self) -> None: + self.tc.fakify_first_call = False # type: ignore[union-attr] + + # For aot_eager and other backends, tracing context is not set + has_tracing_context = torch._guards.TracingContext.try_get() is not None + if has_tracing_context: + g = FakeifyFirstAOTInvocationGuard() # noqa: F841 + + from torch._dynamo.utils import counters + + init = counters["aot_autograd"]["total"] + compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs) + + # TODO - better way of doing this? + # Only aot autograd handles fakifying first call + invoked_aot_autograd = init != counters["aot_autograd"]["total"] + + # We update the original (outer) graph with a call into the compiled module + # instead of the uncompiled one. + self.module.delete_submodule(n.target) # type: ignore[operator] + n.target = "compiled_" + n.target # type: ignore[operator] + self.module.add_submodule(n.target, compiled_submod_real) # type: ignore[operator] + + # Finally, we have to produce inputs for use compiling the next submodule, + # and these need to be FakeTensors, so we execute the module under fake_mode + # Because parameters are not fake we patch fake tensor mode to allow non fake inputs + with ( + self.fake_mode, + mock.patch.object(self.fake_mode, "allow_non_fake_inputs", True), + ): + if has_tracing_context and invoked_aot_autograd: + tracing_ctx = torch._guards.TracingContext.try_get() + assert tracing_ctx is not None + # DDPOptimizer maintains 1 dynamo graph -> N AOT graphs + # Dynamo only has 1 tracing context, so it needs to maintain all N AOT metadata instances + ddp_ctx = tracing_ctx.ddp_optimizer_ctx + assert ddp_ctx is not None + assert tracing_ctx.fw_metadata is not None + ddp_ctx.curr_bucket += 1 + ddp_ctx.metadata_per_bucket.append(tracing_ctx.fw_metadata) + + out = compiled_submod_real(*new_args, **kwargs) + # output should be fake or subclass + assert all( + (not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor) + for t in (out if isinstance(out, (list, tuple)) else [out]) + ) + return out + else: + return curr_submod(*new_args, **kwargs) + else: + # placeholder or output nodes don't need to get compiled, just executed + return getattr(self, n.op)(n.target, new_args, kwargs) + + +class DDPOptimizer: + """Note [DDPOptimizer] + DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP), + breaking the dynamo graph into chunks to compile separately, with the breaks aligning to + the boundaries of gradient-allreduce buckets chosen by DDP. + + Background/Motivation + - DDP uses allreduce collectives to synchronize partial gradients computed on different workers + - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce + - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready + at around the same time during backward and thus can share the same allreduce efficiently + - Allreduces must overlap with backward compute for optimal training performance + - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which + operates when individual grads become 'ready' + - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the + autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole + fused backward function executes, preventing any overlap of compute and communication + + Algorithm + - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse + this graph in reverse order to determine the true order that gradients will become ready during backward. + - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started + and a graph break introduced + - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together + into an outer module that is returned to the user + + Notes + - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP, + and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does + in eager. + - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently + produce splits that do not necessarily align with the buckets used by DDP. This should result in performance + degradation approaching the baseline case where graph-splits are not used, but not worse. + - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the + subgraphs being compiled + - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers + left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are + also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP, + it is not catastrophic but could impact performance by choosing sub-optimal bucket splits. + - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients, + and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by + DDPOptimizer) + + Debugging + - Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb. + - In many cases, the log messages are helpful (they show bucket size assignments)- + just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'. + - See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model + in a single process (or with torchrun, in multiple processes) + + Args: + bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be + set to match the equivalent parameter on the original DDP module. + + backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph. + + first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP + special-cases the first bucket size since it is sometimes optimal to start a small allreduce early. + + """ + + def __init__( + self, + bucket_bytes_cap: int, + backend_compile_fn: CompilerFn, + first_bucket_cap: Optional[int] = None, + ) -> None: + if first_bucket_cap is not None: + self.first_bucket_cap = first_bucket_cap + elif torch.distributed.is_available(): + # this constant comes from C10D lib which is not always built + self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES + else: + self.first_bucket_cap = bucket_bytes_cap + + self.bucket_bytes_cap = bucket_bytes_cap + assert self.first_bucket_cap <= self.bucket_bytes_cap, ( + "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP" + ) + + self.backend_compile_fn = backend_compile_fn + + def _ignore_parameter(self, parameter: torch.nn.Parameter) -> bool: + return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored + + def add_param(self, bucket: Bucket, param: torch.nn.Parameter, name: str) -> None: + bucket.size += param.untyped_storage().nbytes() + bucket.params.append(name) + bucket.param_ids.append(id(param)) + + def add_module_params_to_bucket( + self, + mod: torch.nn.Module, + bucket: Bucket, + processed_modules: set[torch.nn.Module], + prefix: str, + ) -> None: + processed_modules.add(mod) + for name, param in mod.named_parameters(): + if param.requires_grad and not self._ignore_parameter(param): + self.add_param(bucket, param, f"{prefix}_{name}") + + def add_param_args(self, bucket: Bucket, node: fx.Node) -> None: + for arg in node.args: + if not isinstance(arg, torch.fx.node.Node): + continue + if arg.op != "placeholder": + continue + param = arg.meta["example_value"] + if ( + isinstance(param, torch.nn.Parameter) + and param.requires_grad + and not self._ignore_parameter(param) + ): + self.add_param(bucket, param, str(arg.target)) + + def compile_fn( + self, gm: fx.GraphModule, example_inputs: list[torch.Tensor] + ) -> CompiledFn: + """ + Implements graph splitting, first determining a set of of buckets by counting + parameter sizes in reverse graph order, then invoking the user/backend compiler + to compile each subgraph. Finally, stiches compiled graphs into one graphmodule + and returns its callable. + """ + # 1: compute the partition map according to DDP bucket logic + buckets = [Bucket()] # (size, param_names) + processed_modules: set[torch.nn.Module] = set() + for node in reversed(gm.graph.nodes): + if node.op in ("output", "placeholder"): + continue + + if ( + buckets[0].size >= self.bucket_bytes_cap + or len(buckets) == 1 + and buckets[0].size >= self.first_bucket_cap + ): + if bucket_has_external_output(buckets[0]): + buckets.insert(0, Bucket()) + else: + # continue building this bucket past the point of filling its parameter capacity, + # to increase chances it contains at least one node that is either a global output or + # passed as input to a subsequent graph + + if buckets[0].opcount_increased_to_capture_external_output == 0: + buckets[0].paramsize_before_opcount_increase = buckets[0].size + buckets[0].opcount_increased_to_capture_external_output += 1 + + if node.op == "call_function": + self.add_param_args(buckets[0], node) + + elif node.op == "call_module": + target_mod = gm.get_submodule(node.target) + if target_mod not in processed_modules: + self.add_module_params_to_bucket( + target_mod, buckets[0], processed_modules, node.target + ) + elif node.op == "call_method": + if isinstance(node.args[0].target, str): + target_mod = None + try: + target_mod = gm.get_submodule(node.args[0].target) + except AttributeError: + pass + if target_mod is not None and target_mod not in processed_modules: + self.add_module_params_to_bucket( + target_mod, buckets[0], processed_modules, node.target + ) + # This handles situations like tmp = torch.mm(x, self.weight.t()) + # t: "f32[512, 512]" = l_self_seq_2_weight.t(); l_self_seq_2_weight = None + # tmp: "f32[512, 512]" = torch.mm(input_2, t); input_2 = t = None + self.add_param_args(buckets[0], node) + + elif node.op == "get_attr": + maybe_param = getattr(gm, node.target) + if ( + isinstance(maybe_param, torch.nn.Parameter) + and maybe_param.requires_grad + and not self._ignore_parameter(maybe_param) + ): + self.add_param(buckets[0], maybe_param, node.target) + + # All nodes have to be mapped to a bucket, even if they don't have their own params + # Ignored params still end up in buckets, we just don't count them towards the capacity + buckets[0].nodes.append(node) + + if len(buckets) > 1 and buckets[0].size == 0: + # we collected a small preamble graph with ops that don't include parameters, fuse it back + buckets[1].nodes.extend(buckets[0].nodes) + assert len(buckets[0].params) == 0, "Params should be empty if size is 0" + del buckets[0] + + # stash buckets for testing/debugging purposes + self.buckets = buckets + pretty_print_buckets(buckets, self.bucket_bytes_cap) + + if len(buckets) == 1: + # bypass split/fuse logic if there is only one bucket + return self.backend_compile_fn(gm, example_inputs) + + # 2: partition the graphmodule according to bucket capacity + partition_map = {} + for idx, b in enumerate(buckets): + for node in b.nodes: + partition_map[node] = idx + + split_gm = fx.passes.split_module.split_module( + gm, + None, # type: ignore[arg-type] + lambda node: partition_map[node], + ) + + # See note [Assumption on Dynamo Metadata] + propagate_dynamo_source(gm, split_gm) + propagate_metadata(gm, split_gm) + + debug_str = ( + f"\n---orig graph---\n{gm.graph}\n" + + f"\n---split graph---\n{split_gm.graph}\n" + ) + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + # only print the submod graphs, not their children + debug_str += f"\n---{name} graph---\n{module.graph}\n" + debug_str += "\n---------------\n" + ddp_graph_log.debug(debug_str) + + trace_structured( + "optimize_ddp_split_graph", + payload_fn=lambda: split_gm.print_readable(print_output=False), + ) + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + trace_structured( + "optimize_ddp_split_child", + lambda: {"name": name}, + payload_fn=lambda: module.print_readable(print_output=False), + ) + + fake_mode = detect_fake_mode(example_inputs) + if fake_mode is None: + fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() + + submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, fake_mode) + with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + submod_compiler.run(*example_inputs) + split_gm.recompile() + + ddp_graph_log.debug( + "\n---final graph---\n%s\n---------------\n", split_gm.graph + ) + return split_gm diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/inductor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/inductor.py new file mode 100644 index 0000000000000000000000000000000000000000..ae62dd56678b8349d27fe909f12482b884ca596c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/inductor.py @@ -0,0 +1,31 @@ +""" +This module provides the TorchInductor backend integration for TorchDynamo. + +TorchInductor is a compiler backend that generates optimized code for both CPU and GPU. +This module lazily imports and registers the TorchInductor compiler to avoid loading it +into memory when it is not being used. This helps reduce memory overhead when using +other backends. + +The inductor backend can be used with torch.compile(): + model = torch.compile(model, backend="inductor") +""" + +from typing import Any + +from torch._dynamo import register_backend +from torch._dynamo.utils import dynamo_timed + + +@register_backend +def inductor(*args: Any, **kwargs: Any) -> Any: + with dynamo_timed("inductor_import", log_pt2_compile_event=True): + # do import here to avoid loading inductor into memory when it is not used + # The AsyncCompile subproc pool can be slow to start, so warm it up as early + # as possible. + from torch._inductor.async_compile import maybe_warm_pool + + maybe_warm_pool() + + from torch._inductor.compile_fx import compile_fx + + return compile_fx(*args, **kwargs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/onnxrt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/onnxrt.py new file mode 100644 index 0000000000000000000000000000000000000000..93490e64f4ae2044d0c641f8171e733ed7a8e141 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/onnxrt.py @@ -0,0 +1,39 @@ +# This backend is maintained by ONNX team. To direct issues +# to the right people, please tag related GitHub issues with `module: onnx`. +# +# Maintainers' Github IDs: wschin, xadupre +# from torch.onnx._internal.onnxruntime import ( +# is_onnxrt_backend_supported, +# torch_compile_backend, +# ) + +# from .registry import register_backend + +""" +Placeholder for onnxruntime backend for dynamo +""" + +# def has_onnxruntime(): +# # FIXME: update test/dynamo/test_backends.py to call is_onnxrt_backend_supported() +# return is_onnxrt_backend_supported() + + +# if is_onnxrt_backend_supported(): +# register_backend(name="onnxrt", compiler_fn=torch_compile_backend) +# else: + +# def information_displaying_backend(*args, **kwargs): +# raise ImportError( +# "onnxrt is not registered as a backend. " +# "Please make sure all dependencies such as " +# "numpy, onnx, onnxscript, and onnxruntime-training are installed. " +# "Suggested procedure to fix dependency problem:\n" +# " (1) pip or conda install numpy onnx onnxscript onnxruntime-training.\n" +# " (2) Open a new python terminal.\n" +# " (3) Call the API `torch.onnx.is_onnxrt_backend_supported()`:\n" +# " (4) If it returns `True`, then you can use `onnxrt` backend.\n" +# " (5) If it returns `False`, please execute the package importing section in " +# "torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails." +# ) + +# register_backend(name="onnxrt", compiler_fn=information_displaying_backend) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/registry.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..1469ca478a38647f91b95f1eed8b2a0e6408dd66 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/registry.py @@ -0,0 +1,179 @@ +""" +This module implements TorchDynamo's backend registry system for managing compiler backends. + +The registry provides a centralized way to register, discover and manage different compiler +backends that can be used with torch.compile(). It handles: + +- Backend registration and discovery through decorators and entry points +- Lazy loading of backend implementations +- Lookup and validation of backend names +- Categorization of backends using tags (debug, experimental, etc.) + +Key components: +- CompilerFn: Type for backend compiler functions that transform FX graphs +- _BACKENDS: Registry mapping backend names to entry points +- _COMPILER_FNS: Registry mapping backend names to loaded compiler functions + +Example usage: + @register_backend + def my_compiler(fx_graph, example_inputs): + # Transform FX graph into optimized implementation + return compiled_fn + + # Use registered backend + torch.compile(model, backend="my_compiler") + +The registry also supports discovering backends through setuptools entry points +in the "torch_dynamo_backends" group. Example: +``` +setup.py +--- +from setuptools import setup + +setup( + name='my_torch_backend', + version='0.1', + packages=['my_torch_backend'], + entry_points={ + 'torch_dynamo_backends': [ + # name = path to entry point of backend implementation + 'my_compiler = my_torch_backend.compiler:my_compiler_function', + ], + }, +) +``` +``` +my_torch_backend/compiler.py +--- +def my_compiler_function(fx_graph, example_inputs): + # Transform FX graph into optimized implementation + return compiled_fn +``` +Using `my_compiler` backend: +``` +import torch + +model = ... # Your PyTorch model +optimized_model = torch.compile(model, backend="my_compiler") +``` +""" + +import functools +import logging +from collections.abc import Callable, Sequence +from importlib.metadata import EntryPoint +from typing import Any, Optional, Protocol, Union + +import torch +from torch import fx + + +log = logging.getLogger(__name__) + + +class CompiledFn(Protocol): + def __call__(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: ... + + +CompilerFn = Callable[[fx.GraphModule, list[torch.Tensor]], CompiledFn] + +_BACKENDS: dict[str, Optional[EntryPoint]] = {} +_COMPILER_FNS: dict[str, CompilerFn] = {} + + +def register_backend( + compiler_fn: Optional[CompilerFn] = None, + name: Optional[str] = None, + tags: Sequence[str] = (), +) -> Callable[..., Any]: + """ + Decorator to add a given compiler to the registry to allow calling + `torch.compile` with string shorthand. Note: for projects not + imported by default, it might be easier to pass a function directly + as a backend and not use a string. + + Args: + compiler_fn: Callable taking a FX graph and fake tensor inputs + name: Optional name, defaults to `compiler_fn.__name__` + tags: Optional set of string tags to categorize backend with + """ + if compiler_fn is None: + # @register_backend(name="") syntax + return functools.partial(register_backend, name=name, tags=tags) # type: ignore[return-value] + assert callable(compiler_fn) + name = name or compiler_fn.__name__ + assert name not in _COMPILER_FNS, f"duplicate name: {name}" + if compiler_fn not in _BACKENDS: + _BACKENDS[name] = None + _COMPILER_FNS[name] = compiler_fn + compiler_fn._tags = tuple(tags) # type: ignore[attr-defined] + return compiler_fn + + +register_debug_backend = functools.partial(register_backend, tags=("debug",)) +register_experimental_backend = functools.partial( + register_backend, tags=("experimental",) +) + + +def lookup_backend(compiler_fn: Union[str, CompilerFn]) -> CompilerFn: + """Expand backend strings to functions""" + if isinstance(compiler_fn, str): + if compiler_fn not in _BACKENDS: + _lazy_import() + if compiler_fn not in _BACKENDS: + from ..exc import InvalidBackend + + raise InvalidBackend(name=compiler_fn) + + if compiler_fn not in _COMPILER_FNS: + entry_point = _BACKENDS[compiler_fn] + if entry_point is not None: + register_backend(compiler_fn=entry_point.load(), name=compiler_fn) + compiler_fn = _COMPILER_FNS[compiler_fn] + return compiler_fn + + +# NOTE: can't type this due to public api mismatch; follow up with dev team +def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: # type: ignore[no-untyped-def] + """ + Return valid strings that can be passed to: + + torch.compile(..., backend="name") + """ + _lazy_import() + exclude_tags_set = set(exclude_tags or ()) + + backends = [ + name + for name in _BACKENDS + if name not in _COMPILER_FNS + or not exclude_tags_set.intersection(_COMPILER_FNS[name]._tags) # type: ignore[attr-defined] + ] + return sorted(backends) + + +@functools.cache +def _lazy_import() -> None: + from .. import backends + from ..utils import import_submodule + + import_submodule(backends) + + from ..repro.after_dynamo import dynamo_minifier_backend + + assert dynamo_minifier_backend is not None + + _discover_entrypoint_backends() + + +@functools.cache +def _discover_entrypoint_backends() -> None: + # importing here so it will pick up the mocked version in test_backends.py + from importlib.metadata import entry_points + + group_name = "torch_dynamo_backends" + eps = entry_points(group=group_name) + eps_dict = {name: eps[name] for name in eps.names} + for backend_name in eps_dict: + _BACKENDS[backend_name] = eps_dict[backend_name] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/tensorrt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/tensorrt.py new file mode 100644 index 0000000000000000000000000000000000000000..493e21a9dfc5fe929fdeefdf6153834d470ab561 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/tensorrt.py @@ -0,0 +1,12 @@ +# import torch # type: ignore[import] +# from .common import device_from_inputs, fake_tensor_unsupported # type: ignore[import] +# from .registry import register_backend # type: ignore[import] + +""" +Placeholder for TensorRT backend for dynamo via torch-tensorrt +""" + +# @register_backend +# def tensorrt(gm, example_inputs): +# import torch_tensorrt # type: ignore[import] +# pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/torchxla.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/torchxla.py new file mode 100644 index 0000000000000000000000000000000000000000..60d7b87bd0876a85702c07db7c82cd804ee608d1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/torchxla.py @@ -0,0 +1,55 @@ +import logging +from collections.abc import Callable +from typing import Any + +import torch +from functorch.compile import make_boxed_func +from torch import fx + +from ..backends.common import aot_autograd +from .registry import CompiledFn, register_backend, register_experimental_backend + + +log = logging.getLogger(__name__) + + +@register_experimental_backend +def openxla_eval( + model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor] +) -> CompiledFn: + return xla_backend_helper(model, fake_tensor_inputs, boxed=False) + + +def openxla_eval_boxed( + model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor] +) -> Callable[..., Any]: + return xla_backend_helper(model, fake_tensor_inputs, boxed=True) + + +def xla_backend_helper( + model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], boxed: bool = False +) -> Callable[..., Any]: + try: + import torch_xla.core.dynamo_bridge as bridge + except ImportError as e: + raise ImportError( + "Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla" + ) from e + + compiled_graph = None + + def fwd(*args: torch.Tensor) -> Any: + nonlocal model + nonlocal compiled_graph + if compiled_graph is None: + compiled_graph = bridge.extract_compiled_graph(model, args) + del model + return compiled_graph(*args) + + return make_boxed_func(fwd) if boxed else fwd + + +openxla = aot_autograd( + fw_compiler=openxla_eval_boxed, +) +register_backend(name="openxla", compiler_fn=openxla) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/tvm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/tvm.py new file mode 100644 index 0000000000000000000000000000000000000000..02dde50de0fe02d793226b64d852967d99d31de6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/backends/tvm.py @@ -0,0 +1,197 @@ +""" +This module provides TVM backend integration for TorchDynamo. + +Apache TVM is a deep learning compiler framework that can optimize and execute +models on various hardware backends. This module enables: + +- Compilation of PyTorch models to TVM's computation graphs +- Multiple scheduling options: + - Default scheduler + - Auto-scheduler for automatic optimization + - Meta-schedule for evolutionary search-based tuning +- Hardware-specific optimizations: + - CUDA GPU support + - CPU support with LLVM targeting and architecture-specific tuning + - Automatic detection of CPU capabilities (AVX2, AVX512) +- Tensor conversion utilities between PyTorch and TVM formats +- Configurable optimization levels and tuning trials + +The backend can be used with torch.compile(): + model = torch.compile(model, backend="tvm") +""" + +import functools +import importlib +import logging +import os +import sys +import tempfile +from collections.abc import Callable +from pathlib import Path +from types import MappingProxyType +from typing import Any, Optional + +import torch +from torch import fx + +from .common import device_from_inputs, fake_tensor_unsupported +from .registry import register_backend + + +log = logging.getLogger(__name__) + + +@register_backend +@fake_tensor_unsupported # type: ignore[arg-type] +def tvm( + gm: fx.GraphModule, + example_inputs: list[torch.Tensor], + *, + options: Optional[MappingProxyType[str, Any]] = None, +) -> Callable[..., Any]: + if options is None: + options = MappingProxyType({"scheduler": None, "trials": 20000, "opt_level": 3}) + assert options is not None + import tvm # type: ignore[import] + from tvm import relay # type: ignore[import] + from tvm.contrib import graph_executor # type: ignore[import] + + jit_mod = torch.jit.trace(gm, example_inputs) + device = device_from_inputs(example_inputs) + shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] + example_outputs = gm(*example_inputs) + if len(example_outputs) == 0: + log.warning("Explicitly fall back to eager due to zero output") + return gm.forward + mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) + if device.type == "cuda": + dev = tvm.cuda(device.index) + target = tvm.target.cuda() + else: + dev = tvm.cpu(0) + target = tvm.target.Target(llvm_target()) + + scheduler = options.get("scheduler", None) + if scheduler is None: + scheduler = os.environ.get("TVM_SCHEDULER", None) + + trials = options.get("trials", 20000) + opt_level = options.get("opt_level", 3) + + if scheduler == "auto_scheduler": + # pyrefly: ignore [import-error] + from tvm import auto_scheduler + + with ( + tempfile.NamedTemporaryFile() as log_file, + auto_scheduler.ApplyHistoryBest(log_file), + tvm.transform.PassContext( + opt_level=opt_level, config={"relay.backend.use_auto_scheduler": True} + ), + ): + lib = relay.build(mod, target=target, params=params) + elif scheduler == "meta_schedule": + # pyrefly: ignore [import-error] + from tvm import meta_schedule as ms + + with tempfile.TemporaryDirectory() as work_dir: + if device.type != "cuda": + # meta_schedule needs num-cores to be specified + # here we use the maximum core count + target = tvm.target.Target( + f"{llvm_target()} --num-cores {ms.utils.cpu_count(logical=False)}" + ) + # TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch + # once USE_PT_TVMDSOOP is updated and turned on by default in TVM. + assert trials > 0 + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + work_dir=work_dir, + max_trials_global=trials, + num_trials_per_iter=64, + params=params, + strategy="evolutionary", + opt_level=opt_level, + ) + lib = ms.relay_integration.compile_relay( + database=database, + mod=mod, + target=target, + params=params, + opt_level=opt_level, + ) + elif scheduler == "default" or not scheduler: + # no autotuning + with tvm.transform.PassContext(opt_level=opt_level): + lib = relay.build(mod, target=target, params=params) + else: + raise NotImplementedError( + "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. " + "There are three available options: default, auto_scheduler and meta_schedule." + ) + m = graph_executor.GraphModule(lib["default"](dev)) + + def to_torch_tensor(nd_tensor: tvm.nd.array) -> torch.Tensor: + """A helper function to transfer a NDArray to torch.tensor.""" + if nd_tensor.dtype == "bool": + # DLPack does not support boolean so it can't be handled by + # torch.utils.dlpack.from_pack. Workaround by going through + # numpy, although this brings additional data copy overhead. + return torch.from_numpy(nd_tensor.numpy()) + return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack()) + + def to_tvm_tensor(torch_tensor: torch.Tensor) -> tvm.nd.array: + """A helper function to transfer a torch.tensor to NDArray.""" + if torch_tensor.dtype == torch.bool: + # same reason as above, fallback to numpy conversion which + # could introduce data copy overhead + return tvm.nd.array(torch_tensor.cpu().numpy()) + return tvm.nd.from_dlpack(torch_tensor) + + def exec_tvm(*i_args: torch.Tensor) -> list[torch.Tensor]: + args = [a.contiguous() for a in i_args] + shape_info, _ = m.get_input_info() + active_inputs = {name for name, _ in shape_info.items()} + for idx, arg in enumerate(args, 0): + if arg.dim() != 0: + if arg.requires_grad: + arg = arg.detach() + inp_name = f"inp_{idx}" + if inp_name not in active_inputs: + log.warning( + "input %s skipped as not found in tvm's runtime library", + inp_name, + ) + continue + m.set_input( + inp_name, + to_tvm_tensor(arg), + ) + m.run() + return [to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())] + + return exec_tvm + + +tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule") +tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler") + + +def has_tvm() -> bool: + try: + importlib.import_module("tvm") + return True + except ImportError: + return False + + +@functools.cache +def llvm_target() -> str: + if sys.platform == "linux": + cpuinfo = Path("/proc/cpuinfo").read_text() + if "avx512" in cpuinfo: + return "llvm -mcpu=skylake-avx512" + elif "avx2" in cpuinfo: + return "llvm -mcpu=core-avx2" + return "llvm" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59f6f76317e6daf4d6dbcfc93d363442b5e4335f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__init__.py @@ -0,0 +1,431 @@ +""" +Python polyfills for common builtins. +""" + +# NOTE: 1. Please do not import any submodule in the directory here to avoid circular imports. +# 2. While adding a new polyfill module, also add it to POLYFILLED_MODULE_NAMES in loader.py. +# Add it in the TYPE_CHECKING block below as well. + +# mypy: allow-untyped-defs + +import types +from collections import OrderedDict +from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence +from itertools import repeat as _repeat +from operator import eq, ne +from typing import Any, TYPE_CHECKING + +import torch + +from ..utils import dict_keys + + +if TYPE_CHECKING: + # Load by torch._dynamo.polyfills.loader + # See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py + # Put the submodules here to avoid circular imports + from . import ( + _collections as _collections, + builtins as builtins, + functools as functools, + itertools as itertools, + operator as operator, + os as os, + pytree as pytree, + struct as struct, + sys as sys, + ) + +from torch.overrides import BaseTorchFunctionMode + + +# These classes handle support for TorchFunctionModes across +# graph breaks +# Today the TorchFunctionMode enter (for the classes we support) +# simply pushes the mode onto the stack. Since after this occurs +# the stack is mutated, and we replay these mutations, we don't need +# any cleanup logic to be run once the graph break occurs, we simply replay +# these mutations to ensure at the graph break the torch function mode stack is correct +# and reconstruct the torch function mode stack normally +# when we compile the resume function on the other side of the break. +# However, to ensure we exit properly +# in the resume function, we need to re-enter the contexts as we do other contexts. +# These contexts do nothing on enter, but provide the correct exit logic to ensure +# the stack state is correct. +class NoEnterTorchFunctionMode(BaseTorchFunctionMode): + def __enter__(self): + pass + + +def index(iterator, item, start=0, end=None): + from itertools import islice + + for i, elem in islice(enumerate(iterator), start, end): + if item == elem: + return i + # This will not run in dynamo + raise ValueError(f"{item} is not in {type(iterator)}") + + +def repeat(item, count): + for _ in range(count): + yield item + + +def radians(x): + import math + + return math.pi / 180.0 * x + + +def impl_CONTAINS_OP_fallback(a, b): + # performs fallback "a in b" + if hasattr(b, "__iter__"): + # use __iter__ if __contains__ is not available + for x in b: + if x == a: + return True + return False + raise TypeError(f"argument of type {type(b)} is not iterable") + + +def accumulate_grad(x, new_grad): + # polyfills according to the Gradient Layout Contract + if new_grad is None: + return + new_grad_strided = torch.empty_like(x) + new_grad_strided.copy_(new_grad) + if x.grad is None: + x.grad = new_grad_strided + elif torch.is_grad_enabled(): + x.grad = x.grad + new_grad_strided + else: + x.grad.add_(new_grad_strided) + + +# This mirrors +# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/listobject.c#L3352-L3413 +def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]): + """emulate `(1,2,3) > (1,2)` etc""" + + # Optimization: For equality, short-circuit if lengths differ + # This avoids iterating through elements and triggering guards on SymInts + left_len = len(left) + right_len = len(right) + + if op is eq and left_len != right_len: + return False + if op is ne and left_len != right_len: + return True + + # Apply `op` to the first pair that differ + for a, b in zip(left, right): + if a != b: + return op(a, b) + + # No more pairs to compare, so compare sizes. + return op(left_len, right_len) + + +def dict___eq__(d, other): + if (len(d) != len(other)) or (d.keys() != other.keys()): + return False + + if all(isinstance(a, OrderedDict) for a in (d, other)): + return list(d.items()) == list(other.items()) + + for k, v in d.items(): + if v != other[k]: + return False + + return True + + +def set_symmetric_difference(set1, set2): + symmetric_difference_set = set() + for x in set1: + if x not in set2: + symmetric_difference_set.add(x) + for x in set2: + if x not in set1: + symmetric_difference_set.add(x) + return symmetric_difference_set + + +def set_symmetric_difference_update(set1, set2): + result = set1.symmetric_difference(set2) + set1.clear() + set1.update(result) + + +def set_isdisjoint(set1, set2): + if not isinstance(set2, Iterable): + raise TypeError(f"'{type(set2)}' object is not iterable") + + for x in set1: + for y in set2: + if not isinstance(y, Hashable): + raise TypeError(f"unhashable type: '{type(y)}'") + if x == y: + return False + return True + + +def set_intersection(set1, *others): + if len(others) == 0: + return set1.copy() + + if not all(isinstance(s, Iterable) for s in others): + raise TypeError(f"set.difference expected an iterable, got {type(others)}") + + for s in others: + if any(not isinstance(x, Hashable) for x in s): + raise TypeError("unhashable type") + + # return a new set with elements common in all sets + intersection_set = set() + for x in set1: + for set2 in others: + if not any(x == y for y in set2): + break + else: + intersection_set.add(x) + return intersection_set + + +def set_intersection_update(set1, *others): + result = set1.intersection(*others) + set1.clear() + set1.update(result) + + +def set_union(set1, *others): + # frozenset also uses this function + if len(others) == 0: + return set1.copy() + + if not all(isinstance(s, Iterable) for s in others): + raise TypeError(f"set.union expected an iterable, got {type(others)}") + + for s in others: + if any(not isinstance(x, Hashable) for x in s): + raise TypeError("unhashable type") + + union_set = set(set1.copy()) + for set2 in others: + set_update(union_set, set2) + + # frozenset also uses this function + return type(set1)(union_set) + + +def set_update(set1, *others): + if len(others) == 0: + return set1 + + for set2 in others: + for x in set2: + if x not in set1: + set1.add(x) + + +def set_difference(set1, *others): + if len(others) == 0: + return set1.copy() + + if not all(isinstance(s, Iterable) for s in others): + raise TypeError(f"set.difference expected an iterable, got {type(others)}") + + for s in others: + if any(not isinstance(x, Hashable) for x in s): + raise TypeError("unhashable type") + + difference_set = set() + for x in set1: + for set2 in others: + if x in set2: + break + else: + difference_set.add(x) + return difference_set + + +def set_difference_update(set1, *others): + result = set1.difference(*others) + set1.clear() + set1.update(result) + + +def assert_dict_equal(self_, d1, d2, msg=None): + self_.assertTrue(d1 == d2, msg) + + +def assert_multi_line_equal(self_, first, second, msg=None): + return self_.assertTrue(first == second, msg) + + +# The original impl. uses difflib +def assert_sequence_equal(self_, seq1, seq2, msg=None, seq_type=None): + return self_.assertTrue(seq1 == seq2, msg) + + +def getattr_and_trace(*args, **kwargs): + wrapper_obj = args[0] + attr_name = args[1] + fn = getattr(wrapper_obj, attr_name) + return fn(*args[2:], **kwargs) + + +def mapping_get(obj, key, value=None, /): + try: + return obj.__getitem__(key) + except KeyError: + return value + + +def instantiate_user_defined_class_object(cls, /, *args, **kwargs): + obj = cls.__new__(cls, *args, **kwargs) + + # Only call __init__ if the object is an instance of the class + # Reference: https://github.com/python/cpython/blob/3.12/Objects/typeobject.c#L1670-L1673 + if isinstance(obj, cls): + obj.__init__(*args, **kwargs) + return obj + + +def mutable_mapping_update(self, data=(), /, **kwargs): + if isinstance(data, Mapping): + # Merge standard mapping with PyMapping_Items + for key, value in data.items(): + self[key] = value + # FIXME: Enabling the `elif`-branch below needs too many `VariableClass.call_obj_hasattr` changes. + # >>> class Foo: + # ... def __init__(self): + # ... self.keys = lambda: ['a', 'b', 'c'] # not required to be a method + # ... + # ... def __getitem__(self, key): + # ... return 0 + # ... + # >>> dict(Foo()) + # {'a': 0, 'b': 0, 'c': 0} + # + # > This is a rare case, so we comment it out for now. + # + # elif hasattr(data, "keys"): + # # Merge mapping-like object with PyMapping_Keys + PyObject_GetItem + # for key in data.keys(): + # self[key] = data[key] + else: + if not isinstance(data, Iterable): + raise TypeError(f"{type(data).__name__!r} object is not iterable") + # Likely a sequence of pairs + for key, value in data: + self[key] = value + + if kwargs: + for key, value in kwargs.items(): + self[key] = value + + +# Used with something like dict(obj) +def construct_dict(cls, data=(), /, **kwargs): + self = cls.__new__(cls) + mutable_mapping_update(self, data, **kwargs) + return self + + +def foreach_map_fn(*args): + op = args[0] + new_args: list[Any] = [] + at_least_one_list = False + for arg in args[1:]: + if not isinstance(arg, (list, tuple)): + new_args.append(_repeat(arg)) + else: + at_least_one_list = True + new_args.append(arg) + + # Just apply op once to args if there are no lists + if not at_least_one_list: + return op(*args[1:]) + + out = [] + for unpacked in zip(*new_args): + out.append(op(*unpacked)) + + return out + + +def foreach_lerp_inplace(self, end, weight): + # decompose foreach lerp into constituent ops, prevents a graph break due to + # converting a value to a scalar when arg[2] is a single tensor + result = torch._foreach_sub(end, self) + result = torch._foreach_mul(result, weight) + return torch._foreach_add_(self, result) + + +def foreach_pow_scalar(scalar, exps): + return torch._foreach_pow([scalar for _ in exps], exps) + + +def addcmul_inplace(self, tensor1, tensor2, value): + return self.add_(tensor1 * tensor2 * value) + + +def predicate(obj: Any) -> bool: + # This will cause the rest of dynamo to handle the if statement correctly, so we don't have to rewrite it here. + # We can't just use bool() here since we can't trace into that in general. + if obj: + return True + return False + + +def cmp_eq(a, b): + # Note that the commented `is` check should ideally be removed. This is a + # CPython optimization that skips the __eq__ checks it the obj id's are + # same. But, these lines adds many `is` nodes in the Fx graph for + # SymNodeVariable. For now, we can just skip this check. This is STILL + # correct because one of the __eq__ checks will pass later, just could be + # slow in some corner cases. + # if a is b: + # return True + result = a.__eq__(b) + if result is NotImplemented: + result = b.__eq__(a) + return result is not NotImplemented and result + + +def cmp_ne(a, b): + # Check if __ne__ is overridden + if isinstance(type(a).__ne__, types.FunctionType): + return a.__ne__(b) + return not cmp_eq(a, b) + + +def cmp_lt(a, b): + result = a.__lt__(b) + if result is NotImplemented: + raise TypeError(f"{type(a)} does not support the < operator") + return result + + +def cmp_le(a, b): + # Check if __le__ is overridden + if isinstance(type(a).__le__, types.FunctionType): + return a.__le__(b) + return cmp_eq(a, b) or cmp_lt(a, b) + + +def cmp_gt(a, b): + # Check if __gt__ is overridden + if isinstance(type(a).__gt__, types.FunctionType): + return a.__gt__(b) + # a > b is equivalent to b < a + return cmp_lt(b, a) + + +def cmp_ge(a, b): + # Check if __ge__ is overridden + if isinstance(type(a).__ge__, types.FunctionType): + return a.__ge__(b) + return cmp_eq(a, b) or cmp_gt(a, b) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc8750688a39e4f528d7ef137d623a96c24a380f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/_collections.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/_collections.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b17edda4b5c3b3e07bfeedb1475128628ec7423 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/_collections.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/builtins.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/builtins.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d73552d1061440d4adf0b510c92dfa6008830ddc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/builtins.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/functools.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/functools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a14fa306d292d1de78047ad61f9030934ba2dda3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/functools.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/fx.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/fx.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6133196b9d51cae2ae4d3f477ee707887ed3c4b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/fx.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/heapq.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/heapq.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94772f374bdfd63d956caaea03e6e8e6b2822e5a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/heapq.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/itertools.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/itertools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e94992f12031ebdec7fece4a42f979f788d68fe5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/itertools.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/loader.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/loader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b38716617937f73b7994430cf6f9211a82142e7d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/loader.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/operator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/operator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e15e0c2d40dad7f61b1a568fa8fb772603e9dcfd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/operator.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/os.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/os.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bb847cbc28b35e84f81be8df29a9006cb33bd86 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/os.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/pytree.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/pytree.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33f3ae87f8b5255a778cfd99c0673eb93c595959 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/pytree.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/struct.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/struct.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bc97b4fcbcad52d4c927bbaa4fb3f02b203612e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/struct.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/sys.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/sys.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32fed8abf96d6abf24487e4d755b36b68c0e7509 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/sys.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/tensor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58bb58550a8f4138253227788ad5d7a88e2b7f11 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/tensor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/_collections.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/_collections.py new file mode 100644 index 0000000000000000000000000000000000000000..9773635ae30587b06bb9f6b82c003392767b3873 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/_collections.py @@ -0,0 +1,33 @@ +""" +Python polyfills for builtins +""" + +from collections.abc import Iterable, MutableMapping +from typing import TypeVar + +from ..decorators import substitute_in_graph + + +__all__ = [] + + +T = TypeVar("T") + + +try: + import _collections # type: ignore[import-not-found] + + @substitute_in_graph(_collections._count_elements) + def _count_elements( + mapping: MutableMapping[T, int], + iterable: Iterable[T], + ) -> None: + "Tally elements from the iterable." + mapping_get = mapping.get + for elem in iterable: + mapping[elem] = mapping_get(elem, 0) + 1 + + __all__.append("_count_elements") + +except ImportError: + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/builtins.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/builtins.py new file mode 100644 index 0000000000000000000000000000000000000000..45feac9ca5dce561251c85794593c276dabaa4ef --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/builtins.py @@ -0,0 +1,123 @@ +""" +Python polyfills for builtins +""" + +from __future__ import annotations + +import builtins +import functools +import operator +from collections.abc import Callable +from typing import TYPE_CHECKING, TypeVar + +from ..decorators import substitute_in_graph + + +if TYPE_CHECKING: + from collections.abc import Iterable + + +__all__ = [ + "all", + "any", + "enumerate", + "sum", +] + + +_T = TypeVar("_T") + + +@substitute_in_graph(builtins.all, can_constant_fold_through=True) +def all(iterable: Iterable[object], /) -> bool: + for elem in iterable: + if not elem: + return False + return True + + +@substitute_in_graph(builtins.any, can_constant_fold_through=True) +def any(iterable: Iterable[object], /) -> bool: + for elem in iterable: + if elem: + return True + return False + + +@substitute_in_graph(builtins.enumerate, is_embedded_type=True) # type: ignore[arg-type] +def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T]]: + if not isinstance(start, int): + raise TypeError( + f"{type(start).__name__!r} object cannot be interpreted as an integer" + ) + + for x in iterable: + yield start, x + start += 1 + + +@substitute_in_graph(builtins.sum, can_constant_fold_through=True) # type: ignore[arg-type] +def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignment] + return functools.reduce(operator.add, iterable, start) + + +class _CallableIterator: + def __init__(self, fn, sentinel): # type: ignore[no-untyped-def] + self.fn = fn + self.sentinel = sentinel + + def __iter__(self): # type: ignore[no-untyped-def] + return self + + def __next__(self): # type: ignore[no-untyped-def] + # The iterator created in this case will call object with no arguments + # for each call to its __next__() method; + r = self.fn() + + # If the value returned is equal to sentinel, StopIteration will be raised + if r == self.sentinel: + raise StopIteration + + # otherwise the value will be returned. + return r + + +class _SENTINEL_MISSING: + pass + + +# TODO(guilhermeleobas): use substitute_in_graph for iter() +def iter_(fn_or_iterable, sentinel=_SENTINEL_MISSING, /): # type: ignore[no-untyped-def] + # Without a second argument, object must be a collection object which supports + # the iterable (__iter__) or the sequence protocol (__getitem__ with an integer + # starting at 0) + if sentinel is _SENTINEL_MISSING: + iterable = fn_or_iterable + if hasattr(iterable, "__iter__"): + iterator = iterable.__iter__() + if hasattr(iterator, "__next__"): + return iterator + else: + raise TypeError(f"'{type(iterator)}' object is not iterable") + if hasattr(iterable, "__getitem__"): + # Needs to be a new function to avoid iter becoming a generator + def sequence_protocol(iterable): # type: ignore[no-untyped-def] + i = 0 + while True: + try: + yield iterable.__getitem__(i) + i += 1 + except IndexError: + break + + return sequence_protocol(iterable) + raise TypeError(f"'{type(iterable)}' object is not iterable") + else: + # If the second argument, sentinel, is given, then object must be a + # callable object. + fn = fn_or_iterable + + if not isinstance(fn, Callable): # type: ignore[arg-type] + raise TypeError("iter(v, w): v must be a callable") + + return _CallableIterator(fn, sentinel) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/functools.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/functools.py new file mode 100644 index 0000000000000000000000000000000000000000..f70ca59bcea3eeab647583843bd1073e05e14639 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/functools.py @@ -0,0 +1,47 @@ +""" +Python polyfills for functools +""" + +import functools +from collections.abc import Callable, Iterable +from typing import TypeVar + +from ..decorators import substitute_in_graph + + +__all__ = ["reduce"] + + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +class _INITIAL_MISSING: + pass + + +# Reference: https://docs.python.org/3/library/functools.html#functools.reduce +@substitute_in_graph(functools.reduce) +def reduce( + function: Callable[[_U, _T], _U], + iterable: Iterable[_T], + initial: _U = _INITIAL_MISSING, # type: ignore[assignment] + /, +) -> _U: + it = iter(iterable) + + value: _U + if initial is _INITIAL_MISSING: + try: + value = next(it) # type: ignore[assignment] + except StopIteration: + raise TypeError( + "reduce() of empty iterable with no initial value", + ) from None + else: + value = initial + + for element in it: + value = function(value, element) + + return value diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/fx.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/fx.py new file mode 100644 index 0000000000000000000000000000000000000000..5a5ed97e0899d94fc4478de5acfa7879f5560ab2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/fx.py @@ -0,0 +1,41 @@ +from collections.abc import Callable +from typing import Any + +from torch._C import _fx_map_aggregate, _fx_map_arg +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.node import Node + +from ..decorators import substitute_in_graph + + +@substitute_in_graph(_fx_map_arg, can_constant_fold_through=True) +def map_arg(a: Any, fn: Callable[[Node], Any]) -> Any: + return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) + + +@substitute_in_graph(_fx_map_aggregate, can_constant_fold_through=True) +def map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any: + result: Any + if isinstance(a, tuple): + it = (map_aggregate(elem, fn) for elem in a) + # Support NamedTuple (if it has `_fields`) by repacking into original type. + result = type(a)(*it) if hasattr(a, "_fields") else tuple(it) + elif isinstance(a, list): + result = immutable_list([map_aggregate(elem, fn) for elem in a]) + elif isinstance(a, dict): + result = immutable_dict([(k, map_aggregate(v, fn)) for k, v in a.items()]) + elif isinstance(a, slice): + result = slice( + map_aggregate(a.start, fn), + map_aggregate(a.stop, fn), + map_aggregate(a.step, fn), + ) + else: + result = fn(a) + return result + + +__all__ = [ + "map_arg", + "map_aggregate", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/heapq.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/heapq.py new file mode 100644 index 0000000000000000000000000000000000000000..feddb5723614f581fdd232a162feaf00a3ca2fae --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/heapq.py @@ -0,0 +1,119 @@ +""" +Python polyfills for heapq +""" + +from __future__ import annotations + +import heapq +import importlib +import sys +from typing import TYPE_CHECKING, TypeVar + +from ..decorators import substitute_in_graph + + +if TYPE_CHECKING: + from types import ModuleType + + +_T = TypeVar("_T") + + +# Partially copied from CPython test/support/import_helper.py +# https://github.com/python/cpython/blob/bb8791c0b75b5970d109e5557bfcca8a578a02af/Lib/test/support/import_helper.py +def _save_and_remove_modules(names: set[str]) -> dict[str, ModuleType]: + orig_modules = {} + prefixes = tuple(name + "." for name in names) + for modname in list(sys.modules): + if modname in names or modname.startswith(prefixes): + orig_modules[modname] = sys.modules.pop(modname) + return orig_modules + + +def import_fresh_module(name: str, blocked: list[str]) -> ModuleType: + # Keep track of modules saved for later restoration as well + # as those which just need a blocking entry removed + names = {name, *blocked} + orig_modules = _save_and_remove_modules(names) + for modname in blocked: + sys.modules[modname] = None # type: ignore[assignment] + + try: + return importlib.import_module(name) + finally: + _save_and_remove_modules(names) + sys.modules.update(orig_modules) + + +# Import the pure Python heapq module, blocking the C extension +py_heapq = import_fresh_module("heapq", blocked=["_heapq"]) + + +__all__ = [ + "_heapify_max", + "_heappop_max", + "_heapreplace_max", + "heapify", + "heappop", + "heappush", + "heappushpop", + "heapreplace", + "merge", + "nlargest", + "nsmallest", +] + + +@substitute_in_graph(heapq._heapify_max) +def _heapify_max(heap: list[_T], /) -> None: + return py_heapq._heapify_max(heap) + + +@substitute_in_graph(heapq._heappop_max) # type: ignore[attr-defined] +def _heappop_max(heap: list[_T]) -> _T: + return py_heapq._heappop_max(heap) + + +@substitute_in_graph(heapq._heapreplace_max) # type: ignore[attr-defined] +def _heapreplace_max(heap: list[_T], item: _T) -> _T: + return py_heapq._heapreplace_max(heap, item) + + +@substitute_in_graph(heapq.heapify) +def heapify(heap: list[_T], /) -> None: + return py_heapq.heapify(heap) + + +@substitute_in_graph(heapq.heappop) +def heappop(heap: list[_T], /) -> _T: + return py_heapq.heappop(heap) + + +@substitute_in_graph(heapq.heappush) +def heappush(heap: list[_T], item: _T) -> None: + return py_heapq.heappush(heap, item) + + +@substitute_in_graph(heapq.heappushpop) +def heappushpop(heap: list[_T], item: _T) -> _T: + return py_heapq.heappushpop(heap, item) + + +@substitute_in_graph(heapq.heapreplace) +def heapreplace(heap: list[_T], item: _T) -> _T: + return py_heapq.heapreplace(heap, item) + + +@substitute_in_graph(heapq.merge) # type: ignore[arg-type] +def merge(*iterables, key=None, reverse=False): # type: ignore[no-untyped-def] + return py_heapq.merge(*iterables, key=key, reverse=reverse) + + +@substitute_in_graph(heapq.nlargest) # type: ignore[arg-type] +def nlargest(n, iterable, key=None): # type: ignore[no-untyped-def] + return py_heapq.nlargest(n, iterable, key=key) + + +@substitute_in_graph(heapq.nsmallest) # type: ignore[arg-type] +def nsmallest(n, iterable, key=None): # type: ignore[no-untyped-def] + return py_heapq.nsmallest(n, iterable, key=key) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/itertools.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/itertools.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbf9dfa1706751df86abcb55c2186c2ab47dd6e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/itertools.py @@ -0,0 +1,276 @@ +""" +Python polyfills for itertools +""" + +from __future__ import annotations + +import itertools +import operator +from collections.abc import Callable +from typing import Optional, overload, TYPE_CHECKING, TypeAlias, TypeVar + +from ..decorators import substitute_in_graph + + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + + +__all__ = [ + "accumulate", + "chain", + "chain_from_iterable", + "compress", + "cycle", + "dropwhile", + "filterfalse", + "islice", + "tee", + "zip_longest", + "pairwise", +] + + +_T = TypeVar("_T") +_U = TypeVar("_U") +_Predicate: TypeAlias = Callable[[_T], object] +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain +@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type] +def chain(*iterables: Iterable[_T]) -> Iterator[_T]: + for iterable in iterables: + yield from iterable + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.accumulate +@substitute_in_graph(itertools.accumulate, is_embedded_type=True) # type: ignore[arg-type] +def accumulate( + iterable: Iterable[_T], + func: Optional[Callable[[_T, _T], _T]] = None, + *, + initial: Optional[_T] = None, +) -> Iterator[_T]: + # call iter outside of the generator to match cypthon behavior + iterator = iter(iterable) + if func is None: + func = operator.add + + def _accumulate(iterator: Iterator[_T]) -> Iterator[_T]: + total = initial + if total is None: + try: + total = next(iterator) + except StopIteration: + return + + yield total + for element in iterator: + total = func(total, element) + yield total + + return _accumulate(iterator) + + +@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type] +def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: + # previous version of this code was: + # return itertools.chain(*iterable) + # If iterable is an infinite generator, this will lead to infinite recursion + for it in iterable: + yield from it + + +chain.from_iterable = chain_from_iterable # type: ignore[attr-defined] + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.compress +@substitute_in_graph(itertools.compress, is_embedded_type=True) # type: ignore[arg-type] +def compress(data: Iterable[_T], selectors: Iterable[_U], /) -> Iterator[_T]: + return (datum for datum, selector in zip(data, selectors) if selector) + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.cycle +@substitute_in_graph(itertools.cycle, is_embedded_type=True) # type: ignore[arg-type] +def cycle(iterable: Iterable[_T]) -> Iterator[_T]: + iterator = iter(iterable) + + def _cycle(iterator: Iterator[_T]) -> Iterator[_T]: + saved = [] + for element in iterable: + yield element + saved.append(element) + + while saved: + for element in saved: + yield element + + return _cycle(iterator) + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.dropwhile +@substitute_in_graph(itertools.dropwhile, is_embedded_type=True) # type: ignore[arg-type] +def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]: + # dropwhile(lambda x: x < 5, [1, 4, 6, 3, 8]) -> 6 3 8 + + iterator = iter(iterable) + for x in iterator: + if not predicate(x): + yield x + break + + yield from iterator + + +@substitute_in_graph(itertools.filterfalse, is_embedded_type=True) # type: ignore[arg-type] +def filterfalse(function: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]: + it = iter(iterable) + if function is None: + return filter(operator.not_, it) + else: + return filter(lambda x: not function(x), it) + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice +@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type] +def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]: + s = slice(*args) + start = 0 if s.start is None else s.start + stop = s.stop + step = 1 if s.step is None else s.step + if start < 0 or (stop is not None and stop < 0) or step <= 0: + raise ValueError( + "Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.", + ) + + if stop is None: + # TODO: use indices = itertools.count() and merge implementation with the else branch + # when we support infinite iterators + next_i = start + for i, element in enumerate(iterable): + if i == next_i: + yield element + next_i += step + else: + indices = range(max(start, stop)) + next_i = start + for i, element in zip(indices, iterable): + if i == next_i: + yield element + next_i += step + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.pairwise +@substitute_in_graph(itertools.pairwise, is_embedded_type=True) # type: ignore[arg-type] +def pairwise(iterable: Iterable[_T], /) -> Iterator[tuple[_T, _T]]: + a = None + first = True + for b in iterable: + if first: + first = False + else: + yield a, b # type: ignore[misc] + a = b + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee +@substitute_in_graph(itertools.tee) +def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: + iterator = iter(iterable) + shared_link = [None, None] + + def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def] + try: + while True: + if link[1] is None: + link[0] = next(iterator) + link[1] = [None, None] + value, link = link + yield value + except StopIteration: + return + + return tuple(_tee(shared_link) for _ in range(n)) + + +@overload +# pyrefly: ignore [inconsistent-overload] +def zip_longest( + iter1: Iterable[_T1], + /, + *, + fillvalue: _U = ..., +) -> Iterator[tuple[_T1]]: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def zip_longest( + iter1: Iterable[_T1], + iter2: Iterable[_T2], + /, +) -> Iterator[tuple[_T1 | None, _T2 | None]]: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def zip_longest( + iter1: Iterable[_T1], + iter2: Iterable[_T2], + /, + *, + fillvalue: _U = ..., +) -> Iterator[tuple[_T1 | _U, _T2 | _U]]: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def zip_longest( + iter1: Iterable[_T], + iter2: Iterable[_T], + iter3: Iterable[_T], + /, + *iterables: Iterable[_T], +) -> Iterator[tuple[_T | None, ...]]: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def zip_longest( + iter1: Iterable[_T], + iter2: Iterable[_T], + iter3: Iterable[_T], + /, + *iterables: Iterable[_T], + fillvalue: _U = ..., +) -> Iterator[tuple[_T | _U, ...]]: ... + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.zip_longest +@substitute_in_graph(itertools.zip_longest, is_embedded_type=True) # type: ignore[arg-type,misc] +def zip_longest( + *iterables: Iterable[_T], + fillvalue: _U = None, # type: ignore[assignment] +) -> Iterator[tuple[_T | _U, ...]]: + # zip_longest('ABCD', 'xy', fillvalue='-') -> Ax By C- D- + + iterators = list(map(iter, iterables)) + num_active = len(iterators) + if not num_active: + return + + while True: + values = [] + for i, iterator in enumerate(iterators): + try: + value = next(iterator) + except StopIteration: + num_active -= 1 + if not num_active: + return + iterators[i] = itertools.repeat(fillvalue) # type: ignore[arg-type] + value = fillvalue # type: ignore[assignment] + values.append(value) + yield tuple(values) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/loader.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..31479e9d86ce6163c1c54ccdea73cc224ac82904 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/loader.py @@ -0,0 +1,45 @@ +# Used to load and initialize polyfill handlers when importing torch._dynamo +# Please add a new import when adding a new polyfill module. + +import importlib +from typing import TYPE_CHECKING + +import torch.utils._pytree as python_pytree + +from .. import polyfills, trace_rules + + +if TYPE_CHECKING: + from types import ModuleType + + +# See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py +POLYFILLED_MODULE_NAMES: tuple[str, ...] = ( + "_collections", + "builtins", + "functools", + "itertools", + "operator", + "os", + "struct", + "sys", + "fx", + "tensor", +) +if python_pytree._cxx_pytree_dynamo_traceable: + POLYFILLED_MODULE_NAMES += ("pytree",) + +POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple( + importlib.import_module(f".{submodule}", package=polyfills.__name__) + for submodule in POLYFILLED_MODULE_NAMES +) + + +# Unregister the builtin functions from _builtin_function_ids to let them to be +# dispatched with the appropriate VariableTracker type. Otherwise, they will be +# dispatched with BuiltinVariable if present in _builtin_function_ids. +for polyfill_module in POLYFILLED_MODULES: + for polyfill_name in polyfill_module.__all__: + polyfill_handler = getattr(polyfill_module, polyfill_name) + original_fn = polyfill_handler.__torch_dynamo_original__ + trace_rules._builtin_function_ids.remove(id(original_fn)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/operator.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/operator.py new file mode 100644 index 0000000000000000000000000000000000000000..cae61df2c04307f294f1bf56fa68323acabc0e48 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/operator.py @@ -0,0 +1,119 @@ +""" +Python polyfills for operator +""" + +from __future__ import annotations + +import operator +from typing import Any, overload, TYPE_CHECKING, TypeVar +from typing_extensions import TypeVarTuple, Unpack + +from ..decorators import substitute_in_graph + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + +# Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`) +__all__ = ["attrgetter", "itemgetter", "methodcaller", "countOf"] + + +_T = TypeVar("_T") +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") +_Ts = TypeVarTuple("_Ts") +_U = TypeVar("_U") +_U1 = TypeVar("_U1") +_U2 = TypeVar("_U2") +_Us = TypeVarTuple("_Us") + + +@overload +# pyrefly: ignore [inconsistent-overload] +def attrgetter(attr: str, /) -> Callable[[Any], _U]: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def attrgetter( + attr1: str, attr2: str, /, *attrs: str +) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ... + + +# Reference: https://docs.python.org/3/library/operator.html#operator.attrgetter +@substitute_in_graph(operator.attrgetter, is_embedded_type=True) # type: ignore[arg-type,misc] +def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]: + if len(attrs) == 0: + raise TypeError("attrgetter expected 1 argument, got 0") + + if any(not isinstance(attr, str) for attr in attrs): + raise TypeError("attribute name must be a string") + + def resolve_attr(obj: Any, attr: str) -> Any: + for name in attr.split("."): + obj = getattr(obj, name) + return obj + + if len(attrs) == 1: + attr = attrs[0] + + def getter(obj: Any) -> Any: + return resolve_attr(obj, attr) + + else: + + def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] + return tuple(resolve_attr(obj, attr) for attr in attrs) + + return getter + + +@overload +# pyrefly: ignore [inconsistent-overload] +def itemgetter(item: _T, /) -> Callable[[Any], _U]: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def itemgetter( + item1: _T1, item2: _T2, /, *items: Unpack[_Ts] +) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ... + + +# Reference: https://docs.python.org/3/library/operator.html#operator.itemgetter +@substitute_in_graph(operator.itemgetter, is_embedded_type=True) # type: ignore[arg-type,misc] +def itemgetter(*items: Any) -> Callable[[Any], Any | tuple[Any, ...]]: + if len(items) == 0: + raise TypeError("itemgetter expected 1 argument, got 0") + + if len(items) == 1: + item = items[0] + + def getter(obj: Any) -> Any: + return obj[item] + + else: + + def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] + return tuple(obj[item] for item in items) + + return getter + + +# Reference: https://docs.python.org/3/library/operator.html#operator.methodcaller +@substitute_in_graph(operator.methodcaller, is_embedded_type=True) # type: ignore[arg-type] +def methodcaller(name: str, /, *args: Any, **kwargs: Any) -> Callable[[Any], Any]: + if not isinstance(name, str): + raise TypeError("method name must be a string") + + def caller(obj: Any) -> Any: + return getattr(obj, name)(*args, **kwargs) + + return caller + + +# Reference: https://docs.python.org/3/library/operator.html#operator.countOf +@substitute_in_graph(operator.countOf, can_constant_fold_through=True) # type: ignore[arg-type,misc] +def countOf(a: Iterable[_T], b: _T, /) -> int: + return sum(it is b or it == b for it in a) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/os.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/os.py new file mode 100644 index 0000000000000000000000000000000000000000..2f55d436ad8978bc0ddb46bdeeb356c518590547 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/os.py @@ -0,0 +1,37 @@ +""" +Python polyfills for os +""" + +from __future__ import annotations + +import os +from typing import AnyStr + +from ..decorators import substitute_in_graph + + +__all__ = ["fspath"] + + +# Copied from os.py in the standard library +@substitute_in_graph(os.fspath, can_constant_fold_through=True) +def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr: + if isinstance(path, (str, bytes)): + # pyrefly: ignore [bad-return] + return path + + path_type = type(path) + try: + path_repr = path_type.__fspath__(path) # type: ignore[arg-type] + except AttributeError: + if hasattr(path_type, "__fspath__"): + raise + raise TypeError( + f"expected str, bytes or os.PathLike object, not {path_type.__name__}", + ) from None + if isinstance(path_repr, (str, bytes)): + return path_repr # type: ignore[return-value] + raise TypeError( + f"expected {path_type.__name__}.__fspath__() to return str or bytes, " + f"not {type(path_repr).__name__}", + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/pytree.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/pytree.py new file mode 100644 index 0000000000000000000000000000000000000000..f5f9c1830333641b785b96780bb9b6b0475282e4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/pytree.py @@ -0,0 +1,758 @@ +""" +Python polyfills for torch.utils.pytree +""" + +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass, field +from typing import Any, TYPE_CHECKING, TypeVar + +import optree +import optree._C +import optree.utils +from optree import ( + is_namedtuple, + is_namedtuple_class, + is_namedtuple_instance, + is_structseq, + is_structseq_class, + is_structseq_instance, + namedtuple_fields, + structseq_fields, +) + +import torch.utils._cxx_pytree as cxx_pytree # noqa: F401 +import torch.utils._pytree as python_pytree +from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES + +from ..decorators import substitute_in_graph + + +if TYPE_CHECKING: + import builtins + from collections.abc import Callable, Iterable, Mapping + from typing_extensions import Self, TypeIs + + from torch.utils._cxx_pytree import PyTree + + +__all__ = [ + "is_namedtuple", + "is_namedtuple_class", + "is_namedtuple_instance", + "is_structseq", + "is_structseq_class", + "is_structseq_instance", + "namedtuple_fields", + "structseq_fields", + "treespec_leaf", + "treespec_tuple", + "treespec_dict", + "tree_is_leaf", + "tree_iter", + "tree_leaves", + "tree_flatten", + "tree_flatten_with_path", + "tree_structure", + "tree_unflatten", +] + + +_T = TypeVar("_T") +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +@substitute_in_graph( + optree._C.is_dict_insertion_ordered, + can_constant_fold_through=True, +) +def _(*args: Any, **kwargs: Any) -> bool: + # In namespace 'torch', the dictionary is always traversed in insertion order. + # This function returns True. + raise ValueError( + "Should not be called directly " + "because the original function will be called in the constant fold path." + ) + + +__name = "" +for __name, __func in ( + ("is_namedtuple", is_namedtuple), + ("is_namedtuple_class", is_namedtuple_class), + ("is_namedtuple_instance", is_namedtuple_instance), + ("is_structseq", is_structseq), + ("is_structseq_class", is_structseq_class), + ("is_structseq_instance", is_structseq_instance), + ("namedtuple_fields", namedtuple_fields), + ("structseq_fields", structseq_fields), +): + globals()[__name] = substitute_in_graph( + __func, # type: ignore[arg-type] + can_constant_fold_through=True, + )(__func.__python_implementation__) # type: ignore[attr-defined] + del __func +del __name + + +@substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) # type: ignore[arg-type] +def tree_is_leaf( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> bool: + if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)): + return True + if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: + return True + return False + + +@substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) # type: ignore[arg-type] +def tree_iter( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> Iterable[Any]: + stack = [tree] + while stack: + node = stack.pop() + if tree_is_leaf( + node, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ): + yield node + continue + + children, *_ = optree.tree_flatten_one_level( + node, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + stack.extend(reversed(children)) + + +@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type] +def tree_leaves( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> list[Any]: + return list( + tree_iter( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + ) + + +class _Asterisk(str): + __slots__ = () + + def __new__(cls) -> Self: + return super().__new__(cls, "*") + + def __repr__(self) -> str: + return "*" # no quotes + + +_asterisk = _Asterisk() +del _Asterisk + + +@dataclass(frozen=True) +class PyTreeSpec: + """Analog for :class:`optree.PyTreeSpec` in Python.""" + + _children: tuple[PyTreeSpec, ...] + _type: builtins.type | None + _metadata: Any + _entries: tuple[Any, ...] + _unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None + none_is_leaf: bool + namespace: str + + num_nodes: int = field(init=False) + num_leaves: int = field(init=False) + num_children: int = field(init=False) + + def __post_init__(self, /) -> None: + if self._type is None: + assert len(self._children) == 0 + assert self._metadata is None + assert self._entries == () + assert self._unflatten_func is None + num_nodes = 1 + num_leaves = 1 + num_children = 0 + else: + assert callable(self._unflatten_func) + num_nodes = 1 + num_leaves = 0 + for child in self._children: + num_nodes += child.num_nodes + num_leaves += child.num_leaves + num_children = len(self._children) + + object.__setattr__(self, "num_nodes", num_nodes) + object.__setattr__(self, "num_leaves", num_leaves) + object.__setattr__(self, "num_children", num_children) + + def __repr__(self, /) -> str: + def helper(treespec: PyTreeSpec) -> str: + if treespec.is_leaf(): + assert treespec.type is None + return _asterisk + + assert treespec.type is not None + assert callable(treespec._unflatten_func) + children_representations = [ + helper(subspec) for subspec in treespec._children + ] + if ( + treespec.type in BUILTIN_TYPES + or (treespec.type is type(None) and not self.none_is_leaf) + or optree.is_namedtuple_class(treespec.type) + or optree.is_structseq_class(treespec.type) + ): + # pyrefly: ignore [bad-return] + return treespec._unflatten_func( + treespec._metadata, + children_representations, + ) + return ( + f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], " + f"[{', '.join(children_representations)}])" + ) + + inner = [ + str(helper(self)), + *(["NoneIsLeaf"] if self.none_is_leaf else []), + f"namespace={self.namespace!r}", + ] + return f"PyTreeSpec({', '.join(inner)})" + + def __len__(self, /) -> int: + return self.num_leaves + + @property + def type(self, /) -> builtins.type | None: + return self._type + + def is_leaf(self, /) -> bool: + return self.num_nodes == 1 and self.num_leaves == 1 + + def paths(self, /) -> list[tuple[Any, ...]]: + def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None: + if treespec.is_leaf(): + paths.append(path_prefix) + return + + for entry, subspec in zip( + treespec._entries, + treespec._children, + strict=True, + ): + helper(subspec, path_prefix + [entry]) + + paths: list[list[Any]] = [] + helper(self, []) + return [tuple(path) for path in paths] + + def accessors(self, /) -> list[optree.PyTreeAccessor]: + def helper( + treespec: PyTreeSpec, + entry_path_prefix: list[optree.PyTreeEntry], + ) -> None: + if treespec.is_leaf(): + entry_paths.append(entry_path_prefix) + return + + node_type = treespec.type + assert node_type is not None + handler = optree.register_pytree_node.get( + node_type, namespace=treespec.namespace + ) + assert handler is not None + kind: optree.PyTreeKind = handler.kind + path_entry_type: type[optree.PyTreeEntry] = handler.path_entry_type + + for entry, subspec in zip( + treespec._entries, + treespec._children, + strict=True, + ): + helper( + subspec, + entry_path_prefix + [path_entry_type(entry, node_type, kind)], + ) + + entry_paths: list[list[optree.PyTreeEntry]] = [] + helper(self, []) + return [optree.PyTreeAccessor(path) for path in entry_paths] + + def children(self, /) -> list[PyTreeSpec]: + return list(self._children) + + def child(self, index: int, /) -> PyTreeSpec: + return self._children[index] + + def entries(self, /) -> list[Any]: + return list(self._entries) + + def entry(self, index: int, /) -> Any: + return self._entries[index] + + def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]: + def helper( + treespec: PyTreeSpec, + node: PyTree, + subtrees: list[PyTree], + ) -> None: + if treespec.is_leaf(): + subtrees.append(node) + return + + node_type = type(node) + if treespec.type not in BUILTIN_TYPES: + # Always require custom node types to match exactly + if node_type != treespec.type: + raise ValueError( + f"Type mismatch; " + f"expected {treespec.type!r}, but got {node_type!r}.", + ) + + children, metadata, *_ = optree.tree_flatten_one_level( + node, + none_is_leaf=self.none_is_leaf, + namespace=self.namespace, + ) + if len(children) != treespec.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {treespec.num_children}, but got {len(children)}.", + ) + if metadata != treespec._metadata: + raise ValueError( + f"Node context mismatch for custom node type {treespec.type!r}.", + ) + else: + # For builtin dictionary types, we allow some flexibility + # Otherwise, we require exact matches + both_standard_dict = ( + treespec.type in STANDARD_DICT_TYPES + and node_type in STANDARD_DICT_TYPES + ) + if not both_standard_dict and node_type != treespec.type: + raise ValueError( + f"Node type mismatch; " + f"expected {treespec.type!r}, but got {node_type!r}.", + ) + if len(node) != treespec.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {treespec.num_children}, but got {len(node)}.", + ) + + if both_standard_dict: + # dictionary types are compatible with each other + expected_keys = treespec.entries() + got_key_set = set(node) + expected_key_set = set(expected_keys) + if got_key_set != expected_key_set: + missing_keys = expected_key_set.difference(got_key_set) + extra_keys = got_key_set.difference(expected_key_set) + message = "" + if missing_keys: + message += f"; missing key(s): {missing_keys}" + if extra_keys: + message += f"; extra key(s): {extra_keys}" + raise ValueError(f"Node keys mismatch{message}.") + children = [node[key] for key in expected_keys] + else: + # node_type is treespec.type + children, metadata, *_ = optree.tree_flatten_one_level( + node, + none_is_leaf=self.none_is_leaf, + namespace=self.namespace, + ) + if ( + node_type is not deque # ignore mismatch of `maxlen` for deque + ) and metadata != treespec._metadata: + raise ValueError( + f"Node metadata mismatch for node type {treespec.type!r}; " + f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch + ) + + for subtree, subspec in zip(children, treespec._children, strict=True): + helper(subspec, subtree, subtrees) + + subtrees: list[PyTree] = [] + helper(self, tree, subtrees) + return subtrees + + def unflatten(self, leaves: Iterable[Any], /) -> PyTree: + if not isinstance(leaves, (list, tuple)): + leaves = list(leaves) + if len(leaves) != self.num_leaves: + raise ValueError( + f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " + f"but the spec refers to a pytree that holds {self.num_leaves} " + f"items ({self}).", + ) + if self.is_leaf(): + return leaves[0] + + # Recursively unflatten the children + start = 0 + end = 0 + subtrees = [] + for subspec in self._children: + end += subspec.num_leaves + subtrees.append(subspec.unflatten(leaves[start:end])) + start = end + + assert callable(self._unflatten_func) + return self._unflatten_func(self._metadata, subtrees) + + +def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec | python_pytree.TreeSpec]: + return isinstance(obj, (PyTreeSpec, python_pytree.TreeSpec)) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.treespec_leaf, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def treespec_leaf( + *, + none_is_leaf: bool = False, + namespace: str = "", # unused +) -> PyTreeSpec: + return PyTreeSpec( + (), + None, + None, + (), + None, + none_is_leaf=none_is_leaf, + namespace="", + ) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.treespec_tuple, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def treespec_tuple( + iterable: Iterable[PyTreeSpec] = (), + /, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> PyTreeSpec: + children = tuple(iterable) + if any(not _is_pytreespec_instance(child) for child in children): + raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.") + if any(child.none_is_leaf != none_is_leaf for child in children): + raise ValueError( + "All children PyTreeSpecs must have the same `none_is_leaf` value " + f"as the parent; expected {none_is_leaf}, got: {children!r}.", + ) + if any(child.namespace not in (namespace, "") for child in children): + raise ValueError( + "All children PyTreeSpecs must have the same `namespace` value " + f"as the parent; expected {namespace!r}, got: {children!r}.", + ) + handler = optree.register_pytree_node.get(tuple, namespace=namespace) + assert handler is not None + return PyTreeSpec( + tuple(children), + tuple, + None, + tuple(range(len(children))), + handler.unflatten_func, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.treespec_dict, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def treespec_dict( + mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (), + /, + *, + none_is_leaf: bool = False, + namespace: str = "", + **kwargs: PyTreeSpec, +) -> PyTreeSpec: + dct = dict(mapping, **kwargs) + if any(not _is_pytreespec_instance(child) for child in dct.values()): + raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.") + if any(child.none_is_leaf != none_is_leaf for child in dct.values()): + raise ValueError( + "All children PyTreeSpecs must have the same `none_is_leaf` value " + f"as the parent; expected {none_is_leaf}, got: {dct!r}.", + ) + if any(child.namespace not in (namespace, "") for child in dct.values()): + raise ValueError( + "All children PyTreeSpecs must have the same `namespace` value " + f"as the parent; expected {namespace!r}, got: {dct!r}.", + ) + + ( + children, + metadata, + entries, + unflatten_func, + ) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated] + dct, # type: ignore[arg-type] + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + return PyTreeSpec( + tuple(children), # type: ignore[arg-type] + dict, + metadata, + entries, + unflatten_func, # type: ignore[arg-type] + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_flatten, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_flatten( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> tuple[list[Any], PyTreeSpec]: + def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: + if tree_is_leaf( + node, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ): + leaves.append(node) + return PyTreeSpec( + (), + None, + None, + (), + None, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + ( + children, + metadata, + entries, + unflatten_func, + ) = optree.tree_flatten_one_level( + node, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + # Recursively flatten the children + subspecs = tuple(helper(child, leaves) for child in children) + return PyTreeSpec( + subspecs, + type(node), + metadata, + entries, + unflatten_func, # type: ignore[arg-type] + none_is_leaf=none_is_leaf, + namespace=namespace, + ) # type: ignore[arg-type] + + leaves: list[Any] = [] + treespec = helper(tree, leaves) + return leaves, treespec + + +@substitute_in_graph( # type: ignore[arg-type] + optree._C.flatten, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def _C_flatten( + tree: PyTree, + /, + leaf_predicate: Callable[[PyTree], bool] | None = None, + none_is_leaf: bool = False, + namespace: str = "", +) -> tuple[list[Any], PyTreeSpec]: + return tree_flatten( # type: ignore[return-value] + tree, + is_leaf=leaf_predicate, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_flatten_with_path, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_flatten_with_path( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]: + leaves, treespec = tree_flatten( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + return treespec.paths(), leaves, treespec # type: ignore[return-value] + + +@substitute_in_graph( # type: ignore[arg-type] + optree._C.flatten_with_path, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def _C_flatten_with_path( + tree: PyTree, + /, + leaf_predicate: Callable[[PyTree], bool] | None = None, + none_is_leaf: bool = False, + namespace: str = "", +) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]: + return tree_flatten_with_path( # type: ignore[return-value] + tree, + is_leaf=leaf_predicate, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_structure, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_structure( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> PyTreeSpec: + return tree_flatten( # type: ignore[return-value] + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + )[1] + + +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_unflatten, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree: + if not _is_pytreespec_instance(treespec): + raise TypeError( + f"Expected `treespec` to be an instance of " + f"PyTreeSpec but got item of type {type(treespec)}." + ) + return treespec.unflatten(leaves) + + +_none_registration = optree.register_pytree_node.get(type(None)) +assert _none_registration is not None + + +@substitute_in_graph( # type: ignore[arg-type] + _none_registration.unflatten_func, + can_constant_fold_through=True, + skip_signature_check=True, +) +def none_unflatten(_: None, children: Iterable[_T], /) -> None: + if len(list(children)) != 0: + raise ValueError("Expected no children.") + return None + + +with optree.dict_insertion_ordered(False, namespace="torch"): + _dict_registration = optree.register_pytree_node.get(dict) + assert _dict_registration is not None + + +@substitute_in_graph( # type: ignore[arg-type] + _dict_registration.flatten_func, + can_constant_fold_through=True, + skip_signature_check=True, +) +def dict_flatten( + dct: dict[_KT, _VT], / +) -> tuple[list[_VT], tuple[list[_KT], list[_KT]], tuple[_KT, ...]]: + sorted_keys = optree.utils.total_order_sorted(dct) + values = [dct[key] for key in sorted_keys] + original_keys = list(dct) + return values, (original_keys, sorted_keys), tuple(sorted_keys) + + +@substitute_in_graph( # type: ignore[arg-type] + _dict_registration.unflatten_func, + can_constant_fold_through=True, + skip_signature_check=True, +) +def dict_unflatten( + metadata: tuple[list[_KT], list[_KT]], + values: Iterable[_VT], + /, +) -> dict[_KT, _VT]: + original_keys, sorted_keys = metadata + d = dict.fromkeys(original_keys) + d.update(zip(sorted_keys, values, strict=True)) + return d # type: ignore[return-value] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/struct.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/struct.py new file mode 100644 index 0000000000000000000000000000000000000000..f4522a12f7323e51da6f4454814e87daf82cea98 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/struct.py @@ -0,0 +1,27 @@ +""" +Python polyfills for struct +""" + +from __future__ import annotations + +import struct +from typing import Any +from typing_extensions import Buffer + +from ..decorators import substitute_in_graph + + +__all__ = [ + "pack", + "unpack", +] + + +@substitute_in_graph(struct.pack, can_constant_fold_through=True) # type: ignore[arg-type] +def pack(fmt: bytes | str, /, *v: Any) -> bytes: + return struct.pack(fmt, *v) + + +@substitute_in_graph(struct.unpack, can_constant_fold_through=True) # type: ignore[arg-type] +def unpack(format: bytes | str, buffer: Buffer, /) -> tuple[Any, ...]: + return struct.unpack(format, buffer) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/sys.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/sys.py new file mode 100644 index 0000000000000000000000000000000000000000..ab666c385806f9cd56e489038a0884be861c0bf3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/sys.py @@ -0,0 +1,34 @@ +""" +Python polyfills for sys +""" + +from __future__ import annotations + +import sys + +from ..decorators import substitute_in_graph + + +__all__ = [ + "intern", + "getrecursionlimit", +] + + +@substitute_in_graph(sys.intern, can_constant_fold_through=True) +def intern(string: str, /) -> str: + return string + + +@substitute_in_graph(sys.getrecursionlimit, can_constant_fold_through=True) +def getrecursionlimit() -> int: + return sys.getrecursionlimit() + + +if hasattr(sys, "get_int_max_str_digits"): + + @substitute_in_graph(sys.get_int_max_str_digits, can_constant_fold_through=True) + def get_int_max_str_digits() -> int: + return sys.get_int_max_str_digits() + + __all__ += ["get_int_max_str_digits"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/tensor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..dffa98f60f3b578810a2386255964d03858afa37 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/polyfills/tensor.py @@ -0,0 +1,40 @@ +from typing import Any + +import torch + +from ..decorators import substitute_in_graph + + +@substitute_in_graph( # type: ignore[arg-type] + torch.Tensor._make_subclass +) +def make_subclass( + cls: type[Any], data: torch.Tensor, requires_grad: bool = False, **kwargs: Any +) -> Any: + with torch._C.DisableTorchFunctionSubclass(): + # This is a rough approximation of `THPVariable_make_subclass`. It should + # suffice for most of Dynamo tracing purposes. + # https://github.com/pytorch/pytorch/blob/ccfde4dadfa3c342076a1ee387017f84dd4ad2f7/torch/csrc/autograd/python_variable.cpp#L597-L650 + assert len(kwargs) == 0, ( + "_make_subclass only supports requires_grad as keyword arg" + ) + data = data.detach() + + # Avoid unnecessary `requires_grad` mutation, which isn't supported in Dynamo. + if data.requires_grad != requires_grad: + data.requires_grad = requires_grad + + # Dynamo can't yet handle upcasting to base tensor type via `as_subclass`. + if cls is torch.Tensor: + return torch.Tensor(data) + + # Calling `as_subclass` because + # 1. Dynamo knows how to handle it + # 2. the C impls match at this point -- both `THPVariable_make_subclass` and + # `THPVariable_as_subclass` calls `THPVariable_NewWithVar`. + return data.as_subclass(cls) + + +__all__ = [ + "make_subclass", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5394f6df3545b6368a94f19c602b26ed2c1957e4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__pycache__/after_aot.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__pycache__/after_aot.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44b52759a25e2c3c2998e5848ed6750c825b5157 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__pycache__/after_aot.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__pycache__/after_dynamo.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__pycache__/after_dynamo.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a8fe6c930a54361b919054f37857797af46ffdf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__pycache__/after_dynamo.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__pycache__/aoti.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__pycache__/aoti.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3699b7016c75365ed62a682e78fdc7d7e5cc6fb2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/__pycache__/aoti.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py new file mode 100644 index 0000000000000000000000000000000000000000..25ef68a111080a42e97e7fe738203e5a42e1f9df --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py @@ -0,0 +1,1281 @@ +""" +Utilities for reproducing and debugging issues in PyTorch's Dynamo AOT compilation. + +This module provides tools and infrastructure for: +1. Generating minimal reproducible test cases ("repros") from failing compilations +2. Analyzing accuracy issues between eager and compiled execution +3. Minifying large models/inputs to isolate problematic patterns +4. Debugging compiler errors and accuracy divergences + +The main components include: +- Repro generation: Creates standalone Python files that reproduce compiler issues +- Minification: Reduces large graphs to minimal failing examples +- Accuracy analysis: Compares compiled vs eager execution, with fp64 reference +- Debug tools: Dumps graph state, tracks intermediates, analyzes divergences + +This is primarily used by PyTorch developers and researchers to debug issues in +the Dynamo AOT compilation pipeline, particularly for the Inductor backend. +""" + +from __future__ import annotations + +import argparse +import copy +import functools +import io +import logging +import os +import shutil +import subprocess +import sys +import textwrap +import uuid +from importlib import import_module +from tempfile import TemporaryFile +from typing import Any, IO, Optional, TYPE_CHECKING, Union +from typing_extensions import Unpack + +import sympy + + +try: + from triton.runtime.autotuner import Autotuner, Heuristics + from triton.runtime.jit import JITFunction +except ImportError: + + class Autotuner: # type: ignore[no-redef] + pass + + class JITFunction: # type: ignore[no-redef] + pass + + class Heuristics: # type: ignore[no-redef] + pass + + +import torch +import torch.fx as fx +import torch.nn as nn +from torch._dynamo.debug_utils import ( + _cuda_system_info_comment, + AccuracyError, + backend_accuracy_fails, + BuckTargetWriter, + cast_to_fp64, + extra_deps, + extra_imports, + generate_config_string, + generate_env_vars_string, + helper_for_dump_minify, + InputReader, + InputWriter, + MAX_CONSTANT_NUMEL_INLINE, + minifier_dir, + NNModuleToString, + NopInputReader, + same_two_models, +) +from torch._dynamo.utils import clone_inputs, counters, same +from torch._environment import is_fbcode +from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table +from torch._inductor.cpp_builder import normalize_path_separator +from torch._library.fake_class_registry import FakeScriptObject +from torch._ops import OpOverload +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import ( + fx_placeholder_targets, + has_free_symbols, +) +from torch.hub import tqdm + +from .. import config + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from torch._inductor.compile_fx import _CompileFxCallable, _CompileFxKwargs + from torch._inductor.output_code import OutputCode + from torch._inductor.utils import InputType + + +log = logging.getLogger(__name__) + + +inductor_config = import_module("torch._inductor.config") +use_buck = is_fbcode() + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MAIN ENTRY POINT +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def wrap_compiler_debug( + unconfigured_compiler_fn: _CompileFxCallable, + compiler_name: str, +) -> _CompileFxCallable: + """ + Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both + forward and backward call separately with the backend compiler_fn - like + inductor or nvfuser. Intercepting after Aot Autograd presents neat + abstraction, where all the params are lifted as graph inputs, making it easy + to save the graph as a string. + """ + + @functools.wraps(unconfigured_compiler_fn) + def debug_wrapper( + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + **kwargs: Unpack[_CompileFxKwargs], + ) -> OutputCode: + from torch._subclasses import FakeTensorMode + + compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) + + from torch._functorch.aot_autograd import get_aot_graph_name + + graph_name = get_aot_graph_name() + + # TODO: why do we need to deepcopy the original graph? + orig_graph = copy.deepcopy(gm.graph) + assert config.repro_after in ("dynamo", "aot", None) + + try: + # Call the compiler_fn - which is either aot_autograd or inductor + # with fake inputs + inner_compiled_fn = compiler_fn(gm, example_inputs) + except Exception: + # TODO: Failures here are troublesome because no real inputs, + # need a different serialization strategy + if config.repro_after == "aot": + if config.repro_level == 1: + dump_compiler_graph_state( + fx.GraphModule(gm, orig_graph), + example_inputs, + compiler_name, + ) + elif config.repro_level == 2: + dump_to_minify( + fx.GraphModule(gm, orig_graph), + example_inputs, + compiler_name, + ) + log.error("CompilerError") + raise + + # We may run regular PyTorch compute that may trigger Dynamo, do NOT + # recursively attempt to accuracy minify in that case! + def deferred_for_real_inputs( + real_inputs: Sequence[InputType], **_kwargs: object + ) -> Any: + # This is a bit obscure: if we recursively try to accuracy minify + # the SAME function, this would trigger. But most of the time + # we should never hit this branch + assert not _kwargs + if config.repro_after != "aot": + assert not isinstance(inner_compiled_fn, str) + return inner_compiled_fn(real_inputs) + with config.patch(repro_after=None): + return inner_debug_fn(real_inputs) + + def inner_debug_fn(real_inputs: Sequence[InputType]) -> Any: + """ + Aot Autograd fw_compiler and bw_compiler can have fake tensors. So, + example_inputs can be fake tensors. We can call compiler_fn (which is + inductor or nvfuser) with fake tensors but the actually compiled_fn + should be called with real tensors. Therefore, the actual invocation + is deferred. + """ + # Copy the tensor attrs like shape, stride etc by converting to Fake Tensor + # because inductor clears the tensor list in its codegen. And example_inputs + # are available only for the first invocation. + fake_mode = FakeTensorMode() + copy_tensor_attrs = [ + fake_mode.from_tensor(x) if isinstance(x, torch.Tensor) else x + for x in real_inputs + ] + if config.repro_level == 3: + # Always dump the original module in case we have segfaults + dump_to_minify( + fx.GraphModule(gm, orig_graph), real_inputs, compiler_name + ) + + if config.repro_level == 4: + if compiler_name != "inductor": + raise NotImplementedError( + "Accuracy minification is supported for inductor only" + ) + failed = not same_two_models( + gm, + inner_compiled_fn, # type: ignore[arg-type] + real_inputs, + only_fwd=True, + ignore_non_fp=config.repro_ignore_non_fp, + ) + + if failed: + log.warning( + "Accuracy failed for the AOT Autograd graph %s", graph_name + ) + dump_compiler_graph_state( + fx.GraphModule(gm, orig_graph), + real_inputs, + f"{compiler_name}_accuracy", + ) + dump_to_minify( + fx.GraphModule(gm, orig_graph), + real_inputs, + f"{compiler_name}_accuracy", + ) + raise AccuracyError("Bad accuracy detected") + else: + # Call the compiled function with real inputs + return inner_compiled_fn(real_inputs) # type: ignore[operator] + else: + try: + # Call the compiled function with real inputs + out = inner_compiled_fn(real_inputs) # type: ignore[operator] + # sync cuda kernels to ensure IMA detection + for arg in example_inputs: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + torch.cuda.synchronize() + break + return out + except Exception: + if config.repro_level == 1: + dump_compiler_graph_state( + fx.GraphModule(gm, orig_graph), + copy_tensor_attrs, + compiler_name, + ) + elif config.repro_level == 2: + dump_to_minify( + fx.GraphModule(gm, orig_graph), + copy_tensor_attrs, + compiler_name, + ) + raise + + if config.repro_after == "aot": + compiled_fn = deferred_for_real_inputs + compiled_fn._boxed_call = True # type: ignore[attr-defined] + return compiled_fn # type: ignore[return-value] + else: + return inner_compiled_fn + + return debug_wrapper + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# DUMP REPROS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def maybe_fbcode_instructions() -> str: + if is_fbcode(): + extra_deps_formatted = "\n".join([f' "{dep}",' for dep in extra_deps]) + if len(extra_deps_formatted) > 0: + extra_deps_formatted = "\n" + extra_deps_formatted + return f"""\ +\"\"\" +To run this script in fbcode: +- Create a directory (//scripts/{{your_unixname}}/repro) +- Put this file in scripts/{{your_unixname}}/repro/fx_graph_runnable.py +- Add a TARGETS file that looks like the following +- `buck2 run //scripts/{{your_unixname}}/repro:repro` + +NOTE: you may need additional deps to actually be able to run the script. +``` +# Contents of TARGETS file +load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") + +python_binary( + name = "repro", + main_src = "fx_graph_runnable.py", + deps = [ + "//caffe2:torch",{extra_deps_formatted} + ], +) +``` +\"\"\" +""" + else: + return "" + + +def generate_compiler_repro_string( + gm: torch.fx.GraphModule, + args: Sequence[Any], + *, + stable_output: bool = False, + save_dir: Optional[str] = None, + stable_hash: bool = False, + has_distributed_ops: bool = False, +) -> str: + if save_dir is not None: + save_dir = normalize_path_separator(save_dir) + # Add distributed imports if needed + distributed_imports = "" + if has_distributed_ops: + distributed_imports = textwrap.dedent( + """ +import torch.distributed as dist +from torch.testing._internal.distributed.fake_pg import FakeStore + """ + ).strip() + + triton_imports = "" + + if len(kernel_side_table.id_to_kernel) > 0: + triton_imports = textwrap.dedent( + """ +import triton +import triton.language as tl + """ + ).strip() + + model_str = textwrap.dedent( + f""" +{generate_env_vars_string(stable_output=stable_output)} +import torch +from torch import tensor, device +import torch.fx as fx +from torch._dynamo.testing import rand_strided +from math import inf +import torch._inductor.inductor_prims +{distributed_imports} +{triton_imports} + +{generate_config_string(stable_output=stable_output)} + +isolate_fails_code_str = None + +{extra_imports} + +{maybe_fbcode_instructions()} + """ + ) + model_str += textwrap.dedent( + """ +if "__compile_source__" in globals(): + import inspect as __after_aot_inspect + import linecache as __after_aot_linecache + __after_aot_filename = __after_aot_inspect.currentframe().f_code.co_filename + __after_aot_linecache.cache[__after_aot_filename] = ( + len(__compile_source__), + None, + __compile_source__.splitlines(True), + __after_aot_filename, + ) +""" + ) + if not stable_output: + model_str += f"# torch version: {torch.version.__version__}\n" + if hasattr(torch.version, "cuda"): + model_str += f"# torch cuda version: {torch.version.cuda}\n" + if hasattr(torch.version, "git_version"): + model_str += f"# torch git version: {torch.version.git_version}\n\n\n" + model_str += _cuda_system_info_comment() + + kernel_side_table_prefix = ( + "torch._higher_order_ops.triton_kernel_wrap.kernel_side_table" + ) + # Track which grid entry corresponds to the best config + for id in kernel_side_table.id_to_kernel: + kernel = kernel_side_table.get_kernel(id) + + try: + if isinstance(kernel, Autotuner): + # pyrefly: ignore [missing-attribute] + if isinstance(kernel.fn, Heuristics): + model_str += "ERROR: Repro will not work as intended, " + model_str += "triton.runtime.autotuner.Heuristics is not currently supported\n" + break + + config_strs = [] + # pyrefly: ignore [missing-attribute] + for kernel_config in kernel.configs: + # pyrefly: ignore [bad-argument-type] + config_strs.append(f"""triton.Config( + {str(kernel_config.kwargs)}, + num_warps={kernel_config.num_warps}, + num_stages={kernel_config.num_stages}, + )""") + + config_str = ",".join(config_strs) + model_str += textwrap.dedent(f""" + @triton.autotune( + configs=[ + {config_str} + ], + key=[] + ) + """).strip() + + model_str += "\n@triton.jit\n" + # pyrefly: ignore [missing-attribute] + src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src + fn_name = ( + # pyrefly: ignore [missing-attribute] + kernel._fn_name + if isinstance(kernel, JITFunction) + # pyrefly: ignore # missing-attribute + else kernel.fn._fn_name + ) + fn_name = fn_name.split(".")[-1] + + model_str += src_code + model_str += "\n" + model_str += f"{kernel_side_table_prefix}.add_kernel({fn_name})\n" + except AttributeError as e: + model_str += "ERROR: Repro will not work as intended, " + model_str += f"User defined triton kernel exception: {e}\n" + + # pyrefly: ignore [unbound-name] + if len(kernel_side_table.constant_args) > 0: + # pyrefly: ignore [unbound-name] + model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n" + + model_str += NNModuleToString.convert(gm) + + writer = InputWriter(save_dir, stable_hash=stable_hash) + used_syms = {} + + # Extract from graph placeholders and their corresponding arguments + placeholder_targets = fx_placeholder_targets(gm) + for placeholder, arg in zip(placeholder_targets, args): + # pyrefly: ignore [unbound-name] + if isinstance(arg, (int, torch.SymInt)): + writer.symint(placeholder, arg) + # pyrefly: ignore [unbound-name] + elif isinstance(arg, torch.Tensor): + # TODO: improve these names with FQN + writer.tensor(placeholder, arg) + elif arg is None: + writer.const(placeholder) + else: + writer.unsupported(placeholder, arg) + + # Extract symbolic variables from the same arguments + # pyrefly: ignore [unbound-name] + if ( + # pyrefly: ignore [unbound-name] + isinstance(arg, torch.SymInt) + # By checking sympy.Symbol, we are excluding any symbolic expressions. + # TODO: we may need to solve expressions to extract symbol definitions. + and isinstance(arg.node.expr, sympy.Symbol) + and arg.node.hint is not None + ): + used_syms[str(arg.node)] = arg.node.hint + # pyrefly: ignore [unbound-name] + elif isinstance(arg, torch.Tensor): + # Extract symbolic variables from tensor shapes and strides + for dim in arg.shape: + # pyrefly: ignore [unbound-name] + if ( + # pyrefly: ignore [unbound-name] + isinstance(dim, torch.SymInt) + and isinstance(dim.node.expr, sympy.Symbol) + and dim.node.hint is not None + ): + used_syms[str(dim.node)] = dim.node.hint + for stride in arg.stride(): + # pyrefly: ignore [unbound-name] + if ( + # pyrefly: ignore [unbound-name] + isinstance(stride, torch.SymInt) + and isinstance(stride.node.expr, sympy.Symbol) + and stride.node.hint is not None + ): + used_syms[str(stride.node)] = stride.node.hint + # Add symbolic variable definitions to the top of the generated code + if used_syms: + hint_lines = "\n".join( + f"{name} = {hint}" for name, hint in sorted(used_syms.items()) + ) + model_str = f"{hint_lines}\n\n{model_str}" + + load_args_lines = writer.lines() + load_args_code = "\n".join(load_args_lines) + model_str += load_args_code + "\n" + + model_str += "mod = Repro()\n" + return model_str + + +def save_graph_repro( + fd: IO[Any], + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, + *, + stable_output: bool = False, + save_dir: Optional[str] = None, + command: str = "run", + accuracy: Optional[Union[str, bool]] = None, + tracing_mode: Optional[str] = None, + check_str: Optional[str] = None, + stable_hash: bool = False, +) -> None: + if any( + isinstance(arg, torch.fx.experimental._backward_state.BackwardState) + for arg in args + ): + fd.write( + "Repro is not generated due to existence of BackwardState in graph input" + ) + return + + if save_dir is not None: + save_dir = normalize_path_separator(save_dir) + + # Check if the graph contains distributed operations + has_distributed_ops = any( + node.op == "call_function" + and isinstance(node.target, OpOverload) + and node.target.namespace in {"_c10d_functional", "c10d_functional"} + for node in gm.graph.nodes + ) + + fd.write( + generate_compiler_repro_string( + gm, + args, + stable_output=stable_output, + save_dir=save_dir, + stable_hash=stable_hash, + has_distributed_ops=has_distributed_ops, + ) + ) + if accuracy is None: + accuracy = "_accuracy" in compiler_name + if tracing_mode is None: + tracing_mode = "real" + if any( + has_free_symbols(a) for a in args if not isinstance(a, FakeScriptObject) + ): + tracing_mode = "symbolic" + fd.write("if __name__ == '__main__':\n") + fd.write(" from torch._dynamo.repro.after_aot import run_repro\n") + + # Add distributed initialization before run_repro if needed + if has_distributed_ops: + fd.write( + " # Initialize FakeProcessGroup for distributed operations\n" + " store = FakeStore()\n" + " dist.init_process_group(\n" + ' backend="fake",\n' + " rank=0,\n" + " world_size=2,\n" + " store=store\n" + " )\n" + ) + + fd.write( + f" with torch.no_grad():\n" + f" run_repro(mod, load_args, accuracy={accuracy!r}, command={command!r}, " + f"save_dir={save_dir!r}, tracing_mode={tracing_mode!r}, check_str={check_str!r})\n" + f" # To run it separately, do \n" + f" # mod, args = run_repro(mod, load_args, accuracy={accuracy!r}, command='get_args', " + f"save_dir={save_dir!r}, tracing_mode={tracing_mode!r}, check_str={check_str!r})\n" + f" # mod(*args)" + ) + + # Add distributed cleanup after run_repro + if has_distributed_ops: + fd.write("\n dist.destroy_process_group()\n") + + +def dump_compiler_graph_state( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, + *, + accuracy: Optional[Union[str, bool]] = None, +) -> None: + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") + log.warning( + "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name + ) + with open(file_name, "w") as fd: + save_graph_repro( + fd, gm, args, compiler_name, save_dir=subdir, accuracy=accuracy + ) + curdir = os.getcwd() + repro_path = os.path.join(curdir, "repro.py") + try: + shutil.copyfile(file_name, repro_path) + log.warning("Copying repro file for convenience to %s", repro_path) + if use_buck: + BuckTargetWriter(file_name).write() + except OSError: + log.warning("No write permissions for %s", repro_path) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# DUMP MINIFIER +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def dump_to_minify( + gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: str +) -> None: + out = io.StringIO() + # TODO: factor this out + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + save_graph_repro(out, gm, args, compiler_name, save_dir=subdir, command="minify") + return helper_for_dump_minify(out.getvalue()) + + +def isolate_fails( + fx_g: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, + env: Optional[dict[str, Any]] = None, + save_dir: Optional[str] = None, + accuracy: Optional[Union[bool, str]] = None, + tracing_mode: Optional[str] = None, + check_str: Optional[str] = None, +) -> bool: + if env is None: + env = {} + subdir = os.path.join(os.getcwd(), "isolate") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") + with open(file_name, "w") as fd: + save_graph_repro( + fd, + fx_g, + args, + compiler_name, + save_dir=save_dir, + command="minifier-query", + accuracy=accuracy, + tracing_mode=tracing_mode, + check_str=check_str, + ) + # with open(file_name, "r") as fd: + # print(fd.read()) + new_env = os.environ.copy() + new_env = {**new_env, **env} + if use_buck: + cmd = BuckTargetWriter(file_name).write(print_msg=False) + else: + cmd = [sys.executable, file_name] + with ( + TemporaryFile() as stdout, + TemporaryFile() as stderr, + subprocess.Popen( + cmd, + cwd=subdir, + stdout=stdout, + stderr=stderr, + env=new_env, + ) as p, + ): + p.wait() + + stdout.seek(0) + stderr.seek(0) + print( + textwrap.indent(stdout.read().decode("utf-8"), prefix=">> "), + file=sys.stdout, + ) + print( + textwrap.indent(stderr.read().decode("utf-8"), prefix=">> "), + file=sys.stderr, + ) + # print(f"Isolated test failed - {file_name}") + return p.returncode != 0 + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MINIFIER TOOLS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def inductor_fails( + fx_g: torch.fx.GraphModule, args: Sequence[Any], check_str: Optional[str] = None +) -> bool: + has_cuda = False + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + has_cuda = True + break + + def sync() -> None: + if has_cuda: + # Ensures that segfaults are surfaced + torch.cuda.synchronize() + + from torch._inductor.compile_fx import compile_fx_inner + + try: + result = fx_g(*args) + assert isinstance(result, (tuple, list)) + assert not any(isinstance(x, (tuple, list)) for x in result) + except Exception: + return False + + sync() + + try: + compile_mod = compile_fx_inner(fx_g, args) + assert not isinstance(compile_mod, str) + compile_mod(args) + sync() + except Exception as e: + if check_str is not None and check_str not in repr(e): + return False + print(repr(e)) + return True + return False + + +def inductor_accuracy_fails( + fx_g: torch.fx.GraphModule, + args: Sequence[Any], + check_str: Optional[str] = None, + *, + require_fp64: bool = False, + ignore_non_fp: bool = False, +) -> bool: + from torch._inductor.compile_fx import compile_fx_inner + + return backend_aot_accuracy_fails( + fx_g, + args, # type: ignore[arg-type] + compile_fx_inner, # type: ignore[arg-type] + require_fp64=require_fp64, + ignore_non_fp=ignore_non_fp, + ) + + +backend_aot_accuracy_fails = functools.partial(backend_accuracy_fails, only_fwd=True) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO MAIN +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def repro_common( + options: Any, mod: nn.Module, load_args: Any +) -> tuple[torch.fx.GraphModule, Sequence[Any]]: + # Invariant for graphs we generate with the repro script + assert not any(mod.named_parameters()) + for n, b in mod.named_buffers(): + if b.numel() > MAX_CONSTANT_NUMEL_INLINE: + log.warning( + "Constant %s was not serialized, generated random data instead. " + "If you think this is affecting you, please comment on " + "https://github.com/pytorch/pytorch/issues/100468", + n, + ) + + if not hasattr(load_args, "_version"): + log.warning( + "load_args does not have a _version attribute, please file a bug to PyTorch " + "and describe how you generate this repro script" + ) + else: + if load_args._version > 0: + log.warning( + "load_args is version %s, but this version of PyTorch only supports " + "version 0. We will try to run it anyway but there may be an incompatibility; " + "if so, try upgrading your version of PyTorch.", + load_args._version, + ) + + nop_reader = NopInputReader() + load_args(nop_reader) + + with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: + input_reader = InputReader(save_dir=options.save_dir, pbar=pbar) + load_args(input_reader) + args = input_reader.args + + # Turn mod into a GraphModule the slow way + # TODO: speed this up + mod = make_fx(mod, tracing_mode=options.tracing_mode)(*args) + + # pyrefly: ignore [bad-assignment] + torch._inductor.config.generate_intermediate_hooks = True + + return mod, args + + +ACCURACY_FAILS: dict[str, Callable[[torch.fx.GraphModule, Any], bool]] = { + "": inductor_fails, + # This might look inverted but it's not. strict_accuracy means "we will + # minify any time we see anything that diverges", whereas accuracy is more + # conservative, and will only minify if there is a meaningful fp64 + # divergence + "accuracy": functools.partial( + inductor_accuracy_fails, require_fp64=True, ignore_non_fp=True + ), + "strict_accuracy": inductor_accuracy_fails, +} + + +def repro_minifier_query(options: Any, mod: nn.Module, load_args: Any) -> None: + mod, args = repro_common(options, mod, load_args) + fail_fn = functools.partial( + ACCURACY_FAILS[options.accuracy], + check_str=options.check_str, # type: ignore[call-arg] + ) + if fail_fn(mod, args): + sys.exit(1) + else: + sys.exit(0) + + +def repro_minify(options: Any, mod: nn.Module, load_args: Any) -> None: + from functorch.compile import minifier + + mod, args = repro_common(options, mod, load_args) + compiler_name = "inductor_accuracy" if options.accuracy != "" else "inductor" + + favored_device = 1 if torch.cuda.device_count() >= 2 else 0 + env_variables = {"CUDA_VISIBLE_DEVICES": str(favored_device)} + + module_fails: Any + if options.isolate: + module_fails = functools.partial( + isolate_fails, + env=env_variables, + compiler_name=compiler_name, + save_dir=options.save_dir, + accuracy=options.accuracy, + tracing_mode=options.tracing_mode, + ) + else: + module_fails = ACCURACY_FAILS[options.accuracy] + + minifier( + mod, + args, + module_fails=functools.partial(module_fails, check_str=options.check_str), + dump_state=functools.partial( + dump_compiler_graph_state, compiler_name=compiler_name + ), + save_dir=options.save_dir, + offload_to_disk=options.offload_to_disk, + skip_offload=options.skip_saving_eager_intermediates, + skip_sanity=options.skip_sanity, + max_granularity=options.max_granularity, + ) + + +def repro_analyze(options: Any, mod: nn.Module, load_args: Any) -> None: + from torch._inductor.compile_fx import compile_fx_inner + from torch._inductor.hooks import intermediate_hook + + mod, args = repro_common(options, mod, load_args) + + # TODO: The logic for cloning inputs/models here is intentionally + # modeled off of run_fwd_maybe_bwd, but arguably it is better not to + # clone inputs (as you are doubling your effective GPU memory usage). + # It is certainly faster though! It probably makes sense to let the + # user specify the offload strategy. + + with tqdm(desc="Compiling"): + compiled = compile_fx_inner(mod, args) + total = counters["inductor"]["intermediate_hooks"] + + known_names = set() + + def save_hook(name: str, val: Any) -> None: + known_names.add(name) + if not options.skip_saving_inductor_intermediates: + writer.write_tensor(os.path.join("inductor", name), val) + pbar.update(1) # type: ignore[has-type] + + writer = torch.utils._content_store.ContentStoreWriter( + options.save_dir, stable_hash=options.stable_hash + ) + reader = torch.utils._content_store.ContentStoreReader(options.save_dir) + + new_args = clone_inputs(args) + with ( + intermediate_hook(save_hook), + tqdm(desc="Saving inductor intermediates", total=total) as pbar, + ): + assert not isinstance(compiled, str) + compiled(new_args) # type: ignore[arg-type] + assert not new_args + + def compare_tuples(tuple1: tuple[Any], tuple2: tuple[Any]) -> Optional[str]: + diff_indices = [i for i in range(len(tuple1)) if tuple1[i] != tuple2[i]] + diff_values = [(tuple1[i], tuple2[i]) for i in diff_indices] + + if not diff_values: + return None + else: + return " and ".join(f"{a} != {b}" for a, b in diff_values) + + def check_hook(name: str, val: Any) -> None: + meta = writer.compute_tensor_metadata(val) + meta2 = reader.read_tensor_metadata(os.path.join("inductor", name)) + reason = compare_tuples(meta, meta2) + if reason is not None: + pbar.write(f"NONDETERMINISTIC INDUCTOR at {name} ({reason})") + pbar.update(1) + + if not options.skip_check_deterministic: + new_args = clone_inputs(args) + with ( + intermediate_hook(check_hook), + tqdm(desc="Checking inductor determinism", total=total) as pbar, + ): + compiled(new_args) # type: ignore[arg-type] + assert not new_args + + class WriterInterp(fx.Interpreter): + def __init__(self, mod: torch.nn.Module, subdir: str) -> None: + super().__init__(mod) + self.subdir = subdir + + def run_node(self, n: torch.fx.Node) -> Any: + r = super().run_node(n) + name = n.name + if name in known_names: + pbar.update(1) + writer.write_tensor(os.path.join(self.subdir, name), r) + return r + + # NB: the module cast doesn't actually do anything, since there are no + # parameters/buffers on the module + if not options.skip_saving_float64_intermediates: + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) # type: ignore[arg-type] + with tqdm(desc="Saving float64 intermediates", total=total) as pbar: + WriterInterp(new_mod, "float64").boxed_run(new_args) + assert not new_args + + class ExactReaderInterp(fx.Interpreter): + def run_node(self, n: torch.fx.Node) -> Any: + r = super().run_node(n) + name = n.name + if name in known_names: + meta = writer.compute_tensor_metadata(r) + meta2 = reader.read_tensor_metadata(os.path.join("float64", name)) + reason = compare_tuples(meta, meta2) + if reason is not None: + pbar.write(f"NONDETERMINISTIC FLOAT64 at {name} ({reason})") + pbar.update(1) + return r + + # TODO: check eager determinism + + if not options.skip_check_deterministic: + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) # type: ignore[arg-type] + with tqdm(desc="Checking float64 determinism", total=total) as pbar: + ExactReaderInterp(new_mod).boxed_run(new_args) + assert not new_args + + # Now that we've saved everything, interp through the eager graph + # and do comparisons + class ReaderInterp(fx.Interpreter): + def run_node(self, n: torch.fx.Node) -> Any: + r = super().run_node(n) + name = n.name + if name in known_names: + inductor = reader.read_tensor(os.path.join("inductor", name)) + float64 = reader.read_tensor(os.path.join("float64", name)) + logged = False + + def log_error(msg: str, *args: Any) -> None: + nonlocal logged + logged = True + pbar.write(f"DIVERGED at {name}: {msg % args}") + + if not same( + r, + inductor, + float64, + tol=torch._dynamo.config.repro_tolerance, + equal_nan=True, + log_error=log_error, + ): + assert logged + pbar.update(1) + return r + + with tqdm(desc="Checking divergence", total=total) as pbar: + ReaderInterp(mod).boxed_run(args) + assert not args + + +def repro_get_args( + options: Any, mod: nn.Module, load_args: Any +) -> tuple[torch.fx.GraphModule, list[Any]]: + mod, args = repro_common(options, mod, load_args) + return mod, args # type: ignore[return-value] + + +def repro_run(options: Any, mod: nn.Module, load_args: Any) -> None: + from torch._inductor.compile_fx import compile_fx_inner + + mod, args = repro_common(options, mod, load_args) + + from torch.cuda import synchronize + + compiled = compile_fx_inner(mod, args) + assert not isinstance(compiled, str) + + if options.accuracy != "": + # We don't really respect --accuracy vs --strict-accuracy here, it + # seems counterintuitive + if not same_two_models( + mod, + compiled, # type: ignore[arg-type] + args, + only_fwd=True, + ignore_non_fp=config.repro_ignore_non_fp, + ): + raise AccuracyError("Bad accuracy detected") + else: + need_sync = False + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + need_sync = True + break + + compiled(list(args)) + + if need_sync: + synchronize() # ensure segfaults are surfaced + + +# TODO: lazily load the inputs or something, rather than cloning them +def run_repro( + mod: nn.Module, + load_args: Any, + *, + command: str = "run", + accuracy: Union[bool, str] = "", + save_dir: Optional[str] = None, + tracing_mode: Optional[str] = None, + patch_code: Optional[str] = None, + check_str: Optional[str] = None, + **kwargs: Any, +) -> Any: + for k in kwargs: + log.warning( + "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", + k, + ) + + if accuracy is True: + accuracy = "accuracy" + elif accuracy is False: + accuracy = "" + + if patch_code is not None: + log.warning( + "patch_code no longer works on this version of PyTorch, silently ignoring" + ) + + parser = argparse.ArgumentParser( + description=f"""\ +An after_aot repro script, typically triggering a bug in PyTorch Inductor. +When run with no arguments, this script defaults to running '{command}'. +Extra flags may be available; to find out more, try '{command} --help'. +There are also alternate subcommands available, see below. + +default settings on this script: + {accuracy=} + {tracing_mode=} + {save_dir=} + {check_str=} +""", + formatter_class=argparse.RawTextHelpFormatter, + ) + + def common_flags(parser: argparse.ArgumentParser) -> None: + accuracy_group = parser.add_mutually_exclusive_group() + accuracy_group.add_argument( + "--no-accuracy", + dest="accuracy", + action="store_const", + const="", + default=accuracy, + help="do not test accuracy, just run the module and see if it errors", + ) + accuracy_group.add_argument( + "--accuracy", + action="store_const", + const="accuracy", + default=accuracy, + help="""\ +test if the RMSE between the compiled module and the fp64 reference is greater +than eager and the fp64 reference. This is usually more reliable than the +standard allclose test, as we expect numeric differences from compiling, often +improving accuracy over eager. RMSE test allows for compiled module to +diverge greatly from eager, as long as this divergence moves it closer to the +'true' mathematical value of the network. Caveats: (1) double precision can +still suffer from rounding error, so it is not a perfect reference (see for +example 'Herbie: Automatically Improving Floating Point Accuracy') for +approaches that detect the necessary working precision and compute it in +arbitrary precision floating point; unfortunately, this is not practical for +tensor computation; (2) if there are not enough samples in the output being +compared, we may get unlucky and have an unlucky greater RMSE than eager; this +could be overcome by applying a more rigorous statistical test at some +p-value, which we leave for future work. +""", + ) + accuracy_group.add_argument( + "--strict-accuracy", + dest="accuracy", + action="store_const", + const="strict_accuracy", + default=accuracy, + help="""\ +by default, when doing accuracy minification we will reject reductions which +change the divergence from a floating point divergence to a integral/boolean +divergence. This is because some operations like ReLU involve temporarily +sharp boundaries that smooth out again afterwards; without requiring +divergence on floating point, the minifier will often fixate on divergent +boolean tensor even though this is not the true source of the divergence. +However, rejecting these reductions makes it more difficult for the minifier +to make process. Using this option will let the minifier progress for ALL +divergences--you just might not end up with a useful repro in the end.""", + ) + + parser.add_argument( + "--save-dir", + type=str, + default=save_dir, + metavar="DIR", + help="directory where saved inputs live", + ) + parser.add_argument( + "--no-save-dir", + dest="save_dir", + action="store_const", + const=None, + help="don't use any directory for saved inputs", + ) + parser.add_argument( + "--tracing-mode", + type=str, + metavar="{real,fake,symbolic}", + default=tracing_mode, + help="how to trace the repro module into a GraphModule with metadata", + ) + + subparsers = parser.add_subparsers( + dest="command", metavar="{run,minify,analyze}", required=True + ) + + parser_run = subparsers.add_parser( + "run", + help="just run the repro", + ) + common_flags(parser_run) + + parser_minify = subparsers.add_parser( + "minify", help="run the minifier on the repro" + ) + common_flags(parser_minify) + parser_get_args = subparsers.add_parser("get_args", help="get the args") + common_flags(parser_get_args) + parser_minify_isolate = parser_minify.add_mutually_exclusive_group() + parser_minify_isolate.add_argument( + "--isolate", + action="store_true", + default=True, + help="run in separate processes to avoid interference (default)", + ) + parser_minify_isolate.add_argument( + "--no-isolate", + dest="isolate", + action="store_false", + help="speed up by running all compilation in same process", + ) + parser_minify.add_argument( + "--skip-saving-eager-intermediates", + action="store_true", + help="skip saving eager intermediates on --minify", + ) + # TODO: make this an option for --analyze too + parser_minify.add_argument( + "--offload-to-disk", + action="store_true", + help="during minification, offload delta debugging intermediates to disk. Use if you're OOMing", + ) + parser_minify.add_argument( + "--skip-sanity", + action="store_true", + help="skip sanity check at beginning of minification on original graph", + ) + parser_minify.add_argument( + "--max-granularity", + type=int, + default=None, + help="start at this granularity and work down; must be power of 2", + ) + parser_minify.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + + parser_analyze = subparsers.add_parser( + "analyze", help="run the accuracy analyzer on the repro" + ) + common_flags(parser_analyze) + parser_analyze.add_argument( + "--skip-saving-inductor-intermediates", + action="store_true", + help="skip saving inductor intermediates on --analyze", + ) + parser_analyze.add_argument( + "--skip-saving-float64-intermediates", + action="store_true", + help="skip saving float64 intermediates", + ) + parser_analyze.add_argument( + "--skip-check-deterministic", + action="store_true", + help="skip checking that the network is deterministic", + ) + parser_analyze.add_argument( + "--stable-hash", + action="store_true", + help="use SHA-1 checksum instead of fast (but possibly unsound) hash", + ) + + # Run the repro in the context of minification, inverting exit code meaning + parser_minifier_query = subparsers.add_parser( + "minifier-query", + ) + common_flags(parser_minifier_query) + parser_minifier_query.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + + args = None + if len(sys.argv) <= 1: + args = [command, *sys.argv[1:]] + + options = parser.parse_args(args) + COMMAND_FNS = { + "minify": repro_minify, + "analyze": repro_analyze, + "minifier-query": repro_minifier_query, + "run": repro_run, + "get_args": repro_get_args, + } + return COMMAND_FNS[options.command](options, mod, load_args) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py new file mode 100644 index 0000000000000000000000000000000000000000..a17518fc6c74d7c64477964f3fc7d1176fc67019 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py @@ -0,0 +1,637 @@ +""" +Utilities for reproducing and debugging issues in Dynamo after graph capture. + +This file provides tools and infrastructure for debugging problems that occur +after Dynamo has captured the graph but before/during backend compilation. +Key components include: + +- Minification tools to reduce large graphs to minimal failing examples +- Accuracy testing to validate compiled graph outputs match eager mode +- Repro generation to create standalone reproduction scripts +- Debug backends for capturing and analyzing failures +- Utilities for saving/loading graph states and inputs + +The tools here focus specifically on the post-graph-capture stage, making them +useful for debugging backend compilation issues, AOTAutograd problems, and +accuracy discrepancies between compiled and eager execution. +""" + +import argparse +import copy +import functools +import logging +import os +import shutil +import sys +import textwrap +from collections.abc import Callable, Sequence +from importlib import import_module +from typing import Any, Optional, Union + +import torch +import torch.fx as fx +from torch._dynamo.debug_utils import ( + AccuracyError, + backend_accuracy_fails, + BUCK_CMD_PREFIX, + BuckTargetWriter, + extra_imports, + generate_config_string, + generate_env_vars_string, + helper_for_dump_minify, + InputReader, + InputWriter, + minifier_dir, + NNModuleToString, + NopInputReader, + run_fwd_maybe_bwd, + same_two_models, +) +from torch.fx.experimental.symbolic_shapes import fx_placeholder_targets +from torch.hub import tqdm + +from .. import config +from ..backends.registry import CompilerFn, lookup_backend, register_debug_backend +from ..debug_utils import clone_inputs_retaining_gradness + + +log = logging.getLogger(__name__) + + +inductor_config = import_module("torch._inductor.config") +use_buck = inductor_config.is_fbcode() + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MAIN ENTRY POINT +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def _accuracy_fails( + gm: torch.fx.GraphModule, + example_inputs: Sequence[Any], + compiler_fn: Callable[[torch.fx.GraphModule, list[Any]], torch.fx.GraphModule], +) -> bool: + return backend_accuracy_fails( + gm, + example_inputs, + compiler_fn, + only_fwd=config.repro_forward_only, + ignore_non_fp=config.repro_ignore_non_fp, + ) + + +class WrapBackendDebug: + def __init__( + self, unconfigured_compiler_fn: CompilerFn, compiler_name: Optional[str] + ) -> None: + functools.wraps(unconfigured_compiler_fn)(self) + self._torchdynamo_orig_backend = unconfigured_compiler_fn + self._compiler_name = compiler_name + if hasattr(unconfigured_compiler_fn, "__name__"): + self.__name__ = unconfigured_compiler_fn.__name__ + if hasattr(unconfigured_compiler_fn, "compiler_name"): + self.__name__ = unconfigured_compiler_fn.compiler_name # type: ignore[attr-defined] + if hasattr(unconfigured_compiler_fn, "get_compiler_config"): + self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] + + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: list[Any], **kwargs: Any + ) -> torch.fx.GraphModule: + compiler_fn = functools.partial(self._torchdynamo_orig_backend, **kwargs) + assert config.repro_after in ("dynamo", "aot", None) + + if config.repro_after == "dynamo": + + def add_paths(exc: Exception) -> None: + exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py") # type: ignore[attr-defined] + if use_buck: + exc.buck_command = " ".join( # type: ignore[attr-defined] + BUCK_CMD_PREFIX + + [BuckTargetWriter(exc.minifier_path).cmd_line_path] # type: ignore[attr-defined] + ) + + if config.repro_level == 3: + dump_to_minify_after_dynamo(gm, example_inputs, self._compiler_name) + + # Check for either accuracy (level 4) or other type of failures. + if config.repro_level == 4: + # Check Accuracy + compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) + if _accuracy_fails(gm, example_inputs, compiler_fn): # type: ignore[arg-type] + log.warning( + "Accuracy failed for the TorchDynamo produced graph. Creating script to minify the error." + ) + dump_to_minify_after_dynamo( + fx.GraphModule(gm, copy.deepcopy(gm.graph)), + example_inputs, + self._compiler_name, + ) + exc = AccuracyError("Bad accuracy detected.") + add_paths(exc) + raise exc + else: + try: + compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) + run_fwd_maybe_bwd(compiled_gm, example_inputs) # type: ignore[arg-type] + except Exception as exc: + log.warning( + "Compiled Fx GraphModule failed. Creating script to minify the error." + ) + if config.repro_level == 1: + dump_state_fn = functools.partial( + dump_backend_state, compiler_name=self._compiler_name + ) + dump_state_fn( + fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs + ) + elif config.repro_level == 2: + dump_to_minify_after_dynamo( + fx.GraphModule(gm, copy.deepcopy(gm.graph)), + example_inputs, + self._compiler_name, + ) + add_paths(exc) + raise + else: + compiled_gm = compiler_fn(gm, example_inputs) + + return compiled_gm # type: ignore[return-value] + + +def wrap_backend_debug( + unconfigured_compiler_fn: CompilerFn, compiler_name: Optional[str] +) -> WrapBackendDebug: + """ + A minifier decorator that wraps the TorchDynamo produced Fx graph modules. + As opposed to wrap_compiler_debug, this wrapper intercepts at the + TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some + level, e.g., it is useful for minifying issues related to Aot Autograd + tracing. If an error is found, we minify and save the minified repro in + repro.tar.gz. + """ + return WrapBackendDebug(unconfigured_compiler_fn, compiler_name) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO DUMPERS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def generate_dynamo_fx_repro_string( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: Optional[str], + check_accuracy: bool = False, + *, + stable_output: bool = False, + save_dir: Optional[str] = None, + command: str = "run", +) -> str: + """ + Generate a repro string for backend-agnostic minified version. + """ + + model_str = NNModuleToString.convert(gm) + + # TODO: Figure out why torch.compile'd hash isn't work on this codepath + writer = InputWriter(save_dir, stable_hash=True) + for placeholder, arg in zip(fx_placeholder_targets(gm), args): + if isinstance(arg, (int, torch.SymInt)): + writer.symint(placeholder, arg) + elif isinstance(arg, torch.Tensor): + # TODO: improve these names with FQN + writer.tensor(placeholder, arg) + else: + raise TypeError(f"arg is neither SymInt/int nor torch.Tensor, {arg}") + load_args = "\n".join(writer.lines()) + + return textwrap.dedent( + f""" +{generate_env_vars_string(stable_output=stable_output)} +from math import inf +import torch +from torch import tensor, device +import torch.fx as fx +import torch._dynamo +from torch._dynamo.testing import rand_strided +from torch._dynamo.debug_utils import run_fwd_maybe_bwd + +{generate_config_string(stable_output=stable_output)} + +{extra_imports} + +{model_str} +mod = Repro() + +{load_args} + +if __name__ == '__main__': + from torch._dynamo.repro.after_dynamo import run_repro + run_repro(mod, load_args, accuracy={check_accuracy!r}, command={command!r}, + save_dir={save_dir!r}, autocast={torch.is_autocast_enabled()!r}, backend={compiler_name!r}) +""" + ) + + +def dump_backend_repro_as_file( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: Optional[str], + check_accuracy: bool = False, +) -> None: + """ + Saves the repro to a repro.py file + """ + curdir = os.getcwd() + subdir = os.path.join(os.getcwd(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"minified_{len(gm.graph.nodes)}_nodes.py") + log.warning( + "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name + ) + + with open(file_name, "w") as fd: + fd.write( + generate_dynamo_fx_repro_string( + gm, args, compiler_name, check_accuracy, save_dir=subdir + ) + ) + latest_repro = os.path.join(curdir, "repro.py") + log.warning("Copying %s to %s for convenience", file_name, latest_repro) + + if use_buck: + BuckTargetWriter(latest_repro).write() + + shutil.copyfile(file_name, latest_repro) + + +def dump_backend_state( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: Optional[str], + check_accuracy: bool = False, +) -> None: + """ + Dumps the dynamo graph to repro the issue. + 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a + repro.py file. + 2) If we can't convert Fx GraphModule to a string, we use to_folder to save + the module and save a tar file. + """ + assert NNModuleToString.can_convert_to_string(gm) + return dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy) + # return dump_backend_repro_as_tarfile(gm, args, compiler_name) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MINIFIER DUMPER +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def dump_to_minify_after_dynamo( + gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: Optional[str] +) -> None: + # TODO: factor this out + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + helper_for_dump_minify( + generate_dynamo_fx_repro_string( + gm, + args, + compiler_name, + check_accuracy=config.repro_level == 4, + save_dir=subdir, + command="minify", + ) + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MINIFIER BACKENDS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@register_debug_backend # type: ignore[arg-type] +def dynamo_minifier_backend( + gm: fx.GraphModule, example_inputs: Sequence[Any], compiler_name: Optional[str] +) -> fx.GraphModule: + from functorch.compile import minifier + + compiler_fn = lookup_backend(compiler_name) # type: ignore[arg-type] + + # TODO: It's inconsistent to pass SymInt inputs but REAL tensors. + # We should pass ints and look at the GraphModule placeholders + # to resolve them to SymInt (if necessary) + example_inputs = [ + i.node.hint if isinstance(i, torch.SymInt) else i for i in example_inputs + ] + + try: + compiled_gm = compiler_fn(gm, example_inputs) + run_fwd_maybe_bwd(compiled_gm, example_inputs) # type: ignore[arg-type] + raise ValueError("No issue was detected") + except Exception as exc: + orig_failure = str(exc) + log.warning( + "Compiled Fx GraphModule failed. Creating script to minify the error." + ) + dump_state_fn = functools.partial( + dump_backend_state, compiler_name=compiler_name + ) + dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) + fails_fn = functools.partial( + backend_fails, + compiler_fn=compiler_fn, + orig_failure=orig_failure, + ) + minifier( + gm, + example_inputs, + module_fails=fails_fn, + dump_state=dump_state_fn, + ) + return gm + + +@register_debug_backend # type: ignore[arg-type] +def dynamo_accuracy_minifier_backend( + gm: fx.GraphModule, example_inputs: Sequence[Any], compiler_name: Optional[str] +) -> fx.GraphModule: + from functorch.compile import minifier + + compiler_fn = lookup_backend(compiler_name) # type: ignore[arg-type] + + # Set the eval mode to remove randomness. + gm.eval() + + # Check Accuracy + if _accuracy_fails(gm, example_inputs, compiler_fn): # type: ignore[arg-type] + log.warning("Accuracy failed for the TorchDynamo produced graph") + dump_state_fn = functools.partial( + dump_backend_state, compiler_name=compiler_name, check_accuracy=True + ) + fails_fn = functools.partial( + _accuracy_fails, + compiler_fn=compiler_fn, # type: ignore[arg-type] + ) + dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) + minifier( + gm, + example_inputs, + module_fails=fails_fn, + dump_state=dump_state_fn, + ) + else: + log.error("Input graph does not fail accuracy testing") + return gm + + +def backend_fails( + gm: fx.GraphModule, + example_inputs: Sequence[Any], + compiler_fn: CompilerFn, + orig_failure: Sequence[Any], +) -> bool: + """ + Minifier uses this function to identify if the minified graph module fails + with the same error. + + One caveat is that minifier can potentially go into a wrong direction when + the resulting graph module fails for a different reason. To avoid this, we + save the string for the original exception and check similarity between new + and old exception. They can be somewhat different in some cases, when the + exception string depends on the failing node information. So, we have a + loose similarity metric to guide the minifier path. + """ + from difflib import SequenceMatcher + + try: + # Run the original gm to check eager validity + run_fwd_maybe_bwd(gm, clone_inputs_retaining_gradness(example_inputs)) + compiled_gm = compiler_fn(gm, example_inputs) # type: ignore[arg-type] + run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs)) # type: ignore[arg-type] + except Exception as e: + new_failure = str(e) + if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5: + return True + return False + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO MAIN +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def run_load_args(options: Any, mod: torch.nn.Module, load_args: Any) -> list[Any]: + if not hasattr(load_args, "_version"): + log.warning( + "load_args does not have a _version attribute, please file a bug to PyTorch " + "and describe how you generate this repro script" + ) + else: + if load_args._version > 0: + log.warning( + "load_args is version %s, but this version of PyTorch only supports " + "version 0. We will try to run it anyway but there may be an incompatibility; " + "if so, try upgrading your version of PyTorch.", + load_args._version, + ) + + nop_reader = NopInputReader() + load_args(nop_reader) + + with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: + input_reader = InputReader(save_dir=options.save_dir, pbar=pbar) + load_args(input_reader) + args = input_reader.args + + return args + + +def repro_minify(options: Any, mod: torch.nn.Module, load_args: Any) -> None: + args = run_load_args(options, mod, load_args) + + # Setup debug minifier compiler + if not options.accuracy: + compiler_fn = lookup_backend("dynamo_minifier_backend") + else: + compiler_fn = lookup_backend("dynamo_accuracy_minifier_backend") + + if options.backend is None: + raise RuntimeError( + "Compiler name is None - this likely means that a custom compiler " + "was called by torchdynamo. Please remove this error, import your " + "custom compiler function, and replace the backend=None " + "line in run_repro to backend=" + ) + + dynamo_minifier_backend = functools.partial( + compiler_fn, + compiler_name=options.backend, # type: ignore[call-arg] + ) + opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod) + + with torch.amp.autocast("cuda", enabled=options.autocast): + opt_mod(*args) + + +def repro_run(options: Any, mod: torch.nn.Module, load_args: Any) -> None: + opt_mod = torch._dynamo.optimize(options.backend)(mod) + + if options.accuracy != "": + mod.eval() + opt_mod.eval() # type: ignore[union-attr] + + with torch.amp.autocast("cuda", enabled=options.autocast): + # TODO: disable clone + args = run_load_args(options, mod, load_args) + assert same_two_models(mod, mod, args), "Eager itself failed" # type: ignore[arg-type] + if not same_two_models( + mod, # type: ignore[arg-type] + opt_mod, # type: ignore[arg-type] + args, + only_fwd=config.repro_forward_only, + ignore_non_fp=config.repro_ignore_non_fp, + ): + raise AccuracyError("Dynamo failed") + else: + with torch.amp.autocast("cuda", enabled=options.autocast): + args = run_load_args(options, mod, load_args) + run_fwd_maybe_bwd(mod, args, only_fwd=options.only_fwd, disable_clone=True) # type: ignore[arg-type] + del args + + args = run_load_args(options, mod, load_args) + run_fwd_maybe_bwd( + opt_mod, # type: ignore[arg-type] + args, + only_fwd=options.only_fwd, + disable_clone=True, # type: ignore[arg-type] + ) + + +def run_repro( + mod: torch.nn.Module, + load_args: Any, + *, + command: str = "run", + accuracy: Union[bool, str] = "", + save_dir: Optional[str] = None, + autocast: bool = False, + backend: str = "inductor", + **kwargs: Any, +) -> None: + for k in kwargs: + log.warning( + "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", + k, + ) + + if accuracy is True: + accuracy = "accuracy" + elif accuracy is False: + accuracy = "" + + parser = argparse.ArgumentParser( + description=f"""\ +An after_dynamo repro script, typically triggering a bug in Dynamo or +AOTAutograd. When run with no arguments, this script defaults to running +'{command}'. Extra flags may be available; to find out more, try '{command} +--help'. There are also alternate subcommands available, see below. + +default settings on this script: + {accuracy=} + {save_dir=} +""", + formatter_class=argparse.RawTextHelpFormatter, + ) + + def common_flags(parser: argparse.ArgumentParser) -> None: + accuracy_group = parser.add_mutually_exclusive_group() + accuracy_group.add_argument( + "--no-accuracy", + dest="accuracy", + action="store_const", + const="", + default=accuracy, + help="do not test accuracy, just run the module and see if it errors", + ) + accuracy_group.add_argument( + "--accuracy", + action="store_const", + const="accuracy", + default=accuracy, + help="test accuracy", + ) + parser.add_argument( + "--save-dir", + type=str, + default=save_dir, + metavar="DIR", + help="directory where saved inputs live", + ) + parser.add_argument( + "--no-save-dir", + dest="save_dir", + action="store_const", + const=None, + help="don't use any directory for saved inputs", + ) + parser.add_argument( + "--no-isolate", + dest="isolate", + action="store_false", + default=False, + help="no isolate (doesn't do anything for after_dynamo)", + ) + parser.add_argument( + "--autocast", + default=autocast, + action="store_true", + help="use torch.cuda.amp.autocast", + ) + parser.add_argument( + "--no-autocast", + dest="autocast", + action="store_false", + help="don't use torch.cuda.amp.autocast", + ) + parser.add_argument( + "--backend", + type=str, + default=backend, + metavar="BACKEND", + help="torch.compile backend to use", + ) + + subparsers = parser.add_subparsers( + dest="command", metavar="{run,minify}", required=True + ) + + parser_run = subparsers.add_parser( + "run", + help="just run the repro", + ) + common_flags(parser_run) + parser_run.add_argument( + "--only-fwd", + action="store_true", + help="don't run backwards compilation for testing", + ) + + parser_minify = subparsers.add_parser( + "minify", help="run the minifier on the repro" + ) + common_flags(parser_minify) + + args = None + if len(sys.argv) <= 1: + args = [command, *sys.argv[1:]] + + options = parser.parse_args(args) + COMMAND_FNS = { + "minify": repro_minify, + "run": repro_run, + } + COMMAND_FNS[options.command](options, mod, load_args) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/aoti.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/aoti.py new file mode 100644 index 0000000000000000000000000000000000000000..d1f556787695c92b070166c364a3fbf85e262631 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/repro/aoti.py @@ -0,0 +1,661 @@ +""" +Utilities for debugging and reproducing issues in Ahead of Time with Inductor (AOTI) compilation. + +This file provides tools and utilities for: +- Generating minimal reproducible test cases (minification) +- Handling exported programs and graph modules +- Creating debug repros for AOTI compilation issues +- Supporting both accuracy testing and error reproduction +- Managing configuration and environment for repro cases + +The main components include: +- Minification tools to reduce test cases while preserving errors +- Repro generation utilities for exported programs +- Error handling specific to AOTI compilation +- Command-line interface for running and managing repros +""" + +import argparse +import functools +import io +import logging +import os +import re +import shutil +import sys +import textwrap +from collections.abc import Sequence +from importlib import import_module +from typing import Any, IO, Optional, Union + +import torch +from torch._dynamo.debug_utils import ( + _cuda_system_info_comment, + BuckTargetWriter, + extra_imports, + generate_config_string, + generate_env_vars_string, + helper_for_dump_minify, + InputReader, + minifier_dir, + NNModuleToString, + NopInputReader, +) +from torch.export import ExportedProgram +from torch.hub import tqdm + + +log = logging.getLogger(__name__) + + +inductor_config = import_module("torch._inductor.config") +use_buck = inductor_config.is_fbcode() + + +class AOTIMinifierError(Exception): + def __init__(self, original_exception: Union[str, Exception]) -> None: + additional_message = "This error is caused by a bug in the AOTI minifier, please report a bug to PyTorch" + full_message = f"{additional_message}: {str(original_exception)}" + super().__init__(full_message) + self.original_exception = original_exception + + +def dump_to_minify( + exported_program: ExportedProgram, + compiler_name: str, + command: str = "minify", + options: Optional[dict[str, Any]] = None, +) -> None: + """ + If command is "minify": + Dump exported_program to `debug_dir/minifier/minifier_launcher.py`, with minify command. + If command is "run": + Dump exported_program to `cwd/repro.py`, with run command. + """ + assert command in ["minify", "run"] + + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + + if command == "minify": + out = io.StringIO() + save_graph_repro_ep( + out, + compiler_name, + exported_program=exported_program, + save_dir=subdir, + command="minify", + config_patches=options, + ) + return helper_for_dump_minify(out.getvalue()) + else: + curdir = os.getcwd() + file_name = os.path.join(curdir, "repro.py") + try: + with open(file_name, "w") as fd: + save_graph_repro_ep( + fd, + compiler_name, + exported_program=exported_program, + config_patches=options, + save_dir=subdir, + command="run", + module_in_comment=True, + ) + log.warning("Writing repro file to %s", file_name) + if use_buck: + BuckTargetWriter(file_name).write() + except OSError: + log.warning("No write permissions for %s", file_name) + + +def get_module_string(gm: torch.fx.GraphModule) -> str: + def _convert_to_comment(s_: str) -> str: + s = s_.split("\n") + if len(s) == 1: + return "# " + s_ + first = s.pop(0) + for i in range(len(s)): + line = s[i] + if line.strip() != "": + s[i] = "# " + line + else: + s[i] = "" + s = "\n".join(s) + s = first + "\n" + s + return s + + module_string = NNModuleToString.convert(gm) + return _convert_to_comment(module_string) + + +def save_graph_repro_ep( + fd: IO[Any], + compiler_name: str, + *, + exported_program: Optional[ExportedProgram] = None, + gm: Optional[torch.nn.Module] = None, + args: Optional[tuple[Any]] = None, + config_patches: Optional[dict[str, str]] = None, + stable_output: bool = False, + save_dir: Optional[str] = None, + command: str = "run", + accuracy: Optional[Union[str, bool]] = None, + check_str: Optional[str] = None, + module_in_comment: bool = False, + strict: bool = False, +) -> None: + # Save graph for reproducing the error. + # Either exported_program or gm will be saved, depending on which one is defined. + # Only one of exported_program and gm should be defined. + + if exported_program is None and gm is None: + raise AOTIMinifierError("One of exported_program and gm must be defined") + if exported_program is not None and gm is not None: + raise AOTIMinifierError("Only one of exported_program and gm can be defined") + if gm is not None and args is None: + raise AOTIMinifierError("If gm is defined, args should also be defined") + + if exported_program is None: + assert gm is not None + assert args is not None + exported_program = torch.export.export(gm, args, strict=strict) + elif gm is None: + gm = exported_program.module(check_guards=False) + + # save a graph preview using gm + module_string = get_module_string(gm) # type: ignore[arg-type] + fd.write(module_string) + + # save a graph repro using exported_program + fd.write( + generate_compiler_repro_exported_program( + exported_program, + options=config_patches, + stable_output=stable_output, + save_dir=save_dir, + ) + ) + if accuracy is None: + accuracy = "_accuracy" in compiler_name + fd.write("if __name__ == '__main__':\n") + fd.write(" from torch._dynamo.repro.aoti import run_repro\n") + fd.write( + f" with torch.no_grad():\n" + f" run_repro(exported_program, config_patches=config_patches, accuracy={accuracy!r}, command={command!r}, " + f"save_dir={save_dir!r}, check_str={check_str!r})\n" + ) + + +def dump_compiler_graph_state( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, + *, + config_patches: Optional[dict[str, str]] = None, + accuracy: Optional[Union[str, bool]] = None, + strict: bool = False, +) -> None: + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") + log.warning( + "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name + ) + with open(file_name, "w") as fd: + save_graph_repro_ep( + fd, + compiler_name, + gm=gm, + args=tuple(args), + config_patches=config_patches, + save_dir=subdir, + accuracy=accuracy, + module_in_comment=True, + strict=strict, + ) + curdir = os.getcwd() + repro_path = os.path.join(curdir, "repro.py") + try: + shutil.copyfile(file_name, repro_path) + log.warning("Copying repro file for convenience to %s", repro_path) + if use_buck: + BuckTargetWriter(file_name).write() + except OSError: + log.warning("No write permissions for %s", repro_path) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# DUMP REPROS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def generate_compiler_repro_exported_program( + exported_program: ExportedProgram, + *, + options: Optional[dict[str, str]] = None, + stable_output: bool = False, + save_dir: Optional[str] = None, +) -> str: + model_str = textwrap.dedent( + f""" +{generate_env_vars_string(stable_output=stable_output)} +import torch +import torch._inductor.inductor_prims + +{generate_config_string(stable_output=stable_output)} + +isolate_fails_code_str = None + +{extra_imports} + + """ + ) + if not stable_output: + model_str += f"# torch version: {torch.version.__version__}\n" + if hasattr(torch.version, "cuda"): + model_str += f"# torch cuda version: {torch.version.cuda}\n" + if hasattr(torch.version, "git_version"): + model_str += f"# torch git version: {torch.version.git_version}\n\n\n" + model_str += _cuda_system_info_comment() + if save_dir: + ep_path = os.path.join(save_dir, "exported_program.pt2") + else: + ep_path = "exported_program.pt2" + torch.export.save(exported_program, ep_path) + + model_str += f"exported_program = torch.export.load('{ep_path}')\n" + model_str += "# print(exported_program.graph)\n" + model_str += f"config_patches={options}\n" + return model_str + + +def repro_load_args(load_args: Any, save_dir: Optional[str]) -> tuple[Any]: + if not hasattr(load_args, "_version"): + log.warning( + "load_args does not have a _version attribute, please file a bug to PyTorch " + "and describe how you generate this repro script" + ) + else: + if load_args._version > 0: + log.warning( + "load_args is version %s, but this version of PyTorch only supports " + "version 0. We will try to run it anyway but there may be an incompatibility; " + "if so, try upgrading your version of PyTorch.", + load_args._version, + ) + + nop_reader = NopInputReader() + load_args(nop_reader) + + with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: + input_reader = InputReader(save_dir=save_dir, pbar=pbar) + load_args(input_reader) + args = input_reader.args + + return tuple(args) + + +def repro_common( + options: Any, exported_program: ExportedProgram +) -> tuple[torch.fx.GraphModule, Any, Any]: + # pyrefly: ignore [bad-assignment] + torch._inductor.config.generate_intermediate_hooks = True + mod = exported_program.module(check_guards=False) + args, kwargs = exported_program.example_inputs + return mod, args, kwargs # type: ignore[return-value] + + +def repro_get_args( + options: Any, + exported_program: ExportedProgram, + config_patches: Optional[dict[str, Any]], +) -> tuple[torch.fx.GraphModule, Any, Any]: + mod, args, kwargs = repro_common(options, exported_program) + return mod, args, kwargs + + +def repro_run( + options: Any, + exported_program: ExportedProgram, + config_patches: Optional[dict[str, Any]], +) -> None: + from torch._inductor import _aoti_compile_and_package_inner + + gm, args, kwargs = repro_common(options, exported_program) + + from torch.cuda import synchronize + + _aoti_compile_and_package_inner( + gm, + args, + kwargs, + load_and_run=True, + check_accuracy=options.accuracy, + inductor_configs=config_patches, + ) + + need_sync = False + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + need_sync = True + break + + if need_sync: + synchronize() # ensure segfaults are surfaced + + +def export_for_aoti_minifier( + gm: torch.nn.Module, + tuple_inputs: tuple[Any], + strict: bool = False, + skip_export_error: bool = True, +) -> Optional[torch.nn.Module]: + # Some graphs cannot be used for AOTI/export (illegal graphs), these should be + # considered as graphs that don't fail in the minifier, so the minifier keeps searching. + # In these case, we return None. Otherwise, we return the exported graph module. + # This won't affect the minifier result because the minifier is only responsible for catching + # errors in AOTI, not export. + # + # Please add to this list of illegal graphs if you change the implementation here. + # - graph output is not allowed by export + # + # If skip_export_error=True, then the errors in export will not be raised, and the minifier + # will keep exploring and ignore this graph. + from torch._dynamo.exc import UserError, UserErrorType + + try: + ep = torch.export.export(gm, tuple_inputs, strict=strict) + gm = ep.module(check_guards=False) + return gm + except Exception as e: + if skip_export_error: + return None + if isinstance(e, UserError) and e.error_type == UserErrorType.INVALID_OUTPUT: + # graph output is not allowed by export when strict=True + return None + if isinstance(e, RuntimeError): + # graph output is not allowed by export when strict=False + pattern = r"Found .* in output, which is not a known type\." + if re.search(pattern, str(e)) is not None: + return None + raise AOTIMinifierError(e) from e + # we should never reach here + return None + + +def repro_minify( + options: Any, + exported_program: ExportedProgram, + config_patches: Optional[dict[str, Any]], +) -> None: + from functorch.compile import minifier + from torch._inductor import _aoti_compile_and_package_inner + from torch._inductor.compile_fx import _aoti_flatten_inputs + + mod, args, kwargs = repro_common(options, exported_program) + + # update serialized_in_spec and serialized_out_spec + flat_example_inputs, inductor_configs = _aoti_flatten_inputs( + mod, args, kwargs, options=config_patches + ) + compiler_name = "aot_inductor" + assert options.minifier_export_mode in ["dynamo", "python"] + strict = options.minifier_export_mode == "dynamo" + skip_export_error = options.skip_export_error + + from torch.cuda import synchronize + + need_sync = False + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + need_sync = True + break + + def module_fails( + gm: torch.fx.GraphModule, + flat_example_inputs: list[Any], + check_str: Optional[str] = None, + ) -> bool: + # Need to export first so the in_spec and out_spec are populated + tuple_inputs = tuple(flat_example_inputs) + # pyrefly: ignore [bad-assignment] + gm = export_for_aoti_minifier( + gm, tuple_inputs, strict=strict, skip_export_error=skip_export_error + ) + + # Some graphs cannot be used for AOTI/export (illegal graphs), these should be + # considered as graphs that don't fail in the minifier, so the minifier keeps searching. + if gm is None: + return False + + assert isinstance(gm, torch.fx.GraphModule) + + try: + _aoti_compile_and_package_inner( + gm, + tuple_inputs, + load_and_run=True, + check_accuracy=options.accuracy, + inductor_configs=inductor_configs, + ) + if need_sync: + synchronize() # ensure segfaults are surfaced + return False + except Exception as e: + if check_str is not None and check_str not in repr(e): + return False + return True + + minifier( + mod, + flat_example_inputs, + module_fails=functools.partial(module_fails, check_str=options.check_str), + dump_state=functools.partial( + dump_compiler_graph_state, + compiler_name=compiler_name, + config_patches=config_patches, + accuracy=options.accuracy, + strict=strict, + ), + save_dir=options.save_dir, + offload_to_disk=options.offload_to_disk, + skip_offload=options.skip_saving_eager_intermediates, + skip_sanity=options.skip_sanity, + max_granularity=options.max_granularity, + ) + + +def run_repro( + exported_program: ExportedProgram, + *, + config_patches: Optional[dict[str, str]] = None, + command: str = "run", + accuracy: Union[bool, str] = "", + save_dir: Optional[str] = None, + tracing_mode: Optional[str] = None, + check_str: Optional[str] = None, + minifier_export_mode: str = "python", + skip_export_error: bool = True, + **more_kwargs: Any, +) -> Any: + for k in more_kwargs: + log.warning( + "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", + k, + ) + + if accuracy is True: + accuracy = "accuracy" + elif accuracy is False: + accuracy = "" + + parser = argparse.ArgumentParser( + description=f"""\ +An AOTI repro script, typically triggering a bug in PyTorch AOTInductor. +When run with no arguments, this script defaults to running '{command}'. +Extra flags may be available; to find out more, try '{command} --help'. +There are also alternate subcommands available, see below. + +default settings on this script: + {accuracy=} + {tracing_mode=} + {save_dir=} + {check_str=} +""", + formatter_class=argparse.RawTextHelpFormatter, + ) + + def common_flags(parser: argparse.ArgumentParser) -> None: + accuracy_group = parser.add_mutually_exclusive_group() + accuracy_group.add_argument( + "--no-accuracy", + dest="accuracy", + action="store_const", + const="", + default=accuracy, + help="do not test accuracy, just run the module and see if it errors", + ) + accuracy_group.add_argument( + "--accuracy", + action="store_const", + const="accuracy", + default=accuracy, + help="""\ +test if the RMSE between the compiled module and the fp64 reference is greater +than eager and the fp64 reference. This is usually more reliable than the +standard allclose test, as we expect numeric differences from compiling, often +improving accuracy over eager. RMSE test allows for compiled module to +diverge greatly from eager, as long as this divergence moves it closer to the +'true' mathematical value of the network. Caveats: (1) double precision can +still suffer from rounding error, so it is not a perfect reference (see for +example 'Herbie: Automatically Improving Floating Point Accuracy') for +approaches that detect the necessary working precision and compute it in +arbitrary precision floating point; unfortunately, this is not practical for +tensor computation; (2) if there are not enough samples in the output being +compared, we may get unlucky and have an unlucky greater RMSE than eager; this +could be overcome by applying a more rigorous statistical test at some +p-value, which we leave for future work. +""", + ) + accuracy_group.add_argument( + "--strict-accuracy", + dest="accuracy", + action="store_const", + const="strict_accuracy", + default=accuracy, + help="""\ +by default, when doing accuracy minification we will reject reductions which +change the divergence from a floating point divergence to a integral/boolean +divergence. This is because some operations like ReLU involve temporarily +sharp boundaries that smooth out again afterwards; without requiring +divergence on floating point, the minifier will often fixate on divergent +boolean tensor even though this is not the true source of the divergence. +However, rejecting these reductions makes it more difficult for the minifier +to make process. Using this option will let the minifier progress for ALL +divergences--you just might not end up with a useful repro in the end.""", + ) + + parser.add_argument( + "--save-dir", + type=str, + default=save_dir, + metavar="DIR", + help="directory where saved inputs live", + ) + parser.add_argument( + "--no-save-dir", + dest="save_dir", + action="store_const", + const=None, + help="don't use any directory for saved inputs", + ) + + subparsers = parser.add_subparsers( + dest="command", metavar="{run,minify}", required=True + ) + + parser_run = subparsers.add_parser( + "run", + help="just run the repro", + ) + common_flags(parser_run) + + parser_minify = subparsers.add_parser( + "minify", help="run the minifier on the repro" + ) + common_flags(parser_minify) + parser_get_args = subparsers.add_parser("get_args", help="get the args") + common_flags(parser_get_args) + parser_minify.add_argument( + "--skip-saving-eager-intermediates", + action="store_true", + help="skip saving eager intermediates on --minify", + ) + parser_minify.add_argument( + "--offload-to-disk", + action="store_true", + help="during minification, offload delta debugging intermediates to disk. Use if you're OOMing", + ) + parser_minify.add_argument( + "--skip-sanity", + action="store_true", + help="skip sanity check at beginning of minification on original graph", + ) + parser_minify.add_argument( + "--max-granularity", + type=int, + default=None, + help="start at this granularity and work down; must be power of 2", + ) + parser_minify.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + parser_minify.add_argument( + "--minifier-export-mode", + type=str, + default=minifier_export_mode, + help=( + "The export mode used in minifier, either dynamo or python." + "`dynamo` corresponds to strict=True, and `python` corresponds to strict=False." + ), + ) + parser_minify.add_argument( + "--skip-export-error", + type=bool, + default=skip_export_error, + help="Skip intermediate graphs that cannot be exported.", + ) + + # Run the repro in the context of minification, inverting exit code meaning + parser_minifier_query = subparsers.add_parser( + "minifier-query", + ) + common_flags(parser_minifier_query) + parser_minifier_query.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + + args = None + if len(sys.argv) <= 1: + args = [command, *sys.argv[1:]] + + options = parser.parse_args(args) + COMMAND_FNS = { + "minify": repro_minify, + "run": repro_run, + "get_args": repro_get_args, + } + return COMMAND_FNS[options.command]( + options, exported_program, config_patches=config_patches + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac31eeee5362e1d1becbdeb6199ec70cea5c0e2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__init__.py @@ -0,0 +1,230 @@ +""" +This package implements variable tracking and symbolic execution capabilities for Dynamo, +which are essential for converting Python code into FX graphs. It provides a comprehensive +set of variable types that handle different Python constructs during tracing. + +Each variable type (like BuiltinVariable, TensorVariable, NNModuleVariable, etc.) is responsible +for tracking and symbolically executing operations on specific Python objects. This enables +Dynamo to: +- Track the flow of values through Python code +- Maintain correct semantics during graph conversion +- Handle complex Python features like context managers, iterators, and custom objects +- Support both eager and symbolic execution modes + +The VariableTracker base class provides the foundation for all variable types, with each +subclass implementing specific behavior for different Python constructs. This modular design +allows Dynamo to accurately trace and optimize Python code while preserving its semantics. +""" + +from .base import VariableTracker +from .builtin import BuiltinVariable +from .constant import ConstantVariable, EnumVariable +from .ctx_manager import ( + CatchWarningsCtxManagerVariable, + ContextWrappingVariable, + CUDADeviceVariable, + DeterministicAlgorithmsVariable, + DisabledSavedTensorsHooksVariable, + DualLevelContextManager, + DynamoConfigPatchVariable, + ErrorOnGraphBreakVariable, + FSDPParamGroupUseTrainingStateVariable, + FxTracebackAnnotateVariable, + GradIncrementNestingCtxManagerVariable, + GradInplaceRequiresGradCtxManagerVariable, + GradModeVariable, + InferenceModeVariable, + JvpIncrementNestingCtxManagerVariable, + SDPAKernelVariable, + SetFwdGradEnabledContextManager, + TemporarilyPopInterpreterStackCtxManagerVariable, + VmapIncrementNestingCtxManagerVariable, + WithEnterFunctionVariable, + WithExitFunctionVariable, +) +from .dicts import ( + ConstDictVariable, + DefaultDictVariable, + DictKeySetVariable, + FrozensetVariable, + MappingProxyVariable, + NNModuleHooksDictVariable, + SetVariable, +) +from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable +from .functions import ( + BuiltinMethodVariable, + CollectionsNamedTupleFunction, + CreateTMADescriptorExperimentalVariable, + CreateTMADescriptorStableVariable, + FunctionDecoratedByContextlibContextManagerVariable, + FunctoolsPartialVariable, + FunctoolsWrapsVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, + NestedUserFunctionVariable, + PolyfilledFunctionVariable, + PyTreeGetNodeTypeFunctionVariable, + PyTreeTreeIsLeafFunctionVariable, + SkipFunctionVariable, + TMADescriptorExperimentalVariable, + TMADescriptorStableVariable, + UserFunctionVariable, + UserMethodVariable, + WrapperUserFunctionVariable, + WrapperUserMethodVariable, +) +from .higher_order_ops import ( + FunctionalCallVariable, + FunctorchHigherOrderVariable, + ReparametrizeModuleCallVariable, + TorchHigherOrderOperatorVariable, +) +from .iter import ( + CountIteratorVariable, + FilterVariable, + IteratorVariable, + ItertoolsVariable, + MapVariable, + ObjectIteratorVariable, + RepeatIteratorVariable, + ZipVariable, +) +from .lazy import LazyVariableTracker +from .lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + NamedTupleVariable, + RangeVariable, + SliceVariable, + TupleIteratorVariable, + TupleVariable, +) +from .misc import ( + AutogradFunctionContextVariable, + AutogradFunctionVariable, + CellVariable, + DeletedVariable, + ExceptionVariable, + GetAttrVariable, + LambdaVariable, + MethodWrapperVariable, + NewGlobalVariable, + NumpyVariable, + PythonModuleVariable, + RandomClassVariable, + RandomVariable, + StringFormatVariable, + SuperVariable, + TorchVersionVariable, + TypingVariable, + UnknownVariable, + WeakRefVariable, +) +from .nn_module import ( + FSDPManagedNNModuleVariable, + NNModuleVariable, + UnspecializedBuiltinNNModuleVariable, + UnspecializedNNModuleVariable, +) +from .optimizer import OptimizerVariable +from .sdpa import SDPAParamsVariable +from .streams import EventVariable, StreamContextVariable, StreamVariable +from .tensor import ( + DataPtrVariable, + FakeItemVariable, + NumpyNdarrayVariable, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, + UntypedStorageVariable, +) +from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable +from .user_defined import ( + FrozenDataClassVariable, + MutableMappingVariable, + RemovableHandleVariable, + UserDefinedClassVariable, + UserDefinedDictVariable, + UserDefinedExceptionClassVariable, + UserDefinedExceptionObjectVariable, + UserDefinedListVariable, + UserDefinedObjectVariable, + UserDefinedSetVariable, + UserDefinedTupleVariable, +) + + +__all__ = [ + "AutogradFunctionContextVariable", + "AutogradFunctionVariable", + "BackwardHookVariable", + "BaseListVariable", + "BuiltinVariable", + "CatchWarningsCtxManagerVariable", + "ConstantVariable", + "ConstDictVariable", + "ContextWrappingVariable", + "CountIteratorVariable", + "CreateTMADescriptorExperimentalVariable", + "CreateTMADescriptorStableVariable", + "CUDADeviceVariable", + "DataPtrVariable", + "DefaultDictVariable", + "DeletedVariable", + "DeterministicAlgorithmsVariable", + "DictKeySetVariable", + "DynamoConfigPatchVariable", + "EnumVariable", + "FakeItemVariable", + "GetAttrVariable", + "GradModeVariable", + "IteratorVariable", + "ItertoolsVariable", + "LambdaVariable", + "LazyVariableTracker", + "ListIteratorVariable", + "ListVariable", + "NamedTupleVariable", + "NestedUserFunctionVariable", + "CellVariable", + "NewGlobalVariable", + "NNModuleVariable", + "NumpyNdarrayVariable", + "NumpyVariable", + "OptimizerVariable", + "PlacementVariable", + "PolyfilledFunctionVariable", + "PythonModuleVariable", + "RangeVariable", + "RemovableHandleVariable", + "RepeatIteratorVariable", + "SDPAParamsVariable", + "ErrorOnGraphBreakVariable", + "SkipFunctionVariable", + "SliceVariable", + "StringFormatVariable", + "SuperVariable", + "TemporarilyPopInterpreterStackCtxManagerVariable", + "TensorVariable", + "TMADescriptorExperimentalVariable", + "TMADescriptorStableVariable", + "TorchCtxManagerClassVariable", + "TorchInGraphFunctionVariable", + "TorchVersionVariable", + "TupleVariable", + "UnknownVariable", + "UnspecializedNNModuleVariable", + "UnspecializedPythonVariable", + "UntypedStorageVariable", + "UserDefinedClassVariable", + "UserDefinedTupleVariable", + "UserDefinedObjectVariable", + "UserFunctionVariable", + "UserMethodVariable", + "VariableTracker", + "WithEnterFunctionVariable", + "WithExitFunctionVariable", + "MappingProxyVariable", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91e9f6aec418ab109c9bc13425871c4248b8679d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/base.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16cc58940a90b25aac024e47ca67c8f3e5adc8de Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/base.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/constant.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/constant.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df1c7b45ced36ce2b1b9f29863c41d335018e497 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/constant.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/ctx_manager.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/ctx_manager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f15010235cc824e5f6973ca9282a243e53ab02c0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/ctx_manager.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/dicts.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/dicts.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d53a059913f2fe630a30b38c8acb32e96744a65b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/dicts.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/distributed.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/distributed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b20a29d74a8d3a2b9e39f35fc6a04acd12723ed8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/distributed.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/iter.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/iter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..792342ee4c315740992180936ef756e6452aec4c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/iter.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/lazy.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/lazy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..983bfe400c1dc1b2b5530f3112e68a6d01a06866 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/lazy.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/lists.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/lists.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e75b1de93224428c55192cc1255ad04d0456c8c7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/lists.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/misc.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/misc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80d6fdbeb5d61fc5a2f1c8c0eadae8ea02e48581 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/misc.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/nn_module.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/nn_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe019fc33c02976009bb0cfa1ce8040396ef3790 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/nn_module.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/optimizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/optimizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16a110bdbe36715206f375aad91c717e8f1e4d15 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/optimizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/script_object.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/script_object.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01fb3fe90c741955bf5f7fc30d68cc045440ee1a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/script_object.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/sdpa.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/sdpa.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5f56e175adaf1488ccfd743a66555e4ae3b7dd7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/sdpa.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/streams.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/streams.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f6e2c8524151e42105619d73fecdf50f243c9e6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/streams.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/tensor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b7ec486f986e856a86e0751ecacfa2bd899b2aa Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/tensor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/torch_function.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/torch_function.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f4e26b9f0c561c27f5e64a388ddef55ca7e4dbb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/torch_function.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/base.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/base.py new file mode 100644 index 0000000000000000000000000000000000000000..af63c4c9d75999a677d6b1c327ea58b165b2520b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/base.py @@ -0,0 +1,825 @@ +""" +Core variable tracking functionality for Dynamo. This module defines the fundamental +classes and systems used to track and manage variables during Dynamo's operation. + +The module provides: +1. VariableTracker - The base class for tracking variables during compilation +2. MutationType system - Classes for tracking and managing mutations to variables +3. Source type management - Utilities for tracking variable origins and scope +4. Variable state management - Tools for managing variable state and transformations + +These components form the foundation of Dynamo's variable handling system, +enabling accurate tracking and transformation of Python code into optimized +computations. +""" + +import collections +import logging +from collections.abc import Callable, ItemsView, KeysView, Sequence, ValuesView +from enum import Enum +from typing import Any, NoReturn, Optional, TYPE_CHECKING + +from torch._guards import Guard +from torch.fx.proxy import Node + +from .. import graph_break_hints, variables +from ..current_scope_id import current_scope_id +from ..exc import raise_observed_exception, unimplemented +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, Source +from ..utils import cmp_name_to_op_mapping, istype + + +if TYPE_CHECKING: + from ..codegen import PyCodegen + from ..symbolic_convert import InstructionTranslator + from .constant import ConstantVariable + from .functions import UserFunctionVariable + + +log = logging.getLogger(__name__) + + +class SourceType(Enum): + """ + This Enum divides VariableTracker into 2 cases, depending on the variable + it represents: + - already existed that Dynamo began tracking while introspection (Existing) + - is a new variable that is created during Dynamo introspection (New) + + In general, we have these invariants: + 1. for `VariableTracker` associated with `Existing`, its `source` field must not be None. + 2. for `VariableTracker` associated with `New`, most of the time its + `source` field is None, except for cases like side effect codegen for + `AttributeMutationNew`, during which we generate a + `LocalSource('tmp...')` for such variable, to facilitate codegen. + """ + + Existing = 0 + New = 1 + + +class MutationType: + """ + Base class for Variable.mutation_type. It encodes information about + 1. The type of mutation Dynamo allows on the variable. + 2. Whether the value represented by this variable already existed before + Dynamo tracing. + """ + + def __init__(self, typ: SourceType) -> None: + # In HigherOrderOperator tracing, we need to distinguish + # between MutationTypes inside the HigherOrderOperator and + # ones outside it. For example, it is not safe to mutate + # `a` in the following example because it was constructed + # in a different scope. + # + # def f(x): + # a = 1 + # def g(x): + # nonlocal a + # a = 2 + # return x + # return wrap(g, x) + a + # + # We use self.scope to distinguish this. + # scope == 0: The object was an existing variable + # scope == 1: The object was created while Dynamo + # was introspecting a function + # (and no HigherOrderOps were involved) + # scope >= 2: The object was created through + # Dynamo introspection of a HigherOrderOp. + # The exact number corresponds to the level + # of nested HigherOrderOps. + if typ is SourceType.Existing: + self.scope = 0 + elif typ is SourceType.New: + self.scope = current_scope_id() + else: + unimplemented( + gb_type="Unsupported SourceType", + context=f"MutationType.__init__ {self} {typ}", + explanation=f"Dynamo does not support the type `{typ}`", + hints=[ + "This branch is not supposed to be reachable.", + *graph_break_hints.DYNAMO_BUG, + ], + ) + + +class ValueMutationNew(MutationType): + """ + This case of VariableTracker.mutation_type marker indicates + 1. Dynamo allows mutation on the value itself (rather than its attributes). + 2. The value is created by the bytecode Dynamo is tracing through. + + For instance, Dynamo could model a newly created list with this marker, + indicating that while we need to model mutations to this list, we don't have + to emit bytecode for these mutations if the list doesn't escape into the + Python world. + """ + + def __init__(self) -> None: + super().__init__(SourceType.New) + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other: object) -> bool: + return self is other + + +class ValueMutationExisting(MutationType): + """ + This case of VariableTracker.mutation_type marker indicates + 1. Dynamo allows mutation on the value itself (rather than its attributes). + 2. The value exists before Dynamo tracing started. + + For instance, Dynamo could model a pre-existing list with this marker, + indicating that if we encounter mutations to this list, we need to buffer + and re-apply those mutations after the graph runs, since the list might be + used afterwards in Python. + """ + + # A flag to indicate whether mutation happened on the associated + # `VariableTracker`. This enables SideEffects to accurately and quickly + # filter out which pre-existing values it needs to generate mutation for. + is_modified: bool + + def __init__(self, is_modified: bool = False) -> None: + super().__init__(SourceType.Existing) + self.is_modified = is_modified + + +class AttributeMutation(MutationType): + """ + This case of VariableTracker.mutation_type marker indicates that Dynamo + allows mutation on the value's attributes. + """ + + +class AttributeMutationExisting(AttributeMutation): + """ + This case of VariableTracker.mutation_type marker indicates + 1. Dynamo allows mutation on the value's attributes. + 2. The value exists before Dynamo tracing started. + + For instance, Dynamo could model a pre-existing object with this marker, + indicating that if we encounter mutations to this object, we need to buffer + then re-apply those mutations after the graph runs, since the object might + be used afterwards in Python. + """ + + def __init__(self) -> None: + super().__init__(SourceType.Existing) + + +class AttributeMutationNew(AttributeMutation): + """ + This case of VariableTracker.mutation_type marker indicates + 1. Dynamo allows mutation on the value's attributes. + 2. The value is created by the bytecode Dynamo is tracing through. + + For instance, Dynamo could model a newly created object with this marker, + indicating that while we need to model mutations to this object, we don't + have to emit bytecode for these mutations if the object doesn't escape into + the Python world. + """ + + def __init__(self, cls_source: Optional[Source] = None) -> None: + super().__init__(SourceType.New) + self.cls_source = cls_source + + +def _is_top_level_scope(scope_id: int) -> bool: + return scope_id == 1 + + +def is_side_effect_safe(m: MutationType) -> bool: + scope_id = current_scope_id() + + # In the top-level scope (if no HigherOrderOperators are involved), + # we are allowed to modify variables created in this scope as well + # as existing variables. + if _is_top_level_scope(scope_id): + return True + # Otherwise, only allow local mutation of variables created in the current scope + return m.scope == scope_id + + +# This helps users of `as_python_constant` to catch unimplemented error with +# more information; it inherits `NotImplementedError` for backward +# compatibility reasons. +class AsPythonConstantNotImplementedError(NotImplementedError): + vt: "VariableTracker" + + def __init__(self, vt: "VariableTracker") -> None: + super().__init__(f"{vt} is not a constant") + self.vt = vt + + +class VariableTrackerMeta(type): + all_subclasses: list[type] = [] + + def __new__( + mcs: type, name: str, bases: tuple[type, ...], attrs: dict[str, Any] + ) -> type: + # Determine which metaclass to use based on the class attributes + # Classes with _no_implicit_realize = True should NOT implicitly realize + # (they need standard isinstance behavior to avoid infinite recursion) + # Check if any base class has _no_implicit_realize set, or if it's in attrs + no_implicit_realize = attrs.get("_no_implicit_realize", False) or any( + getattr(base, "_no_implicit_realize", False) for base in bases + ) + if no_implicit_realize or name == "VariableTracker": + # Use base VariableTrackerMeta (no custom __instancecheck__) + return super().__new__(VariableTrackerMeta, name, bases, attrs) + else: + # Use ImplicitRealizingVariableTrackerMeta for all other subclasses + return super().__new__( + ImplicitRealizingVariableTrackerMeta, name, bases, attrs + ) + + def __init__( + cls: type, name: str, bases: tuple[type, ...], attrs: dict[str, Any] + ) -> None: + super().__init__(name, bases, attrs) # type: ignore[misc] + VariableTrackerMeta.all_subclasses.append(cls) + + +class ImplicitRealizingVariableTrackerMeta(VariableTrackerMeta): + def __instancecheck__(self, instance: object) -> bool: + """Make isinstance work with LazyVariableTracker""" + if instancecheck(LazyVariableTracker, instance): + return instance.lazy_isinstance(self) # pyrefly: ignore[missing-attribute] + return instancecheck(self, instance) + + +class VariableTracker(metaclass=VariableTrackerMeta): + """ + Base class for tracked locals and stack values + + VariableTracker instances are immutable and should be copied in + order to change them. + + Prefer the factory function VariableTracker.build() over VariableTracker.__init__(). + """ + + # fields to leave unmodified in apply() + _nonvar_fields = { + "value", + "guards", + "source", + "mutation_type", + "parents_tracker", + "user_code_variable_name", + } + + def clone(self, **kwargs: Any) -> "VariableTracker": + """Shallow copy with some (optional) changes""" + args = dict(self.__dict__) + args.update(kwargs) + return self.__class__(**args) + + @classmethod + def visit( + cls, + fn: Callable[["VariableTracker"], None], + value: Any, + cache: Optional[dict[int, Any]] = None, + ) -> None: + """ + Walk value and call fn on all the VariableTracker instances + """ + if cache is None: + cache = {} + + idx = id(value) + if idx in cache: + return + # save `value` to keep it alive and ensure id() isn't reused + cache[idx] = value + + if isinstance(value, VariableTracker): + value = value.unwrap() + fn(value) + value = value.unwrap() # calling fn() might have realized it + nonvars = value._nonvar_fields + for key, subvalue in value.__dict__.items(): + if key not in nonvars: + cls.visit(fn, subvalue, cache) + elif istype(value, (list, tuple)): + for subvalue in value: + cls.visit(fn, subvalue, cache) + elif istype(value, (dict, collections.OrderedDict)): + for subvalue in value.values(): + cls.visit(fn, subvalue, cache) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + def debug_repr(self) -> str: + # Intended to be overridden to provide more info + try: + return repr(self.as_python_constant()) + except NotImplementedError: + return repr(self) + + def python_type(self) -> type: + """ + Abstract method to be implemented by subclasses of VariableTracker. + + This method should return the type represented by the instance of the subclass. + The purpose is to provide a standardized way to retrieve the Python type information + of the variable being tracked. + + Returns: + type: The Python type (such as int, str, list, etc.) of the variable tracked by + the subclass. If the type cannot be determined or is not relevant, + leaving it undefined or invoking super() is always sound. + + Note: + This is an abstract method and may be overridden in subclasses. + + Example: + class SetVariable(VariableTracker): + def python_type(self): + return set + + Raises: + NotImplementedError: If the method is not implemented in a subclass. + """ + try: + return type(self.as_python_constant()) + except NotImplementedError: + raise NotImplementedError(f"{self} has no type") from None + + def python_type_name(self) -> str: + try: + return self.python_type().__name__ + except NotImplementedError: + return "" + + def as_python_constant(self) -> Any: + """For constants""" + raise AsPythonConstantNotImplementedError(self) + + def guard_as_python_constant(self) -> Any: + """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" + try: + return self.as_python_constant() + except NotImplementedError: + unimplemented( + gb_type="Not a Python constant", + context=f"guard_as_python_constant {self}", + explanation=f"Failed to convert {self} into a Python constant.", + hints=[], + ) + + def is_python_constant(self) -> bool: + try: + self.as_python_constant() + return True + except NotImplementedError: + return False + + def is_constant_match(self, *values: Any) -> bool: + """ + Check if this variable is a python constant matching one of the given values. + + Examples: + var.is_constant_match(None) # True if var is constant None + var.is_constant_match(True, False) # True if var is constant True or False + var.is_constant_match(NotImplemented) # True if var is constant NotImplemented + """ + return False + + def is_constant_none(self) -> bool: + """Check if this variable is a constant None value.""" + return False + + def make_guard(self, fn: Callable[..., Any]) -> Guard: + if self.source: + return self.source.make_guard(fn) + raise NotImplementedError + + # TODO[@lucaskabela] - change this type to `InstructionTranslatorBase` + # and cascade that (large blast radius) + def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: + """getattr(self, name) returning a python constant""" + raise NotImplementedError + + def is_symnode_like(self) -> bool: + """Return True for values that can participate in SymNode operations""" + return False + + def is_tensor(self) -> bool: + """Return True for TensorVariable instances""" + return False + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + """getattr(self, name) returning a new variable""" + value = self.const_getattr(tx, name) + if not variables.ConstantVariable.is_literal(value): + raise NotImplementedError + source = self.source and AttrSource(self.source, name) + if source and not self.is_python_constant(): + # The second condition is to avoid guards on const getattr objects + # like __code__.co_argcount + install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) + return variables.ConstantVariable.create(value, source=source) + + def is_proxy(self) -> bool: + try: + self.as_proxy() + return True + except NotImplementedError: + return False + + def as_proxy(self) -> Any: + raise NotImplementedError(str(self)) + + def maybe_fx_node(self) -> Optional[Node]: + try: + proxy = self.as_proxy() + import torch.fx + + if isinstance(proxy, torch.fx.Proxy): + return proxy.node + return None + except NotImplementedError: + return None + + def reconstruct(self, codegen: "PyCodegen") -> None: + raise NotImplementedError + + def unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]: + raise NotImplementedError + + def force_unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]: + # like unpack_var_sequence, but should only be used when it is + # safe to eagerly (vs. lazily) unpack this variable. + # e.g. map(f, x) is normally evaluated lazily but sometimes + # we want to force eager unpacking, e.g. when converting to a list. + # NOTE: this method is allowed to mutate the VariableTracker, so + # it should only be called once. + return self.unpack_var_sequence(tx) + + def has_unpack_var_sequence(self, tx: Any) -> bool: + try: + self.unpack_var_sequence(tx) + return True + except NotImplementedError: + return False + + # NB: don't call force_unpack_var_sequence, especially if it mutates! + def has_force_unpack_var_sequence(self, tx: Any) -> bool: + return self.has_unpack_var_sequence(tx) + + # Forces unpacking the var sequence while also applying a function to each element. + # Only use when it is safe to eagerly unpack this variable (like force_unpack_var_sequence). + # INVARIANT: variable must satisfy has_force_unpack_var_sequence() == True! + def force_apply_to_var_sequence( + self, tx: Any, fn: Callable[["VariableTracker"], Any] + ) -> None: + assert self.has_force_unpack_var_sequence(tx) + for v in self.unpack_var_sequence(tx): + fn(v) + + def call_obj_hasattr(self, tx: Any, name: str) -> "ConstantVariable": + unimplemented( + gb_type="Unsupported hasattr call", + context=f"call_obj_hasattr {self} {name}", + explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`", + hints=[ + f"Avoid calling `hasattr({self.__class__.__name__}, {name})` in your code.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def call_function( + self, + tx: Any, + args: Sequence["VariableTracker"], + kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + unimplemented( + gb_type="Unsupported function call", + context=f"call_function {self} {args} {kwargs}", + explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`", + hints=[ + f"Avoid calling `{self.debug_repr()}` in your code.", + "Please report an issue to PyTorch.", + ], + ) + + def call_method( + self, + tx: Any, + name: str, + args: list["VariableTracker"], + kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + if name == "__len__" and self.has_unpack_var_sequence(tx): + assert not (args or kwargs) + return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx))) + elif ( + name == "__getattr__" + and len(args) == 1 + and args[0].is_python_constant() + and not kwargs + ): + return self.var_getattr(tx, args[0].as_python_constant()) + elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs: + other = args[0] + if not isinstance(self, type(other)) and not ( + isinstance(self, variables.GetAttrVariable) + or isinstance(other, variables.GetAttrVariable) + ): + # NB: GetAttrVariable is a special case because sometimes an + # object can map to GetAttrVariable but other time as + # SkipFunctionVariable if it is an input to the compiled + # function, e.g. tensor.data_ptr + return variables.ConstantVariable.create(NotImplemented) + # NB : Checking for mutation is necessary because we compare + # constant values + if ( + not self.is_python_constant() + or not other.is_python_constant() + or tx.output.side_effects.has_pending_mutation(self) + or tx.output.side_effects.has_pending_mutation(other) + ): + unimplemented( + gb_type="Builtin `operator.*` comparison with constant `self` failed", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=f"Failed to compare {self} with {other}, " + + f"because {other} is not a Python constant or its mutation check fails.", + hints=[], + ) + + try: + return variables.ConstantVariable.create( + cmp_name_to_op_mapping[name]( + self.as_python_constant(), other.as_python_constant() + ) + ) + except Exception as e: + raise_observed_exception( + type(e), + tx, + args=[list(map(variables.ConstantVariable.create, e.args))], + ) + hints = [ + f"Avoid calling `{self.python_type_name()}.{name}` in your code.", + "Please report an issue to PyTorch.", + ] + # additional hint for method calls on improperly constructed iterators + if isinstance(self, variables.UserDefinedObjectVariable) and name in ( + "__iter__", + "__next__", + ): + if isinstance(self.value, (KeysView, ItemsView, ValuesView)): + hints.append( + "Consider moving the creation of dict view object (e.g. `dict.keys()`, `dict.items()`,) " + "to the compiled region, instead of passing it as an input to the compiled region." + ) + hints.append( + "Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) " + "passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). " + "This can happen unintentionally if a previous graph break happens with a builtin iterator " + "in the local scope." + ) + hints.append( + "List/dict comprehensions in Python <= 3.11 result in implicit function calls, which Dynamo " + "cannot trace as a top level frame. Possible workarounds are (1) use a loop instead of a comprehension, " + "(2) fix any graph breaks in the function above the comprehension, (3) wrap the comprehension in a " + "function, or (4) use Python 3.12+." + ) + unimplemented( + gb_type="Unsupported method call", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=f"Dynamo does not know how to trace method `{name}` of class `{self.python_type_name()}`", + hints=hints, + ) + + def call_tree_map( + self, + tx: Any, + tree_map_fn: "UserFunctionVariable", + map_fn: "VariableTracker", + rest: Sequence["VariableTracker"], + tree_map_kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + """Performance optimization to implement optree.tree_map faster than tracing it""" + is_leaf_var = tree_map_kwargs.get("is_leaf") + if is_leaf_var is not None and not is_leaf_var.is_constant_none(): + pred_result = is_leaf_var.call_function(tx, [self], {}) + try: + leaf_decision = pred_result.as_python_constant() + except NotImplementedError: + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + if leaf_decision: + return map_fn.call_function(tx, [self, *rest], {}) + + return self.call_tree_map_branch( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + def call_tree_map_branch( + self, + tx: Any, + tree_map_fn: "UserFunctionVariable", + map_fn: "VariableTracker", + rest: Sequence["VariableTracker"], + tree_map_kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + """Emulate optree.tree_map without is_leaf/none_is_leaf checks (handled above)""" + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + def _tree_map_fallback( + self, + tx: Any, + tree_map_fn: "UserFunctionVariable", + map_fn: "VariableTracker", + rest: Sequence["VariableTracker"], + tree_map_kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + tree_map_fn_copy = tree_map_fn.clone() + tree_map_fn_copy._maybe_call_tree_map_fastpath = lambda *args, **kwargs: None # type: ignore[missing-attribute] + log.debug( + "tree_map fastpath fallback triggered for %s (rest=%s, kwargs=%s)", + self, + rest, + tree_map_kwargs, + ) + return tree_map_fn_copy.call_function( + tx, + [map_fn, self, *rest], + tree_map_kwargs, + ) + + def set_name_hint(self, name: str) -> None: + pass + + def realize(self) -> "VariableTracker": + """Used by LazyVariableTracker to build the real VariableTracker""" + return self + + def unwrap(self) -> "VariableTracker": + """Used by LazyVariableTracker to return the real VariableTracker if it already exists""" + return self + + def is_realized(self) -> bool: + """Used by LazyVariableTracker to indicate an unrealized node""" + return True + + def next_variable(self, tx: Any) -> "VariableTracker": + unimplemented( + gb_type="Unsupported next() call", + context=f"next({self})", + explanation=f"Dynamo does not know how to trace calling `next()` on variable `{self}`.", + hints=[*graph_break_hints.USER_ERROR], + ) + + def is_strict_mode(self, tx: Any) -> bool: + return bool(tx.strict_checks_fn and tx.strict_checks_fn(self)) + + def is_mutable(self) -> bool: + """Whether Dynamo allows mutation on this variable.""" + return not self.is_immutable() + + def is_immutable(self) -> bool: + """Whether Dynamo bans mutation on this variable.""" + return self.mutation_type is None + + @staticmethod + def build( + tx: Any, + value: Any, + source: Optional[Source] = None, + ) -> Any: + """Create a new VariableTracker from a value and optional Source""" + if source is None: + return builder.SourcelessBuilder.create(tx, value) + else: + return variables.LazyVariableTracker.create(value, source) + + def is_python_hashable(self): + """ + Unlike the variable tracker's own __hash__, this method checks whether + the underlying Python object referenced by this variable tracker is hashable. + """ + try: + type_self = self.python_type() + except NotImplementedError: + type_self = type(self) + + unimplemented( + gb_type="Dynamo cannot determine whether the underlying object is hashable", + context=f"is_python_hashable {self}", + explanation=f"Dynamo does not know whether the underlying python object for {self} is hashable", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {type_self}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + + def get_python_hash(self): + """ + Unlike the variable tracker’s own __hash__, this method is used by + ConstDictVariableTracker to compute the hash of the underlying key object. + """ + unimplemented( + gb_type="Dynamo cannot determine the hash of an object", + context=f"get_python_hash {self}", + explanation=f"Dynamo does not know the hash of the underlying python object for {self}", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + + def is_python_equal(self, other): + """ + NB - Deliberately not overriding the __eq__ method because that can + disable the __hash__ for the vt itself. + """ + unimplemented( + gb_type="Dynamo cannot determine the equality comparison of an object", + context=f"is_python_equal {self}", + explanation=f"Dynamo does not know the equality comparison of the underlying python object for {self}", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + + def __init__( + self, + *, + source: Optional[Source] = None, + mutation_type: Optional[MutationType] = None, + ) -> None: + super().__init__() + self.source = source + self.mutation_type = mutation_type + + # NOTE sometimes mutation_type is set afterwards for implementation + # convenience, we don't validate those cases at the moment. + if mutation_type is not None: + if isinstance(mutation_type, (ValueMutationNew, AttributeMutationNew)): + # If this fails, it's either + # 1. one mistakenly passed in a source + # 2. `mutation_type` is incorrect + assert source is None + else: + assert isinstance( + mutation_type, (ValueMutationExisting, AttributeMutationExisting) + ) + # If this fails, it's either + # 1. one forgot to pass in a source + # 2. `mutation_type` is incorrect + assert source is not None + + +def raise_type_error_exc(tx: Any, msg_str: str) -> NoReturn: + msg = variables.ConstantVariable.create(msg_str) + raise_observed_exception(TypeError, tx, args=[msg]) + + +def typestr(*objs: object) -> str: + if len(objs) == 1: + (obj,) = objs + if isinstance(obj, VariableTracker): + return str(obj) + else: + return type(obj).__name__ + else: + return " ".join(map(typestr, objs)) + + +instancecheck = type.__instancecheck__ +from . import builder +from .lazy import LazyVariableTracker diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..62c9bb896ef9bc4b92455f3ea71ecabdcb148be4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py @@ -0,0 +1,3957 @@ +# mypy: ignore-errors + +""" +This module contains classes and utilities for building variable trackers in Dynamo. +Variable trackers are used to convert Python values into symbolic representations +that can be traced and transformed during graph capture. + +The key classes are: + +- VariableBuilder: Handles source-tracked objects that need guards and proper + reconstruction in the output graph. Used for inputs, module attributes, etc. + +- SourcelessBuilder: Handles ephemeral objects created during tracing that don't + need source tracking or guards. Used for temporary lists, intermediate values, etc. + +Variable trackers enable Dynamo to track the flow of values through the program, +maintain guards for dynamic properties, and reconstruct values in the output graph. +The builders in this module handle converting Python values into appropriate +VariableTracker instances based on their type and usage context. +""" + +import abc +import collections +import contextlib +import copy +import dataclasses +import enum +import functools +import inspect +import itertools +import logging +import math +import operator +import random +import re +import sys +import traceback +import types +import weakref +from collections.abc import Callable, MutableMapping +from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union + +import sympy + +import torch +from torch import SymInt +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.graph_bytecode_inputs import ( + get_external_object_by_index, + register_user_object, +) +from torch._dynamo.utils import ( + get_metrics_context, + is_int_specialization_case, + is_torch_sym, + set_feature_use, +) +from torch._guards import TracingContext +from torch._higher_order_ops.flat_apply import flat_apply +from torch._higher_order_ops.torchbind import call_torchbind +from torch._library.opaque_object import is_opaque_type, is_opaque_value_type +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode +from torch._subclasses.meta_utils import is_sparse_any, safe_grad +from torch._utils_internal import justknobs_check +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental._dynamism import normalize_source_name +from torch.fx.experimental.sym_node import _DynamicScalar, DynamicInt +from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + _nested_int_aware_sort, + DimDynamic, + RelaxedUnspecConstraint, + StatefulSymbolicContext, + SubclassSymbolicContext, + SymbolicContext, + SymIntSymbolicContext, + TrackedFake, +) +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.nn.utils._expanded_weights import ExpandedWeight +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + is_traceable_wrapper_subclass_type, +) +from torch.utils._sympy.value_ranges import ValueRanges +from torch.utils.weak import TensorWeakRef + +from .. import config, graph_break_hints, mutation_guard, replay_record, trace_rules +from ..device_interface import get_registered_device_interfaces +from ..exc import InternalTorchDynamoError, raise_observed_exception, unimplemented +from ..guards import GuardBuilder, install_guard, make_dupe_guard +from ..pgo import ( + auto_dynamic, + auto_unset, + FrameStateSizeEntry, + InferStride, + process_automatic_dynamic, +) +from ..side_effects import SideEffects +from ..source import ( + AttrProxySource, + AttrSource, + CallMethodItemSource, + ChainedSource, + ConstDictKeySource, + ConvertIntSource, + DictGetItemSource, + DictSubclassGetItemSource, + DynamicScalarSource, + FloatTensorSource, + GetItemSource, + GradSource, + is_constant_source, + is_from_closure_source, + is_from_global_source, + is_from_nonlocal_source, + is_from_optimizer_source, + is_from_unspecialized_nn_module_source, + ListGetItemSource, + LocalSource, + NonSerializableSetGetItemSource, + NumpyTensorSource, + OptimizerSource, + RandomValueSource, + Source, + SubclassAttrListSource, + TupleIteratorGetItemSource, + UnspecializedBuiltinNNModuleSource, + UnspecializedNNModuleSource, +) +from ..utils import ( + _extract_tensor_dict, + build_checkpoint_variable, + build_invoke_subgraph_variable, + clone_input, + common_constant_types, + dict_keys, + get_fake_value, + get_items_from_dict, + get_locals_to_steal, + get_static_address_type, + is_frozen_dataclass, + is_function, + is_function_or_wrapper, + is_invoke_subgraph, + is_lru_cache_wrapped_function, + is_namedtuple, + is_parameter_freezing, + is_typing, + is_utils_checkpoint, + is_wrapper_or_member_descriptor, + istype, + namedtuple_fields, + odict_values, + proxy_args_kwargs, + range_iterator, + set_example_value, + tensor_always_has_static_shape, + tuple_iterator, + tuple_iterator_getitem, + tuple_iterator_len, + unwrap_with_attr_name_if_wrapper, + wrap_fake_exception, +) +from .base import ( + AttributeMutationNew, + typestr, + ValueMutationExisting, + ValueMutationNew, + VariableTracker, + VariableTrackerMeta, +) +from .builtin import BuiltinVariable +from .constant import ConstantVariable, EnumVariable +from .ctx_manager import ( + AutocastModeVariable, + DynamoConfigPatchVariable, + ErrorOnGraphBreakVariable, + NullContextVariable, + PreserveVersionContextVariable, +) +from .dicts import ( + ConstDictVariable, + DefaultDictVariable, + DictKeySetVariable, + FrozensetVariable, + MappingProxyVariable, + SetVariable, +) +from .distributed import ( + DeviceMeshVariable, + PlacementClassVariable, + PlacementVariable, + ProcessGroupVariable, + WorldMetaClassVariable, +) +from .functions import ( + BuiltinMethodVariable, + CollectionsNamedTupleFunction, + CollectiveFunctionRewriteVariable, + CreateTMADescriptorExperimentalVariable, + CreateTMADescriptorStableVariable, + FunctoolsPartialVariable, + FunctoolsWrapsVariable, + SysFunctionVariable, + TracebackVariable, + TritonKernelVariable, + UserFunctionVariable, + UserMethodVariable, + WrapperUserFunctionVariable, +) +from .higher_order_ops import ( + LocalMapWrappedHigherOrderVariable, + TorchHigherOrderOperatorVariable, +) +from .iter import ItertoolsVariable +from .lazy import LazyVariableTracker +from .lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + NamedTupleVariable, + RangeVariable, + SizeVariable, + SliceVariable, + TupleIteratorVariable, + TupleVariable, +) +from .misc import ( + AutogradEngineVariable, + AutogradFunctionContextVariable, + AutogradFunctionVariable, + ComptimeVariable, + ConstantLikeVariable, + DebuggingVariable, + DelayGraphBreakVariable, + GetAttrVariable, + GetSetDescriptorVariable, + LambdaVariable, + LoggingLoggerVariable, + MethodWrapperVariable, + NumpyDTypeVariable, + NumpyVariable, + PythonModuleVariable, + RandomClassVariable, + RandomVariable, + SavedTensorBox, + TorchVersionVariable, + TypingVariable, + WeakRefVariable, +) +from .nn_module import ( + FSDPManagedNNModuleVariable, + UnspecializedBuiltinNNModuleVariable, + UnspecializedNNModuleVariable, +) +from .optimizer import OptimizerVariable +from .script_object import OpaqueObjectClassVariable, TorchScriptObjectVariable +from .sdpa import SDPAParamsVariable +from .streams import EventVariable, StreamContextVariable, StreamVariable +from .tensor import ( + NumpyNdarrayVariable, + supported_const_comparison_op_values, + SymNodeVariable, + TensorSubclassVariable, + TensorVariable, + UnspecializedPythonVariable, +) +from .torch import ( + DispatchKeySetVariable, + FuncTorchInterpreterVariable, + TorchCtxManagerClassVariable, + TorchInGraphFunctionVariable, +) +from .torch_function import ( + TensorWithTFOverrideVariable, + torch_function_mode_stack_state_mgr, + TorchFunctionModeVariable, +) +from .user_defined import ( + FrozenDataClassVariable, + IntWrapperVariable, + KeyedJaggedTensorVariable, + MutableMappingVariable, + SourcelessGraphModuleVariable, + UserDefinedClassVariable, + UserDefinedDictVariable, + UserDefinedExceptionClassVariable, + UserDefinedListVariable, + UserDefinedObjectVariable, + UserDefinedSetVariable, + UserDefinedTupleVariable, +) + + +try: + import numpy as np +except ModuleNotFoundError: + np = None + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +log = logging.getLogger(__name__) +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + + +DimList = list + + +def safe_has_grad(t): + with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter): + return hasattr(t, "grad") + + +class _missing: + pass + + +@dataclasses.dataclass +class GraphArg: + source: Source + # TODO: storing a SymInt here but not a FakeTensor is a pretty strange + # thing to do. Probably should have example (which stores an int) and + # fake_example + _example: Union[TensorWeakRef, torch.SymInt] + # When True, this indicates that this GraphArg is a Python quantity (e.g., + # a float or int) which we pass to the FX graph as a Tensor. This + # controls how we codegen calls into the Dynamo graph: we will call + # torch.as_tensor on the quantity before passing it in. + # + # Note that we typically do not pass dynamic integers as tensors, because + # they will most frequently just be used for size computation. But this + # is a policy decision that we can change our mind on; in particular, when + # an int comes from a random number generator (e.g., random.randint), we + # DO pass it as a tensor. + # + # It's also worth noting that our current tracing rules for + # pass_arg_as_tensor as subtly broken: we just pun the variable as a + # 0d scalar Tensor and pray that the semantics are the same. Which they + # often are, but not necessarily. ezyang(May 2024) plans to fix this + # soon. + pass_arg_as_tensor: bool + fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor] + # UnspecializedPythonVariable often masquerades as a tensor. + # We MUST NOT generate shape guard code + # that actually tries to access tensor properties on these values. + # is_tensor lets us tell if this graph arg actually is a tensor + # or not. + is_tensor: bool = True + # Sometimes, the Tensor we pass to example is freshly allocated (smh). + # Then we cannot only keep a weak reference to it. This lets you + # stash a strong reference too. + example_strong_ref: Optional[torch.Tensor] = None + + def __setattr__(self, name, value): + # Use object.__setattr__ to bypass Dynamo's STORE_ATTR interception. + # This is needed because when PYTORCH_TEST_WITH_DYNAMO=1, even internal + # GraphArg creation can be traced, and with replay_side_effects=False, + # normal STORE_ATTR bytecode only records mutations without applying them. + object.__setattr__(self, name, value) + + @property + def example(self): + if isinstance(self._example, TensorWeakRef): + r = self._example() + assert r is not None + return r + else: + return self._example + + def __post_init__(self): + if isinstance(self._example, torch.Tensor): + self._example = TensorWeakRef(self._example) + assert is_fake(self.fake_tensor) + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.source) + + def erase(self): + self._example = None + self.example_strong_ref = None + + def __eq__(self, other): + return self.source.name == other.source.name + + +class BackwardStateGraphArg(GraphArg): + def __init__(self) -> None: + super().__init__( + source=None, + _example=BackwardState(), + pass_arg_as_tensor=False, + fake_tensor=None, + is_tensor=False, + ) + + def reconstruct(self, codegen: "PyCodegen"): + assert codegen.tx.output.backward_state_var + codegen.add_push_null( + lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState") + ) + codegen.call_function(0, False) + codegen.dup_top() + codegen.store(codegen.tx.output.backward_state_var) + + +# All class-based iterators in itertools +# NOTE: use id() because some objects are not hashable, it will raise error during lookup +ITERTOOLS_TYPE_IDS: frozenset[int] = frozenset( + id(member) + for name, member in vars(itertools).items() + if not name.startswith("_") and inspect.isclass(member) +) +# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py +ITERTOOLS_POLYFILLED_TYPE_IDS: set[int] = set() + +# Capture fn pointer at import time +# This is to guard against trying to mark the iterated tensors +# as static in case user overrides fn ptr +og_module_named_buffers_fn_ptr = torch.nn.Module.named_buffers +og_module_named_parameters_fn_ptr = torch.nn.Module.named_parameters + + +class VariableBuilder: + """Wrap a python value in a VariableTracker() instance""" + + def __init__( + self, + tx, + source: Source, + ) -> None: + assert source is not None, ( + "Consider SourcelessBuilder for ephemeral objects, usually objects created locally." + ) + assert TracingContext.try_get() is not None, "Expected active TracingContext" + super().__init__() + self.tx = tx + self.source = source + self.name = source.name + + def __call__(self, value): + if value in self.tx.output.side_effects: + side_effect_result = self.tx.output.side_effects[value] + dup_guard = make_dupe_guard(self.source, side_effect_result.source) + if dup_guard: + self.install_guards(dup_guard) + + if isinstance(value, torch.nn.Module) and isinstance( + side_effect_result, UnspecializedNNModuleVariable + ): + # This means that two nn module instances with different sources + # have the same id. NN modules are somewhat special objects, + # because we have to track their nn_module_stack for ease of + # use. But if we don't do anything, we will just return the + # older variable tracker with the older nn_module_stack. So, + # lets return the old variable tracker but update its + # nn_module_stack + side_effect_result.set_nn_module_stack_source(self.source) + return side_effect_result + + cached_vt = self.tx.output.variable_tracker_cache.lookup(value, self.source) + if cached_vt: + return cached_vt + + vt = self._wrap(value) + + if vt.source is None: + vt.source = self.source + + def _is_deduplicable_sym_variable(value, vt): + # Constants like 0, 1, 2, etc. can be unspecialized as SymNodeVariables sometimes, but we + # should NOT track them. If we use a single SymNodeVariable instance to track them + # across multiple uses, then guards created for one usage will incorrectly apply to + # all other usages of that constant, leading to unnecessary recompilations. + return ( + is_torch_sym(value) or isinstance(value, _DynamicScalar) + ) and isinstance(vt, SymNodeVariable) + + if ( + ( + self._can_lift_attrs_to_inputs(vt) + or _is_deduplicable_sym_variable(value, vt) + ) + and value not in self.tx.output.side_effects + and not is_wrapper_or_member_descriptor(value) + ): + vt = self.tx.output.side_effects.track_object_existing(value, vt) + + self.tx.output.variable_tracker_cache.add(value, self.source, vt) + return vt + + def _can_lift_attrs_to_inputs(self, vt): + return type(vt) in { + TensorVariable, + TensorWithTFOverrideVariable, + UserDefinedObjectVariable, + NumpyNdarrayVariable, + } + + def get_source(self): + return self.source + + def install_guards(self, *guards): + source = self.get_source() + try: + tmp = [source.make_guard(guard) for guard in guards] + except NotImplementedError: + return None + install_guard(*tmp, skip=1) + return {} + + @classmethod + def _type_dispatch(cls): + return cls._type_dispatch_impl(config.trace_numpy) + + @classmethod + @functools.cache + def _type_dispatch_impl(cls, trace_numpy): + # NB: Careful not to close over self to avoid ref cycle from lru_cache + entries = [ + ( + ( + torch.Tensor, + torch.nn.Parameter, + torch._subclasses.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + cls.wrap_tensor, + ), + ( + (tuple, list, odict_values, collections.deque, torch.Size), + cls.wrap_listlike, + ), + (tuple_iterator, cls.wrap_tuple_iterator), + (range_iterator, cls.wrap_range_iterator), + ((slice, range), cls.wrap_slice_range), + (tuple(common_constant_types), cls.wrap_literal), + (re.Pattern, cls.wrap_regex_pattern), + (weakref.ReferenceType, cls.wrap_weakref), + (torch.utils.hooks.RemovableHandle, cls.wrap_removable_handle), + (torch.jit.ScriptFunction, cls.wrap_jit_function), + (types.MappingProxyType, cls.wrap_mapping_proxy), + ] + + if trace_numpy and np: + entries.append((np.ndarray, cls.wrap_numpy_ndarray)) + + result = {} + for ts, fn in entries: + for t in ts if isinstance(ts, tuple) else (ts,): + assert t not in result + result[t] = fn + + return result + + def wrap_regex_pattern(self, value: re.Pattern): + # TODO(jansel): something like a REPR_MATCH might be more robust here + self.install_guards(GuardBuilder.ID_MATCH) + return ConstantLikeVariable(value) + + def wrap_weakref(self, value: weakref.ReferenceType): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WeakRefVariable.build(self.tx, value, source=self.source) + + def wrap_removable_handle(self, value): + # This means that the removable handle was created in some other frame. + # Our current infra requires the hook to be registered and removed in + # the same frame. So graph break. + # Related test - PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py -k TestAutograd.test_hooks + unimplemented( + gb_type="Attempted to represent unregistered RemovableHandle", + context="", + explanation="Dynamo attempted to build a representation of a torch.utils.hooks.RemovableHandle, " + "which is not supported. This happens because the RemovableHandle was created in another frame.", + hints=[], + ) + + def wrap_jit_function(self, value): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable( + value, "_torchdynamo_inline", source=self.source + ) + + def wrap_mapping_proxy(self, value): + self.install_guards(GuardBuilder.TYPE_MATCH) + # This might be suboptimal compared to dict guards. But mappingproxy is + # not very common, so its ok to guard on all keys. + self.install_guards(GuardBuilder.MAPPING_KEYS_CHECK) + all_const = all(ConstantVariable.is_literal(k) for k in value) + + if not all_const: + unimplemented( + gb_type="non-const keys in mappingproxy", + context=f"non-const keys: {[k for k in value.keys() if not ConstantVariable.is_literal(k)]}", # noqa: SIM118 + explanation="Dynamo expects mappingproxy keys to be constants.", + hints=[ + "Ensure your mappingproxy keys are constants (e.g. int, float, strings)", + ], + ) + + def build_key_value(k, v): + key = ConstantVariable.create(k) + source_key = k + + source_value = GetItemSource(self.get_source(), source_key) + res_value = LazyVariableTracker.create(v, source_value) + + return key, res_value + + items = dict(build_key_value(k, v) for k, v in value.items()) + + # Create a dict_vt to be used in the mapping proxy variable + dict_vt = ConstDictVariable(items, source=None) + result = MappingProxyVariable(dict_vt, source=self.source) + return self.tx.output.side_effects.track_mutable(value, result) + + @classmethod + @functools.cache + def _id_dispatch( + cls, + ) -> dict[int, Callable[["VariableBuilder", Any], VariableTracker]]: + from ..comptime import comptime + + entries = [ + (comptime, lambda self, value: ComptimeVariable()), + ( + dataclasses.fields, + lambda self, value: LambdaVariable( + _dataclasses_fields_lambda, + source=self.source, + **self.install_guards(GuardBuilder.CLOSURE_MATCH), + ), + ), + (torch.__version__, lambda self, value: TorchVersionVariable()), + ] + + result = {} + for ts, fn in entries: + for t in ts if isinstance(ts, (tuple, list)) else (ts,): + assert t not in result + result[id(t)] = fn + + return result + + def _wrap(self, value): + # import here to avoid circular dependencies + from torch.utils._triton import ( + has_triton, + has_triton_experimental_host_tma, + has_triton_tensor_descriptor_host_tma, + ) + + from ..decorators import ( + DynamoConfigPatchProxy, + ErrorOnGraphBreakDecoratorContextManager, + ) + + if has_triton(): + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + else: + + class JITFunction: + pass + + class Autotuner: + pass + + # default implementations, in case we don't have triton (or the wrong triton version) + def create_1d_tma_descriptor(): + pass + + def create_2d_tma_descriptor(): + pass + + class TensorDescriptor: + @staticmethod + def from_tensor(): + pass + + if has_triton_experimental_host_tma(): + from triton.tools.experimental_descriptor import ( # noqa: F811 + create_1d_tma_descriptor, + create_2d_tma_descriptor, + ) + if has_triton_tensor_descriptor_host_tma(): + from triton.tools.tensor_descriptor import TensorDescriptor # noqa: F811 + + # Handle exact type() match + type_dispatch = self._type_dispatch().get(type(value)) + if type_dispatch is not None: + return type_dispatch(self, value) + + # Handle exact id() match + id_dispatch = self._id_dispatch().get(id(value)) + if id_dispatch is not None: + return id_dispatch(self, value) + + # Everything else (NB: order matters!) + if ( + isinstance(value, torch.Tensor) + and type(value) + not in ( + # These torch-native subclasses have overly restrictive + # `__torch_function__` which prevents Dynamo from reading their + # tensor attributes like `is_nested` or calling methods like + # `_is_view`. + torch.nn.parameter.UninitializedBuffer, + torch.nn.parameter.UninitializedParameter, + ExpandedWeight, + ) + and type(value) not in config.nontraceable_tensor_subclasses + ): + if ( + type(value).__torch_dispatch__ is torch.Tensor.__torch_dispatch__ + or is_traceable_wrapper_subclass(value) + ): + return self.wrap_tensor(value) + + if is_namedtuple(value): + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + output = [ + LazyVariableTracker.create( + getattr(value, name), + source=AttrSource(self.source, name), + ) + for name in namedtuple_fields(type(value)) + ] + result = NamedTupleVariable( + output, tuple_cls=type(value), source=self.source + ) + return self.tx.output.side_effects.track_object_existing(value, result) + elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)): + self.install_guards(GuardBuilder.TYPE_MATCH) + all_const = all(ConstantVariable.is_literal(k) for k in value) + + # For all_const, we don't have to guard on anything yet. We guard on + # keys lazily by adding a dict_getitem entry for each accessed key. + # For cases where we need to guard on all keys, we lazily put guards + # during the dict call_method (check dicts.py) + if not all_const: + # Guard on the key order + # This is not ideal, i.e., there is no need to guard on the key + # order. But we guard on the key order because of the complexity + # + # 1) For non-constant objects, we can't save the key in the + # guard context because it can be memory heavy. We can add + # weakrefs but this complicates the accesses. + # + # 2) For non-constant objects, we also have to guard on the keys + # (like TENSOR_MATCH on tensor). We might also have guards on + # the attributes of the keys (like tensor.grad). To make this + # work in tree structure is complicated. + # + # So, instead we guard on the key order. While guarding on key + # order, we just save the indices and use it to access keys and + # values. Indices are cheap to save. + self.tx.output.guard_on_key_order.add(self.source) + + # We need all the keys to be hashable. We do this within the + # _HashableTracker class in dicts.py + def build_key_value(i, k, v): + base = self.get_source() + if all_const: + key = ConstantVariable.create(k) + source_key = k + else: + source_key = ConstDictKeySource(base, i) + key = LazyVariableTracker.create(k, source_key) + source_value = DictGetItemSource(base, source_key) + res_value = LazyVariableTracker.create(v, source_value) + + return key, res_value + + # Ensure that we call dict.keys and not value.keys (which can call + # overridden keys method). In the C++ guards, we relied on + # PyDict_Next to traverse the dictionary, which uses the internal + # data structure and does not call the overridden keys method. + result = dict( + build_key_value(i, k, v) + for i, (k, v) in enumerate(get_items_from_dict(value)) + ) + + if istype(value, collections.defaultdict): + factory_source = AttrSource(self.source, "default_factory") + result = DefaultDictVariable( + result, + type(value), + default_factory=VariableBuilder(self.tx, factory_source)( + value.default_factory + ), + source=self.source, + ) + else: + result = ConstDictVariable( + result, user_cls=type(value), source=self.source + ) + + return self.tx.output.side_effects.track_mutable(value, result) + elif isinstance(value, torch.nn.Module): + return self.wrap_module(value) + elif ConstantVariable.is_literal(value): # non-atomic literals + return self.wrap_literal(value) + elif isinstance(value, torch.overrides.TorchFunctionMode): + var = TorchFunctionModeVariable(value, source=self.source) + self.tx.output.side_effects.track_object_existing(value, var) + return var + elif istype(value, set): + if any(isinstance(x, torch.Tensor) for x in value): + unimplemented( + gb_type="Attempted to wrap a set with tensors", + context="Python set containing torch.Tensor elements", + explanation=( + "Dynamo cannot trace sets of tensors. To get a stable ordering, " + "Dynamo needs to convert the set into a list and the order might not be " + "stable if the set contains tensors." + ), + hints=[ + "Use a dictionary where the keys are tensors.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # The list gives a ordering for the set items. The ordering is based + # on the Python hash and it is not related to object ordering inside + # the set object. The order being incorrect at runtime will lead to + # a recompilation. + L = list(value) + items = [ + LazyVariableTracker.create( + v, source=NonSerializableSetGetItemSource(self.source, i) + ) + for i, v in enumerate(L) + ] + result = SetVariable(items, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif istype(value, frozenset) and all( + ( + # For DBR quantization, we could get a frozenset of torch funcs. + (type(x) is types.BuiltinMethodType and x.__module__ == "torch") + or + # Another commonly used frozenset of types. + x in torch.utils._pytree.BUILTIN_TYPES + ) + for x in value + ): + # For the limited cases of frozenset here, we know the items won't + # change across runs, so we can safely create sourceless VTs for + # them and only guard on the frozenset id. + # TODO support source for sets and remove the special logics here. + items = [SourcelessBuilder.create(self.tx, v) for v in value] + self.install_guards(GuardBuilder.ID_MATCH) + return FrozensetVariable(items, source=self.source) + elif isinstance( + value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType) + ): + self.install_guards(GuardBuilder.ID_MATCH) + return EnumVariable(value=value, source=self.source) + elif DebuggingVariable.is_reorderable_logging_function(value): + # Put this above builtin_callable so that print() can be handled + # along with other builtin debugging functions + self.install_guards(GuardBuilder.BUILTIN_MATCH) + return DebuggingVariable(value, source=self.source) + elif isinstance(value, logging.Logger): + self.install_guards(GuardBuilder.TYPE_MATCH) + return LoggingLoggerVariable(value, source=self.source) + elif is_utils_checkpoint(value): + return build_checkpoint_variable(source=self.source) + elif is_invoke_subgraph(value): + return build_invoke_subgraph_variable(source=self.source) + elif LocalMapWrappedHigherOrderVariable.should_wrap_in_hop(value): + return LocalMapWrappedHigherOrderVariable.build(source=self.source) + elif isinstance(value, functools.partial): + func_src = AttrSource(self.get_source(), "func") + func_obj = VariableBuilder(self.tx, func_src)(value.func) + + args = [] + args_source = AttrSource(self.get_source(), "args") + for i, arg in enumerate(value.args): + args.append( + VariableBuilder(self.tx, GetItemSource(args_source, i))(arg) + ) + + keywords = {} + keywords_source = AttrSource(self.get_source(), "keywords") + for k, v in value.keywords.items(): + if not ConstantVariable.is_literal(k): + unimplemented( + gb_type="functools.partial() with non-literal keyword", + context=f"non-literal keyword: {k}", + explanation="functools.partial() expects literal/string keywords", + hints=[*graph_break_hints.USER_ERROR], + ) + keywords[k] = VariableBuilder( + self.tx, DictGetItemSource(keywords_source, k) + )(v) + + install_guard( + self.get_source().make_guard(GuardBuilder.TYPE_MATCH), + keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH), + args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH), + ) + return FunctoolsPartialVariable(func_obj, args, keywords) + elif is_typing(value): + # typing.List, typing.Mapping, etc. + self.install_guards(GuardBuilder.ID_MATCH) + return TypingVariable( + value, + source=self.source, + ) + elif np is not None and isinstance(value, np.generic): + # numpy array scalars: convert to 0D arrays + return self.wrap_numpy_ndarray(np.asarray(value)) + elif trace_rules.is_numpy(value): + assert np + if istype(value, types.MethodType): + # Dont guard on cython functions as they dont change ids + if inspect.isfunction(value.__func__): + install_guard( + AttrSource(self.source, "__func__").make_guard( + GuardBuilder.CLOSURE_MATCH + ) + ) + elif inspect.isclass(value): + self.install_guards(GuardBuilder.CLASS_MATCH) + elif inspect.isfunction(value): + self.install_guards(GuardBuilder.CLOSURE_MATCH) + elif callable(value): + self.install_guards(GuardBuilder.ID_MATCH) + else: + self.install_guards(GuardBuilder.TYPE_MATCH) + return NumpyVariable(value, source=self.source) + elif trace_rules.is_numpy_dtype(value): + self.install_guards(GuardBuilder.ID_MATCH) + return NumpyDTypeVariable(value, source=self.source) + elif trace_rules.is_numpy_type_info(value): + if isinstance(value, np.iinfo): + self.install_guards(GuardBuilder.TYPE_MATCH) + dt_source = AttrSource(self.source, "dtype") + install_guard(dt_source.make_guard(GuardBuilder.ID_MATCH)) + else: + self.install_guards(GuardBuilder.ID_MATCH) + return ConstantLikeVariable(value, source=self.source) + # NB: These can't be put in type_dispatch, they have to run later + elif CollectiveFunctionRewriteVariable.can_rewrite(value): + self.install_guards(GuardBuilder.CLOSURE_MATCH) + return CollectiveFunctionRewriteVariable.create( + self.tx, + value, + source=self.source, + ) + elif istype(value, torch.autograd.function.FunctionMeta): + self.install_guards(GuardBuilder.CLASS_MATCH) + return AutogradFunctionVariable( + value, + source=self.source, + ) + elif isinstance(value, torch.autograd.function.FunctionCtx): + actual_saved_tensors = None + try: + actual_saved_tensors = value.saved_tensors + except RuntimeError: + pass + + saved_tensors = [] + guards = [self.source.make_guard(GuardBuilder.TYPE_MATCH)] + if isinstance(actual_saved_tensors, tuple): + saved_tensors_source = AttrSource(self.source, "saved_tensors") + guards.append( + saved_tensors_source.make_guard(GuardBuilder.SEQUENCE_LENGTH) + ) + for i, v in enumerate(actual_saved_tensors): + saved_tensors.append( + VariableBuilder( + self.tx, GetItemSource(saved_tensors_source, i) + )(v) + ) + install_guard(*guards) + + return self.tx.output.side_effects.track_object_existing( + value, + AutogradFunctionContextVariable( + value, + source=self.source, + saved_tensors=SavedTensorBox(saved_tensors), + ), + ) + elif ( + isinstance(value, types.MethodType) + and istype( + getattr(value, "__self__", None), torch.autograd.function.FunctionMeta + ) + and getattr(value, "__name__", "") == "apply" + and value == getattr(value.__self__, "apply", None) + ): + # handle aliased autograd function `apply` calls + install_guard( + AttrSource(self.get_source(), "__func__").make_guard( + GuardBuilder.CLOSURE_MATCH + ) + ) + return GetAttrVariable( + AutogradFunctionVariable( + value.__self__, source=AttrSource(self.source, member="__self__") + ), + "apply", + ) + elif isinstance(value, torch._C._ImperativeEngine): + self.install_guards(GuardBuilder.ID_MATCH) + return AutogradEngineVariable(value, source=self.source) + elif ( + value + is torch._dynamo.external_utils.FakeCompiledAutogradEngine._exec_final_callbacks_stub + ): + self.install_guards(GuardBuilder.CLOSURE_MATCH) + return LambdaVariable( + lambda: UserFunctionVariable( + torch._dynamo.external_utils.FakeCompiledAutogradEngine.exec_final_callbacks, + ).call_function( + self.tx, + (self.tx.output.side_effects.get_ca_final_callbacks_var(),), + {}, + ) + ) + elif isinstance(value, DynamoConfigPatchProxy): + return DynamoConfigPatchVariable(value.changes) + elif isinstance(value, ErrorOnGraphBreakDecoratorContextManager): + return ErrorOnGraphBreakVariable(value.error_on_graph_break) + elif callable(value) and trace_rules.lookup_callable(value) is not None: + if trace_rules.is_callable_allowed(value): + self.tx.output.has_user_defined_allowed_in_graph = True + return trace_rules.lookup_callable(value).create_with_source( + value, source=self.source + ) + elif np and isinstance(value, np.number): + return self.wrap_unspecialized_primitive(value) + elif isinstance(value, HigherOrderOperator): + if value is torch._higher_order_ops.invoke_subgraph: + unimplemented( + gb_type="Attempted to wrap torch._higher_order_ops.invoke_subgraph", + context="", + explanation="Directly using invoke_subgraph is not supported. Use nested_compile_region", + hints=[], + ) + self.install_guards(GuardBuilder.TYPE_MATCH) + return TorchHigherOrderOperatorVariable.make(value, source=self.source) + elif isinstance(value, torch.cuda.StreamContext): + self.install_guards(GuardBuilder.ID_MATCH) + stream_source = AttrSource(self.source, "stream") + stream_var = VariableBuilder(self.tx, stream_source)(value.stream) + return StreamContextVariable.create(self.tx, stream_var) + elif isinstance(value, torch.Stream): + # This refers to the device-agnostic torch.Stream + self.install_guards(GuardBuilder.TYPE_MATCH) + index = register_user_object(value, self.source) + stream_proxy = self.tx.output.create_proxy( + "call_function", get_external_object_by_index, (index,), {} + ) + set_example_value(stream_proxy.node, value) + var = StreamVariable( + stream_proxy, value, source=self.source, user_object_index=index + ) + return self.tx.output.side_effects.track_object_existing(value, var) + elif isinstance(value, (torch._C._SDPAParams)): + self.install_guards(GuardBuilder.TYPE_MATCH) + return SDPAParamsVariable.create(self.tx, value, self.source) + elif isinstance(value, torch._functorch.pyfunctorch.FuncTorchInterpreter): + self.install_guards(GuardBuilder.ID_MATCH) + return FuncTorchInterpreterVariable(value) + elif isinstance(value, torch.Event): + self.install_guards(GuardBuilder.TYPE_MATCH) + index = register_user_object(value, self.source) + event_proxy = self.tx.output.create_proxy( + "call_function", + get_external_object_by_index, + (index,), + {}, + ) + set_example_value(event_proxy.node, value) + return EventVariable( + event_proxy, + value, + index, + source=self.source, + ) + elif ( + istype(value, contextlib.nullcontext) + and inspect.getattr_static(value, "enter_result", None) is None + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return NullContextVariable(source=self.source) + elif KeyedJaggedTensorVariable.is_matching_object(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = KeyedJaggedTensorVariable(value, source=self.source) + # TODO: this doing it manually is bad + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, torch.optim.Optimizer): + self.install_guards(GuardBuilder.ID_MATCH) + self.source = OptimizerSource(self.source) + return OptimizerVariable(value, source=self.source) + elif isinstance(value, torch.DispatchKeySet): + self.install_guards(GuardBuilder.DISPATCH_KEY_SET_MATCH) + return DispatchKeySetVariable(value) + elif WorldMetaClassVariable.is_group_member_type(value): + return WorldMetaClassVariable(value, source=self.source) + elif ProcessGroupVariable.is_process_group(value): + self.install_guards(GuardBuilder.ID_MATCH) + return ProcessGroupVariable(value, source=self.source) + elif DeviceMeshVariable.is_device_mesh(value): + # TODO: see if we need to add custom guard instead of a simple ID_MATCH + self.install_guards(GuardBuilder.EQUALS_MATCH) + return DeviceMeshVariable(value, source=self.source) + elif PlacementClassVariable.is_placement_type(value): + # TODO: see if we need to add custom guard instead of a simple ID_MATCH + self.install_guards(GuardBuilder.ID_MATCH) + return PlacementClassVariable(value, source=self.source) + elif PlacementVariable.is_placement(value): + # TODO: see if we need to add custom guard instead of a simple ID_MATCH + self.install_guards(GuardBuilder.EQUALS_MATCH) + return PlacementVariable( + value, + source=self.source, + ) + elif ( + id(value) in ITERTOOLS_TYPE_IDS + and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS + ): + self.install_guards(GuardBuilder.CLASS_MATCH) + return ItertoolsVariable(value, source=self.source) + elif isinstance(value, _DynamicScalar): + is_int = isinstance(value, DynamicInt) + source = DynamicScalarSource(self.source, is_int) + if id(value) in self.tx.output.root_tracer.dynamic_scalar_nodes: + # If we've already seen this dynamic scalar, reuse the existing + # SymInt/SymFloat node. + node = self.tx.output.root_tracer.dynamic_scalar_nodes[id(value)] + else: + sym = self.tx.output.shape_env.create_unspecified_symbol( + value.real, + source=source, + dynamic_dim=DimDynamic.DYNAMIC, + ) + node = self.tx.output.shape_env.create_symintnode( + sym, + hint=value.real, + source=source, + ) + + # Bind to graph input + sym_node_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(node), + node, + source=source, + ) + sym_node_proxy.node.meta["grapharg"] = GraphArg( + source, + node, + False, + None, + is_tensor=False, + example_strong_ref=node, + ) + sym_expr = node.node.expr + assert isinstance(sym_expr, sympy.Symbol), ( + f"{sym_expr} is not a basic Symbol." + ) + self.tx.output.tracked_fakes.append(TrackedFake(node, source, None)) + return SymNodeVariable.create(self.tx, sym_node_proxy, node) + elif is_torch_sym(value): + # Note: this doesn't handle nested symints. + # For SymBool input, we reuse the infra for SymInt by simulating SymBool with a SymInt in dynamo. + + # Concretely, + # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source). + # so that guards on the SymInts can be effectively applied on the original SymBool in user program. + # 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program + # depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly. + source = ( + self.source + if isinstance(value, torch.SymInt) + else ConvertIntSource(self.source) + ) + if value.node.has_hint(): + new_symint = ( + self.tx.output.shape_env.create_unspecified_symint_and_symbol( + int(value.node.hint), + source, + dynamic_dim=DimDynamic.DYNAMIC, + ) + ) + else: + if isinstance(value, torch.SymBool): + # We need to create an unbacked symint to replace the unbacked symbool. + new_symint = self.tx.output.shape_env.create_unbacked_symint() + else: + # TODO (yidi): we need to figure out a way to propagate the guards + # we accumulated when tracing the subggraph to outer shape_env. For normal symints, + # this is automatically done by evaluating the guards once but this + # will cause data-dependent error when we evaluate the outer unbacked symints. + # The test case that triggers this graph break is test_cond_unbacked_symint_closure + unimplemented( + gb_type="Attempted to wrap unbacked SymInt", + context="", + explanation="Unbacked SymInt input is not supported yet.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + sym_node_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(new_symint), + new_symint, + source=source, + ) + + sym_node_proxy.node.meta["grapharg"] = GraphArg( + source, + new_symint, + False, + None, + is_tensor=False, + example_strong_ref=new_symint, + ) + # We bind the new_symint to graph input. + sym_expr = new_symint.node.expr + assert isinstance(sym_expr, sympy.Symbol), ( + f"{sym_expr} is not a basic Symbol." + ) + self.tx.output.tracked_fakes.append(TrackedFake(new_symint, source, None)) + + tracing_symint = ( + new_symint if isinstance(value, torch.SymInt) else new_symint == 1 + ) # cast it back to symbool for tracing + return SymNodeVariable(sym_node_proxy, tracing_symint) + + elif isinstance(value, (JITFunction, Autotuner)): + self.install_guards(GuardBuilder.ID_MATCH) + return TritonKernelVariable( + value, + None, # No kernel idx provided + None, # No grid provided + source=self.source, + ) + elif value is create_1d_tma_descriptor: + return CreateTMADescriptorExperimentalVariable(rank=1) + elif value is create_2d_tma_descriptor: + return CreateTMADescriptorExperimentalVariable(rank=2) + elif value is TensorDescriptor.from_tensor: + return CreateTMADescriptorStableVariable() + elif isinstance(value, torch.amp.autocast_mode.autocast): + self.install_guards(GuardBuilder.ID_MATCH) + return AutocastModeVariable( + target_values=[ + value.device, + value.fast_dtype, + value._enabled, + value._cache_enabled, + ], + source=self.source, + ) + elif TorchCtxManagerClassVariable.is_matching_cls(value): + if inspect.isclass(value): + self.install_guards(GuardBuilder.CLASS_MATCH) + elif inspect.isfunction(value): + self.install_guards(GuardBuilder.CLOSURE_MATCH) + return TorchCtxManagerClassVariable(value, source=self.source) + elif inspect.getattr_static(value, "__script_if_tracing_wrapper", False): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable( + value, "__original_fn", source=self.source + ) + elif is_lru_cache_wrapped_function(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source) + elif value is traceback.clear_frames: + return TracebackVariable(source=self.source) + elif value is sys.exc_info or ( + sys.version_info >= (3, 11) and value is sys.exception + ): + return SysFunctionVariable(value, source=self.source) + elif is_function_or_wrapper(value) and inspect.getattr_static( + value, "_torchdynamo_inline", False + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable( + value, "_torchdynamo_inline", source=self.source + ) + elif value is functools.wraps: + self.install_guards(GuardBuilder.ID_MATCH) + return FunctoolsWrapsVariable(value, source=self.source) + elif value is collections.namedtuple: + self.install_guards(GuardBuilder.ID_MATCH) + return CollectionsNamedTupleFunction(value, source=self.source) + elif isinstance( + value, types.BuiltinMethodType + ) and BuiltinMethodVariable.is_supported_builtin_method(value): + self.install_guards(GuardBuilder.ID_MATCH) + return BuiltinMethodVariable(value, source=self.source) + elif is_function(value) and value in (float.fromhex, float.hex): + self.install_guards(GuardBuilder.ID_MATCH) + return GetAttrVariable( + BuiltinVariable(float, source=self.source), + value.__name__, + ) + elif is_function_or_wrapper(value): + value, attr_name = unwrap_with_attr_name_if_wrapper(value) + # For these wrappers, Dynamo points to the wrapped function, + # so source needs to be updated as well. + if attr_name is not None: + self.source = AttrSource(self.source, attr_name) + return trace_rules.lookup(value).create_with_source( + value, source=self.source + ) + elif value is random.Random: + self.install_guards(GuardBuilder.ID_MATCH) + return RandomClassVariable(source=self.source) + elif istype(value, random.Random) and RandomVariable.is_supported_random_obj( + value + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = RandomVariable(value, source=self.source) + self.tx.output.side_effects.track_mutable(value, result) + return result + # Don't use istype, since some python modules are not subclasses of types.ModuleType directly. + # E.g, type(torch.ops) -> , + # type(torch.backends.cudnn) -> + elif isinstance(value, (types.ModuleType, replay_record.DummyModule)): + self.install_guards(GuardBuilder.MODULE_MATCH) + result = PythonModuleVariable( + value, + source=self.source, + ) + self.tx.output.side_effects.track_object_existing(value, result) + return result + elif isinstance(value, types.MethodType) and isinstance( + value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec) + ): + # don't let MethodTypes fall through to UserDefinedObject, + # which doesn't support 'CALL_FUNCTION' + + # TODO(whc): Why do we limit this to methods on NNModules? + # I don't have a good reason for this, but it preserves the existing behavior + # for MBartForConditionalGeneration, which generates many graph breaks and OOMs otherwise. + # I suspect we probably want to relax this check and dig deeper there. + + # In order to construct a MethodVariable in Dynamo, we start with an actual method obj from python, + # but need to separately wrap its underlying `__func__` and its `self` argument. We wrap `self` here + # and then `__func__` gets wrapped inside UserMethodVariable. + self_obj = VariableBuilder( + self.tx, source=AttrSource(self.source, "__self__") + )(value.__self__) + assert self_obj and isinstance(self_obj, VariableTracker), ( + "Failed to produce a valid self obj" + ) + return UserMethodVariable( + value.__func__, + self_obj, + source=self.source, + ) + elif isinstance(value, types.GetSetDescriptorType): + # GetSet descriptors are C functions attached to an attribute lookup + # using PyGetSetDef. Python, on attribute lookup, can decide to + # create a new object on the fly, and therefore the `id` of the + # descriptors is not guaranteed to be same for different attribute + # accesses. Since these are unlikely to change during the program + # execution, we can skip guarding on them. + return GetSetDescriptorVariable(value) + elif isinstance(value, types.MethodWrapperType): + # Method-wrappers are written in C, and they are not guaranteed to + # return the same object on attribute lookup. Therefore, we cannot + # insert a ID_MATCH guard here. method-wrappers are very + # unlikely to change, so its ok to skip the guard here. + return MethodWrapperVariable(value) + elif issubclass(type(value), type) and issubclass(value, BaseException): + # match user defined exceptions + self.install_guards(GuardBuilder.ID_MATCH) + return UserDefinedExceptionClassVariable(value) + elif issubclass(type(value), type): + if value in ( + torch.utils.hooks.BackwardHook, + torch.nn.Parameter, + torch.nn.Buffer, + ): + # TODO(jansel): combine this case with the one above + return trace_rules.lookup(value).create_with_source( + value, source=self.source + ) + if value is torch.autograd._unsafe_preserve_version_counter: + self.install_guards(GuardBuilder.CLASS_MATCH) + return PreserveVersionContextVariable.constructor(self.tx) + if ( + # `value` must be a strict subclass of `torch.Tensor` + issubclass(value, torch.Tensor) + and value is not torch.Tensor + # `TensorSubclassVariable` is not for subclass that overrides + # `torch_dispatch`. + and value.__torch_dispatch__ is torch.Tensor.__torch_dispatch__ + # `TensorSubclassVariable` would lead to construction of + # `TensorWithTFOverrideVariable`, but we don't want that for + # traceable wrapper subclasses (we wrap those subclass instances + # into `TensorVariable`). + and not is_traceable_wrapper_subclass_type(value) + ): + return TensorSubclassVariable(value, source=self.source) + + if not is_from_closure_source(self.source): + # For closure source, the variable comes from LOAD_SUPER_ATTR, + # which calls self.__class__. This is internal Cpython + # implementation, and it is rare for the user to modify + # self.__class__ manually. + # For other cases, this is a userdefined class, so install an + # ID_MATCH even if its a global variable. + self.install_guards(GuardBuilder.CLASS_MATCH) + + if is_opaque_type(value): + return OpaqueObjectClassVariable( + value, + source=self.source, + ) + + return UserDefinedClassVariable( + value, + source=self.source, + ) + elif TorchScriptObjectVariable.is_matching_cls(type(value)): + from ..source import ( + FlattenScriptObjectSource, + ScriptObjectQualifiedNameSource, + ) + + if torch._library.fake_class_registry.tracing_with_real(value): + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(value), + value, + source=self.source, + ) + + # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default + # setting example to be real value because these example values will be used + # as example_inputs for user compiler. + proxy.node.meta["grapharg"] = GraphArg( + self.source, value, False, None, False, value + ) + return TorchScriptObjectVariable.create( + proxy, + value, + source=self.source, + ) + + if is_opaque_type(type(value)): + # Check if this is a value-type opaque object (registered as both opaque type and constant) + if is_opaque_value_type(type(value)): + # Value-type: guard on equality (will use __eq__) + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return TorchScriptObjectVariable.create( + value, + value, + source=self.source, + ) + else: + # Reference-type: guard only on type/identity + self.install_guards(GuardBuilder.TYPE_MATCH) + + elif not hasattr(value, "__obj_flatten__"): + # This exists to allow a smoother transition. + # The implications are: + # The script objects won't be tracked as proxies. + # Methods on these objects won't show up in the graph. + # The original script object might be mutated. + return self.wrap_user_defined(value) + else: + # Install the guards on the fully qualified name of the script object + LazyVariableTracker.realize_all( + VariableBuilder( + self.tx, ScriptObjectQualifiedNameSource(self.source) + )( + value._type().qualified_name() # type: ignore[attr-defined] + ) + ) + # Install the guards on the content of the script object by setting the source + # to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents. + LazyVariableTracker.realize_all( + VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))( + value.__obj_flatten__() + ) + ) + + fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj( + self.tx.output.fake_mode, value + ) + + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(value), + fake_script_obj, + source=self.source, + ) + + # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default + # setting example to be real value because these example values will be used + # as example_inputs for user compiler. + proxy.node.meta["grapharg"] = GraphArg( + self.source, value, False, None, False, fake_script_obj + ) + return TorchScriptObjectVariable.create( + proxy, + fake_script_obj, + source=self.source, + ) + elif ( + isinstance(value, (dict, collections.OrderedDict)) + and type(value).__new__ is dict.__new__ + ): + # Construct a dict_vt that will reside inside the UserDefinedDictVariable + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # Guard on the key order + self.tx.output.guard_on_key_order.add(self.source) + + # We need all the keys to be hashable. We do this within the + # _HashableTracker class in dicts.py + def build_key_value(i, k, v): + base = self.get_source() + source_key = ConstDictKeySource(base, i) + key = LazyVariableTracker.create(k, source_key) + + source_value = DictSubclassGetItemSource(base, source_key) + res_value = LazyVariableTracker.create(v, source_value) + + return key, res_value + + # Ensure that we call dict.keys and not value.keys (which can call + # overridden keys method). In the C++ guards, we relied on + # PyDict_Next to traverse the dictionary, which uses the internal + # data structure and does not call the overridden keys method. + result = dict( + build_key_value(i, k, v) + for i, (k, v) in enumerate(get_items_from_dict(value)) + ) + + dict_vt = ConstDictVariable( + result, + user_cls=( + collections.OrderedDict + if isinstance(value, collections.OrderedDict) + else dict + ), + mutation_type=ValueMutationExisting(), + source=self.source, + ) + # Force this to reconstruct on mutation to keep the reconstruction + # bytecode simple + dict_vt.should_reconstruct_all = True + + result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, tuple): + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # NB - Be careful in not triggering user code. Guards also work on + # the underlying tuple data structure. + output = [ + LazyVariableTracker.create( + tuple.__getitem__(value, i), + source=GetItemSource(self.get_source(), i), + ) + for i in range(tuple.__len__(value)) + ] + + tuple_vt = TupleVariable( + output, source=self.source, mutation_type=ValueMutationExisting() + ) + result = UserDefinedTupleVariable( + value, tuple_vt=tuple_vt, source=self.source + ) + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, list): + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # NB - Be careful in not triggering user code. Guards also work on + # the underlying list data structure. + output = [ + LazyVariableTracker.create( + list.__getitem__(value, i), + source=ListGetItemSource(self.get_source(), i), + ) + for i in range(list.__len__(value)) + ] + list_vt = ListVariable( + output, source=self.source, mutation_type=ValueMutationExisting() + ) + result = UserDefinedListVariable(value, list_vt=list_vt, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, (set, frozenset)): + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + L = list(dict.fromkeys(value)) + output = [ + LazyVariableTracker.create( + list.__getitem__(L, i), + source=NonSerializableSetGetItemSource(self.get_source(), i), + ) + for i in range(list.__len__(L)) + ] + set_vt_cls = SetVariable if isinstance(value, set) else FrozensetVariable + set_vt = set_vt_cls( + output, source=self.source, mutation_type=ValueMutationExisting() + ) + result = UserDefinedSetVariable(value, set_vt=set_vt, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif issubclass(type(value), MutableMapping): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = MutableMappingVariable(value, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif is_frozen_dataclass(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = FrozenDataClassVariable.create(self.tx, value, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, dict_keys): + if all(ConstantVariable.is_literal(k) for k in value): + # If the dict_keys object is passed from outside the compile region, it must either be passed along with + # the corresponding dict object or treated as a set (when only the keys are passed into the compiled region). + # - If it is passed along with the dict, the dict object itself is already guarded. + # - If only the dict_keys object is passed, we add EQUALS_MATCH and SEQUENCE_LENGTH guards + # to ensure it remains unchanged across multiple runs. + items = [SourcelessBuilder.create(self.tx, v) for v in value] + install_guard( + self.get_source().make_guard(GuardBuilder.SEQUENCE_LENGTH), + self.get_source().make_guard(GuardBuilder.EQUALS_MATCH), + ) + return DictKeySetVariable(items, source=self.source) + else: + unimplemented( + gb_type="non-const keys in dict_keys", + context=f"non-const keys: {[k for k in value if not ConstantVariable.is_literal(k)]}", + explanation="Dynamo expects dict_keys keys to be constants.", + hints=[ + "Ensure your dict_keys keys are constants (e.g. int, float, strings)", + ], + ) + elif IntWrapperVariable.is_matching_object(value): + from torch.export.dynamic_shapes import _DimHintType + + if value.dynamism is None or value.dynamism.type == _DimHintType.STATIC: + return self.wrap_symint(value.val) + elif value.dynamism.type == _DimHintType.DYNAMIC: + log.debug( + "%s marked %s via IntWrapper", + self.source.name, + DimDynamic.DYNAMIC, + ) + return self.wrap_symint( + value.val, + dynamism=DimDynamic.DYNAMIC, + context=SymIntSymbolicContext( + constraint=RelaxedUnspecConstraint(warn_only=False) + ), + ) + elif value.dynamism.type == _DimHintType.AUTO: + log.debug( + "%s marked %s via IntWrapper", + self.source.name, + DimDynamic.DYNAMIC, + ) + return self.wrap_symint(value.val, dynamism=DimDynamic.DYNAMIC) + else: + raise RuntimeError(f"Undefined dynamism {value.dynamism}") + else: + return self.wrap_user_defined(value) + + def wrap_user_defined(self, value: Any): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = UserDefinedObjectVariable(value, source=self.source) + if not SideEffects.cls_supports_mutation_side_effects(type(value)): + # don't allow STORE_ATTR mutation with custom __setattr__ + return result + return self.tx.output.side_effects.track_object_existing(value, result) + + def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): + for item in value: + if item is value: + unimplemented( + gb_type="list elements are pointing to the list itself", + context="", + explanation="Dynamo does not support lists whose items reference to itself", + hints=["Avoid using self referential list"], + ) + + if config.specialize_int and type(value) is torch.Size: + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value) + + # One can index a tensor with a list/tuple. Therefore, we need to + # have a stricter match. + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # Tuples are immutable objects, so we should mark its items static. This + # avoids wrapping of tuple items as symints. This helps for nn module + # attributes like conv2d strides, dilations. + if ( + istype(value, tuple) + and all(ConstantVariable.is_literal(item) for item in value) + and self.source.guard_source.is_unspecialized_nn_module() + ): + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return TupleVariable([ConstantVariable.create(item) for item in value]) + + output = [ + LazyVariableTracker.create( + item, + source=GetItemSource(self.get_source(), i), + ) + for i, item in enumerate(value) + ] + + maybe_gm = self.tx.output.local_scope.get("self") + if isinstance( + self.source, LocalSource + ) and self.source.local_name in get_locals_to_steal(maybe_gm): + # The input tensor list to dynamo from compiled autograd may contain activations + # which are freed as they are used in inductor. Dynamo's default behavior is to + # lift all tensors to the graph inputs, but this will cause dynamo to hold an + # extra reference to the activation tensors and increase peak memory usage. + # To allow freeing ASAP, we keep the list as graph argument to the dynamo output + # graph, and unpack it locally. + # e.g. instead of `def forward(self, L_inputs_0_, L_inputs_1_, ...):`, we have + # `def forward(self, L_inputs_):` + source = self.source + assert isinstance(value, list) + tensor_list_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(value), + value, + source=source, + ) + tensor_list_proxy.node.meta["steal_arg"] = True + + list_variable = wrap_fx_proxy_cls( + target_cls=TensorVariable, + tx=self.tx, + proxy=tensor_list_proxy, + example_value=value, + subclass_type=None, + source=source, + ) + + # Apply relevant logic from `VariableTracker.build(value[i])` + # (except for the `create_graph_input` stuff). + guards = [] + for i, tensor_variable in enumerate(list_variable.items): + source_i = GetItemSource(base=source, index=i, index_is_slice=False) + # access unpacked tensor from this list instead of from a lifted arg + self.tx.output.input_source_to_var[source_i] = tensor_variable + tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict( + value[i] + ) + guard = functools.partial( + GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i]) + ) + guards.append(source_i.make_guard(guard)) + + install_guard(*guards, skip=1) + + grapharg = GraphArg( + source, + value, + pass_arg_as_tensor=False, + fake_tensor=None, + is_tensor=False, + ) + tensor_list_proxy.node.meta["grapharg"] = grapharg + + # The following is very important for maintaining the "python object + # <==> variable tracker" 1-to-1 mapping, which is mainly handled via + # `side_effects`. Note that constructing `tensor_variable` above + # already adds it to graph arg, but we never registered it with + # `side_effects`. The preemptive `realize` calls here basically + # does that registration (at the end of `self.__call__`). + # + # A slightly cleaner alternative is to register the + # `tensor_variable`s above with `side_effects` directly, and just + # return the `list_variable`, but that breaks some tensor-subclass + # related tests like `test_inputs_aliasing_bytecode_stack_restore`, + # because `tensor_variable` is constructed via + # `handle_traced_output`, which doesn't really expect/handle tensor + # subclass. + # + # Eventually, we expect to fix remove all of these by having Dynamo + # auto-boxing inputs to the compiled graph, see + # https://github.com/pytorch/pytorch/issues/153701. + for vt in output: + vt.realize() + + result = BaseListVariable.cls_for_instance(value)(output, source=self.source) + if istype(value, (list, collections.deque)): + return self.tx.output.side_effects.track_mutable(value, result) + return result + + def wrap_tuple_iterator(self, value: tuple_iterator): + self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN) + output = [ + VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))( + tuple_iterator_getitem(value, i) + ) + for i in range(tuple_iterator_len(value)) + ] + result = TupleIteratorVariable(output, source=self.source) + return self.tx.output.side_effects.track_mutable(value, result) + + def wrap_range_iterator(self, value: range_iterator): + self.install_guards(GuardBuilder.RANGE_ITERATOR_MATCH) + # Get all the values from the range iterator; no need to install guards + # on items since `RANGE_ITERATOR_MATCH` guarantees the same items. + items = [ConstantVariable.create(v) for v in copy.deepcopy(value)] + result = ListIteratorVariable(items, source=self.source) + return self.tx.output.side_effects.track_mutable(value, result) + + def wrap_slice_range(self, value: Union[slice, range]): + items = [ + VariableBuilder(self.tx, AttrSource(self.get_source(), k))( + getattr(value, k) + ) + for k in ("start", "stop", "step") + ] + self.install_guards(GuardBuilder.TYPE_MATCH) + if isinstance(value, slice): + return SliceVariable(items, self.tx, source=self.source) + else: + return RangeVariable(items, source=self.source) + + def mark_static_input(self, value: torch.Tensor, guard: bool): + from ..decorators import mark_static_address + + static_inputs_log.debug( + "Marking static input %s, id: %s)", self.source.name, id(value) + ) + mark_static_address(value, guard=guard) + + # Check if we've seen this tensor before and update graph metadata if needed + # As long as this runs before AOT this is sound + if value in self.tx.output.side_effects: + var = self.tx.output.side_effects[value] + var.proxy.node.meta["tensor_dict"]["_dynamo_static_input_type"] = ( + value._dynamo_static_input_type + ) + + def wrap_module(self, value: torch.nn.Module): + from ..eval_frame import OptimizedModule + + if len(value.__dict__) == 0: + unimplemented( + gb_type="Uninitialized nn.Module", + context=typestr(value), + explanation=f"Attempted to trace an uninitialized nn.Module of type {typestr(value)}.", + hints=[ + *graph_break_hints.USER_ERROR, + "Ensure your nn.Module instance has called `super().__init__()`.", + ], + ) + if istype(value, OptimizedModule): + # Check if the optimized module was disabled + if inspect.getattr_static(value.forward, "_torchdynamo_disable", False): + # This bytecode is mostly of kind LOAD_ATTR or LOAD_METHOD. If + # we graph break here, Dynamo does not know how to create + # continuation functions for such bytecodes. So, we delay the + # graph break to CALL_FUNCTION. + msg = inspect.getattr_static( + value.forward, "_torchdynamo_disable_msg", None + ) + return DelayGraphBreakVariable( + source=self.source, + msg=f"Optimized `nn.Module` is wrapped with `torch.compiler.disable` (reason: {msg})", + ) + + self.install_guards(GuardBuilder.TYPE_MATCH) + self.source = AttrSource(self.source, "_orig_mod") + return self.wrap_module(value._orig_mod) + + if ( + isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM)) + and not config.allow_rnn + ): + unimplemented( + gb_type="Attempted to wrap RNN, GRU, or LSTM", + context=str(value), + explanation="Dynamo does not support RNN, GRU, or LSTM.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + if getattr(value, "_is_fsdp_managed_module", False): + # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule] + # in fully_sharded_data_parallel.py for more information + + # we can't do this assert inside FSDP constructor, + # since we don't know yet whether dynamo will be used + if not getattr(value, "_fsdp_use_orig_params", False): + unimplemented( + gb_type="FSDP with use_orig_params=False", + context="", + explanation="Dynamo only supports FSDP with use_orig_params=True", + hints=[], + ) + + # Note on FSDP guarding + # Eager FSDP already assumes (requires, but without enforcement) + # that users don't mutate their model parameters/structure after + # FSDP wrapping, because FSDP wouldn't notice or update its + # FlatParams. + # + # Therefore, torch.compile can skip guarding on params or submodule + # structure of fsdp_managed modules, by using FSDPNNModuleSource as + # the guard source. This behavior is gated on + # config.skip_fsdp_guards. + self.install_guards(GuardBuilder.TYPE_MATCH) + result = FSDPManagedNNModuleVariable(value, source=self.get_source()) + if not SideEffects.cls_supports_mutation_side_effects(type(value)): + # don't allow STORE_ATTR mutation with custom __setattr__ + return result + return self.tx.output.side_effects.track_object_existing(value, result) + elif mutation_guard.is_dynamic_nn_module(value, self.tx.export): + # created dynamically, don't specialize on it + + # Note [Tracing a torch.compiled function] + # when make_fx tracing a compiled function, we need + if isinstance(value, torch.fx.experimental.proxy_tensor._AttrProxy): + value = value.get_base() + self.source = AttrProxySource(self.source) + + if torch._dynamo.config.inline_inbuilt_nn_modules: + freezing = is_parameter_freezing() + + # Guard against the case where user may overwrite named parameters + # / named buffers + # NOTE: This is not likely to happen but worth guarding to avoid + # exception + if ( + callable(value.named_parameters) + and value.named_parameters.__func__ + is og_module_named_parameters_fn_ptr + ): + try: # catch TypeErrors in named_parameters() from unserializable nn modules + for _, p in value.named_parameters(): + self.mark_static_input(p, guard=freezing) + except TypeError as e: + raise_observed_exception(type(e), self.tx, args=list(e.args)) + + if ( + callable(value.named_buffers) + and value.named_buffers.__func__ is og_module_named_buffers_fn_ptr + ): + try: # catch TypeErrors in named_parameters() from unserializable nn modules + for _, b in value.named_buffers(): + self.mark_static_input(b, guard=freezing) + except TypeError as e: + raise_observed_exception(type(e), self.tx, args=list(e.args)) + + if freezing: + # we need to add the module to tracing context + # in order to allow its params to get invalidated + # this will get cleaned up once compile ends + self.tx.output.nn_modules[self.name] = value + + if ( + value.__module__.startswith(("torch.nn.modules", "torch.ao.")) + and not value.__module__.startswith("torch.nn.modules.container") + ) or getattr(value.__class__, "_dynamo_marked_static", False): + new_source = self.source + if config.inline_inbuilt_nn_modules and ( + not self.tx.output.export or config.install_free_tensors + ): + # Export corner case - look at test_repros.py test_inlining_cornercase + new_source = UnspecializedBuiltinNNModuleSource(self.source) + result = UnspecializedBuiltinNNModuleVariable(value, source=new_source) + install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) + else: + new_source = self.source + if config.inline_inbuilt_nn_modules and ( + not self.tx.output.export or config.install_free_tensors + ): + # Export corner case - look at test_repros.py test_inlining_cornercase + new_source = UnspecializedNNModuleSource(self.source) + result = UnspecializedNNModuleVariable(value, source=new_source) + install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) + + self.tx.output.add_fqn_info_for_inlined_modules(value, self.source) + + if not SideEffects.cls_supports_mutation_side_effects(type(value)): + # don't allow STORE_ATTR mutation with custom __setattr__ + return result + return self.tx.output.side_effects.track_object_existing(value, result) + elif issubclass( + value.__class__, torch.nn.parallel.distributed.DistributedDataParallel + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return UnspecializedNNModuleVariable(value, source=self.get_source()) + else: + return self.tx.output.register_attr_or_module( + value, + self.name, + source=self.get_source(), + # Guards are added inside register_attr_or_module + ) + + def wrap_literal(self, value): + if type(value) is int: + # allowlist has higher precedence over specialization control. + if is_dynamic_source(self.source.name): + log.debug("%s marked dynamic via source whitelist", self.source.name) + return self.wrap_symint(value, dynamism=DimDynamic.DYNAMIC) + + if is_unbacked_source(self.source.name): + log.debug("%s marked unbacked via source whitelist", self.source.name) + return self.wrap_symint(value, dynamism=DimDynamic.SIZE_LIKE_UNBACKED) + + if not config.specialize_int: + # unspecializing int by default, but still + # specialize for the following conditions + if is_int_specialization_case(value, self.source): + recompile_hint = None + if ( + self.source.guard_source.is_unspecialized_builtin_nn_module() + or self.source.guard_source.is_unspecialized_nn_module() + ): + # This means that it is an integer from a NN module. + # Dynamo considers nn module int attributes to be static + # (a good heuristic). But a user might want to mark the + # int attribute to be a symint, so track this integer + # for recompilation later. + recompile_hint = ( + "torch.compile considers integer attributes of the nn.Module to be static. " + "If you are observing recompilation, you might want to make this integer dynamic " + "using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this " + "integer into a tensor." + ) + + process_automatic_dynamic( + self.tx, + self.source.name, + FrameStateSizeEntry.make_scalar(value), + is_unspecialized_nn_module=self.source.guard_source.is_unspecialized_nn_module(), + ) + self.install_guards( + functools.partial( + GuardBuilder.EQUALS_MATCH, recompile_hint=recompile_hint + ) + ) + return ConstantVariable.create(value=value, source=self.source) + + return self.wrap_symint(value) + elif not config.specialize_float and type(value) is float: + return self.wrap_symfloat(value) + else: + self.install_guards(GuardBuilder.CONSTANT_MATCH) + result = ConstantVariable.create(value=value, source=self.source) + if isinstance(value, (list, set)): + return self.tx.output.side_effects.track_mutable(value, result) + return result + + def assert_not_wrapped_by_this_graph(self, value: torch.Tensor): + if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode: + raise InternalTorchDynamoError( + "Cannot wrap a Tensor that has already been", + "wrapped by this instance of Dynamo", + ) + + def wrap_tensor(self, value: torch.Tensor): + source = self.get_source() + + # We cannot already be tracking the tensor, which implies + # it would have already been wrapped + assert value not in self.tx.output.side_effects + + is_static_input = get_static_address_type(value) is not None + + if ( + config.inline_inbuilt_nn_modules + and not is_static_input + and ( + isinstance(value, torch.nn.Parameter) + # mark tensor attributes of nn modules static. This is done to keep inline_inbuilt_nn_modules behavior + # compatible with previous behavior. + or (source and source.guard_source.is_unspecialized_nn_module()) + ) + ): + self.mark_static_input(value, guard=is_parameter_freezing()) + is_static_input = True + + # Install any tensors which are "free" variables; that is: + # 1. Globals + # 2. NonLocals + # 3. tensors that are attributes of nn module + should_install_free_tensor = config.install_free_tensors and ( + is_from_global_source(source) + or is_from_nonlocal_source(source) + or is_from_unspecialized_nn_module_source(source) + ) + + make_graph_attribute = is_static_input and ( + not config.inline_inbuilt_nn_modules + or is_parameter_freezing() + or torch._dynamo.config.prepare_freezing + ) + + if should_install_free_tensor or ( + (source.guard_source.is_specialized_nn_module() or make_graph_attribute) + and not source.guard_source.is_fsdp_module() + ): + self.assert_not_wrapped_by_this_graph(value) + return self.tx.output.register_attr_or_module( + value, self.name, source=source + ) + + if get_static_address_type(value) == "guarded": + # If it's a guarded tensor, we can install the parameter directly + # into the Fx graph instead of lifting it as an input. Lifting + # offers no benefit, such as regional compilation, since we still + # guard on the tensor's ID. Moreover, installing it in the Fx graph + # eliminates the pre-graph bytecode required to extract the tensor + # from locals/globals, reducing overhead. This can lead to + # significant cost savings, especially for optimizers handling many + # tensors. + self.install_guards(GuardBuilder.ID_MATCH) + self.assert_not_wrapped_by_this_graph(value) + return self.tx.output.register_attr_or_module( + value, self.name, source=source + ) + + if is_constant_source(source): + self.assert_not_wrapped_by_this_graph(value) + return self.tx.output.register_attr_or_module( + value, + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + source=source, + # Guards are added inside register_attr_or_module + ) + + # NB: this just says we accessed a tensor from the same source again + # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). + # This is distinct from two distinct sources mapping to the same + # Tensor (per id())! No guard is necessary here. See below for the + # other case. + is_duplicate_tensor = source in self.tx.output.input_source_to_var + if is_duplicate_tensor: + return self.tx.output.input_source_to_var[source] + + options = {} + subclass_type = infer_subclass_type(value) + if subclass_type is not None: + self.install_guards(GuardBuilder.TYPE_MATCH) + + if get_static_address_type(value) == "guarded": + self.install_guards(GuardBuilder.ID_MATCH) + + # By this point, we should have deduplicated all tensors + self.assert_not_wrapped_by_this_graph(value) + + if ( + isinstance(value, torch.Tensor) + and value.is_nested + and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor) + ): + unimplemented( + gb_type="Attempted to wrap strided NestedTensor", + context="", + explanation="torch.compile does not support strided NestedTensor", + hints=[], + ) + + # TODO(pearu,sparse-team) - Add the corresponding SPARSE_TENSOR_MATCH guards + if ( + isinstance(value, torch.Tensor) + and is_sparse_any(value) + and (not self.tx.export or not config.capture_sparse_compute) + ): + # A hot fix for sparse tensors + torch.compile. Support for + # export + sparsity is being added but we need to create + # SPARSE_TENSOR_GUARDS for guards to work properly. + unimplemented( + gb_type="Attempted to wrap sparse Tensor", + context="", + explanation="torch.compile does not support sparse Tensors", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + if ( + safe_has_grad(value) + and safe_grad(value) is not None + and value.dtype != safe_grad(value).dtype + ): + unimplemented( + gb_type="dtype mismatch between tensor and its gradient", + context=f"tensor dtype: {value.dtype}; grad dtype: {safe_grad(value).dtype}", + explanation="Inconsistent dtype between tensor and its gradient. " + "This can happen in FSDP and crashes meta tensor creation.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + # tx.output has multiple tracers if we're introspecting HigherOrderOperator. + # When we've discovered an untracked tensor, then we actually need + # to get Dynamo to track the tensor (which is what this function does) + # and put it as a graph input on the root tracer. Later on, + # if the input is actually used in the body of the HigherOrderOperator, + # then the relevant SubgraphTracer will lift it to being an input of + # the subgraph. + # See NOTE [HigherOrderOperator tracing design] for more details. + + example_value = wrap_to_fake_tensor_and_record( + value, tx=self.tx, is_tensor=True, source=source + ) + + tensor_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(value), + example_value, + source=source, + ) + cache_real_value_when_export(self.tx, tensor_proxy, value) + + tensor_variable = wrap_fx_proxy( + tx=self.tx, + proxy=tensor_proxy, + example_value=example_value, + subclass_type=subclass_type, + source=source, + **options, + ) + + if value._is_view(): + # If value is a view, add its base tensor to the tracked fakes list. + # This is so we are able to access the correct source for its symbolic + # shape values, in case we need them. + wrap_to_fake_tensor_and_record( + value._base, + tx=self.tx, + source=AttrSource(source, "_base"), + is_tensor=True, + ) + + guard_type = GuardBuilder.TENSOR_MATCH + + if isinstance(source, GradSource) and is_from_optimizer_source(source): + guard_type = GuardBuilder.NOT_NONE_MATCH + + is_dtensor = torch.distributed.is_available() and isinstance( + value, torch.distributed.tensor.DTensor + ) + if not is_dtensor: + # We guard on the _local_tensor and the _spec, and therefore we dont + # have to guard on the outer DTensor. + self.install_guards( + functools.partial( + guard_type, + value=( + value + if isinstance(source, NumpyTensorSource) + else TensorWeakRef(value) + ), + ) + ) + + # We install TYPE_MATCH guards for traceable wrapper subclass object, + # and recursively install corresponding guard for each inner attribute. + if is_traceable_wrapper_subclass(value): + # Tensor subclass guards are very expensive because they are + # implemented in Python. Since DTensor is PyTorch-maintained class, + # we can skip a lot of these guards. + if is_dtensor: + self.install_guards(GuardBuilder.TYPE_MATCH) + + # The inner tensor name is always _local_tensor. If its not, we + # raise assertion to update the check accordingly. + inner_tensor_name = value.__tensor_flatten__()[0][0] + if inner_tensor_name != "_local_tensor": + raise RuntimeError( + "Expecting Dtensor inner tensor name to be _local_tensor" + ) + + # Now selectively guard on the flattening context + flattening_ctx = value.__tensor_flatten__()[1] + # This is supposed to be (self._spec, self.requires_grad) + if not ( + len(flattening_ctx) == 2 + and flattening_ctx[0] == value._spec + and flattening_ctx[1] == value.requires_grad + ): + # If not, raise an assertion to update to the new guards + raise RuntimeError( + "Expecting Dtensor flattening ctx to be _spec, requires_grad" + ) + # Guard on the dtensor spec + install_guard( + AttrSource(self.source, "_spec").make_guard( + GuardBuilder.DTENSOR_SPEC_MATCH + ) + ) + # Move this to C++ + install_guard( + AttrSource(self.source, "requires_grad").make_guard( + GuardBuilder.EQUALS_MATCH + ) + ) + else: + self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH) + self.install_guards(GuardBuilder.TYPE_MATCH) + install_guard( + SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH) + ) + + attrs, _ = value.__tensor_flatten__() + for attr in attrs: + inner_value = getattr(value, attr) + inner_source = AttrSource(self.source, attr) + LazyVariableTracker.realize_all( + VariableBuilder(self.tx, inner_source)(inner_value) + ) + + self.tx.output.input_source_to_var[source] = tensor_variable + assert "tensor_dict" not in tensor_proxy.node.meta + tensor_proxy.node.meta["tensor_dict"] = _extract_tensor_dict(value) + + # Note: this information is conveyed via subclass_type now + fake_tensor_value = tensor_variable.proxy.node.meta["example_value"] + if maybe_get_fake_mode(fake_tensor_value) is not self.tx.fake_mode: + raise InternalTorchDynamoError("Wrapped Tensor must be this graph's fake") + + grapharg = GraphArg(source, value, False, fake_tensor_value) + tensor_proxy.node.meta["grapharg"] = grapharg + return tensor_variable + + def wrap_numpy_ndarray(self, value): + assert np is not None + assert isinstance(value, np.ndarray) + + source = NumpyTensorSource(self.get_source()) + + from torch._numpy import _util + + readonly = not value.flags.writeable + if readonly: + try: + value.flags.writeable = True + except ValueError: + # One can not easily make nditer elements writable, + # but warning is not the end of the world + assert isinstance(value.base, np.nditer) + + with torch_function_mode_stack_state_mgr.temp_restore_stack(): + try: + tensor_value = _util._try_convert_to_tensor(value) + if readonly: + from torch._prims_common import clone_preserve_strides + + tensor_value = clone_preserve_strides(tensor_value) + except NotImplementedError as e: + # failed to convert to tensor, graph break + unimplemented( + gb_type="failed to convert numpy.ndarray to Tensor", + context=str(value), + explanation="Exception encountered when attempting to convert numpy.ndarray to Tensor", + hints=[], + from_exc=e, + ) + + # We do this because we want the full behavior of guarding the numpy ndarray as if it were + # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here + # that there's not another great way to do this atm. + # This creates the right graphargs, as well as registration for guards in tensor names and shape env. + LazyVariableTracker.realize_all(VariableBuilder(self.tx, source)(tensor_value)) + example_value = wrap_to_fake_tensor_and_record( + tensor_value, + tx=self.tx, + is_tensor=False, + source=source, + ) + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(tensor_value), + example_value, + source=source, + ) + cache_real_value_when_export(self.tx, proxy, tensor_value) + options = {"source": source} + numpy_ndarray_variable = wrap_fx_proxy_cls( + target_cls=NumpyNdarrayVariable, + tx=self.tx, + proxy=proxy, + example_value=example_value, + **options, + ) + + self.tx.output.input_source_to_var[source] = numpy_ndarray_variable + example_value = numpy_ndarray_variable.proxy.node.meta["example_value"] + + # pass_arg_as_tensor should be true because we are wrapping a np.ndarray as argument input, and it needs to be + # converted to a tensor. + grapharg = GraphArg( + source, + tensor_value, + pass_arg_as_tensor=True, + fake_tensor=example_value, + is_tensor=True, + example_strong_ref=tensor_value, + ) + proxy.node.meta["grapharg"] = grapharg + + # TODO - Why do we need to set the source of the np ndarray vt back to + # original source. Many tests fails. + numpy_ndarray_variable.source = self.source + + return numpy_ndarray_variable + + def wrap_symint( + self, + value, + dynamism: Optional[DimDynamic] = None, + context: Optional[SymIntSymbolicContext] = None, + ): + assert type(value) is int + + if self.name in self.tx.output.unspec_variable_map: + return self.tx.output.unspec_variable_map[self.name] + + shape_env = self.tx.output.shape_env + if TracingContext.get().force_unspec_int_unbacked_size_like: + wrapped_value = shape_env.create_unbacked_symint() + _constrain_range_for_size(wrapped_value) + self.tx.output.tracked_fakes.append( + TrackedFake(wrapped_value, self.source, None) + ) + + # NB: We do not do float. For motivation, see + # https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit + # but the general idea is that we generate kernels that can + # take unspecialized floats and use them in sizevar computation + elif not is_constant_source(self.get_source()): + if dynamism is None and torch._dynamo.config.specialize_int: + # If specialize_int is False, also return + # a constant (but this should have been handled + # in the caller, TBH). But if `dynamism` is set, then actually + # turn it into a symint + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + + name = self.source.name + + frame_state_entry = process_automatic_dynamic( + self.tx, + name, + FrameStateSizeEntry.make_scalar(value), + is_unspecialized_nn_module=self.source.guard_source.is_unspecialized_nn_module(), + ) + + # TODO: This should be dynamic, as we in general do not + # know if bare integers are actually going to be sizevars + # and it is inappropriate to eagerly duck size them with + # real sizevars + normalized_source_name = normalize_source_name(self.source.name) + base_source = self.source + if isinstance(base_source, ChainedSource): + base_source = base_source.get_base() + + if dynamism is not None: + dynamic_dim = dynamism + elif ( + config.automatic_dynamic_shapes + and frame_state_entry.scalar is auto_dynamic + ): + set_feature_use("dynamo.automatic_dynamic_shapes", True) + dynamic_dim = get_automatic_dynamic_shapes_mark_as() + elif ( + isinstance(base_source, LocalSource) + and base_source.dynamism is not None + and dict(base_source.dynamism).get(normalized_source_name, {0: False})[ + 0 + ] + ) or not config.assume_static_by_default: + dynamic_dim = DimDynamic.DYNAMIC + else: # assume_static_by_default + # TODO: dynamic_dim = DimDynamic.STATIC should work but + # for some reason it doesn't + if frame_state_entry.scalar is auto_dynamic: + set_feature_use("dynamo.automatic_dynamic_shapes", False) + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value) + + wrapped_value = shape_env.create_unspecified_symint_and_symbol( + value, + source=self.source, + dynamic_dim=dynamic_dim, + ) + + self.tx.output.tracked_fakes.append( + TrackedFake(wrapped_value, self.source, context) + ) + else: + assert is_constant_source(self.get_source()) + # TODO: Do I actually need guard for constant source? + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + + assert not isinstance(self.get_source(), RandomValueSource) + install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) + + options = {"source": self.get_source()} + + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(wrapped_value), + wrapped_value, + source=self.get_source(), + ) + + sym_expr = wrapped_value.node.expr + assert isinstance(sym_expr, sympy.Symbol), f"{sym_expr} is not a basic Symbol." + self.tx.output.root_tracer.bound_symbols[sym_expr] = proxy + unspec_var = SymNodeVariable.create(self.tx, proxy, wrapped_value, **options) + self.tx.output.unspec_variable_map[self.name] = unspec_var + + if not is_constant_source(self.get_source()): + proxy.node.meta["grapharg"] = GraphArg( + self.get_source(), + wrapped_value, + pass_arg_as_tensor=False, + fake_tensor=None, + is_tensor=False, + example_strong_ref=wrapped_value, + ) + + return unspec_var + + def wrap_symfloat(self, value): + # SymFloat wrapping is special. We first wrap it in the same way we + # do an unspecialized primitive, and then we item() it into a + # SymFloat. Removal of the item() call is left to a later FX pass, + # mostly because that pass is more easily done after we have lowered + # to ATen ops. (Dynamo doesn't do decomposition right now). + + if self.name in self.tx.output.unspec_variable_map: + return self.tx.output.unspec_variable_map[self.name] + + frame_state_entry = process_automatic_dynamic( + self.tx, + self.source.name, + FrameStateSizeEntry.make_scalar(value), + is_unspecialized_nn_module=self.source.guard_source.is_unspecialized_nn_module(), + ) + + # NB: we specialize on nan input, because our guard modeling in + # ShapeEnv cannot deal with nan + if ( + torch._dynamo.config.specialize_float + or is_constant_source(self.get_source()) + or math.isnan(value) + or math.isinf(value) + # We don't support cudagraphs for now. Without this cudagraphs + # break because they expect all cuda inputs but our tensorified + # float will be a f64[] cpu tensor. Fixes the following test + # when specialize_float=False + # python test/inductor/test_compiled_optimizers.py CompiledOptimizerTests.test_rmsprop_weight_decay_maximize_capturable_cuda # noqa: B950 + or torch._inductor.config.triton.cudagraphs + or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False) + or ( + config.assume_static_by_default + and frame_state_entry.scalar is not auto_dynamic + ) + ): + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + + # NB: At the point we've gotten here, we don't assume static by + # default. Since we have a guard mechanism, there isn't really any + # downside to trying to be dynamic for float all the time. Unlike + # ints, this won't make codegen perf worse. Modest cost to compile + # time. + + wrapped_value = torch.tensor(value, dtype=torch.float64) + + # We don't support specializing floats for grad checking tensors + # See https://github.com/pytorch/pytorch/pull/140828 for more + # context. + if torch._C._functorch.is_gradtrackingtensor(wrapped_value): + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + + # TODO: Switch RandomValueSource over to use this, this is more + # accurate + assert not isinstance(self.get_source(), RandomValueSource) + install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) + + # The FloatTensorSource here is just for pedantic correctness: if you + # guard against an UnspecializedPythonVariable, you need to guard + # against the tensor-ified version of the local, otherwise it's not a + # Tensor. However, we never let the UnspecializedPythonVariable escape + # here, so there should never actually be any guards against this + # source. + source = FloatTensorSource(self.get_source()) + options = {"source": source, "raw_value": value} + + # TODO: Maybe the tensor-ification should be built into the source, + # rather than by special pattern match + example_value = wrap_to_fake_tensor_and_record( + wrapped_value, tx=self.tx, is_tensor=False, source=source + ) + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(wrapped_value), + example_value, + source=source, + ) + cache_real_value_when_export(self.tx, proxy, wrapped_value) + + unspec_var = wrap_fx_proxy_cls( + UnspecializedPythonVariable, + tx=self.tx, + proxy=proxy, + example_value=example_value, + **options, + ) + assert isinstance(unspec_var, UnspecializedPythonVariable) + self.tx.output.unspec_variable_map[self.name] = unspec_var + + if self.tx.export and not isinstance(self.get_source(), LocalSource): + raise AssertionError( + f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}" + ) + fake_tensor_value = None + example_value = unspec_var.proxy.node.meta["example_value"] + assert is_fake(example_value) + + fake_tensor_value = example_value + assert fake_tensor_value.fake_mode is self.tx.fake_mode, ( + f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode" + "({self.tx.fake_mode}) from InstructionTranslator" + ) + + # There's something a bit incoherent about pass_arg_as_tensor, + # specifically regarding sources. + # + # Specifically, suppose we have "x: float" local argument. We + # eventually end up with an UnspecializedPythonVariable denoting + # torch.as_tensor(x)... but it's source is still L['x'] (which if you + # accessed it directly is a float!) So you gotta be careful when + # setting up your guards, because it's still going to be a float at + # this point, the conversion happens only precisely at the point we're + # actually calling the FX graph. This happens to be what we want for + # shape guard generation, but it's kind of unintuitive. + proxy.node.meta["grapharg"] = GraphArg( + self.get_source(), + wrapped_value, + pass_arg_as_tensor=True, + fake_tensor=fake_tensor_value, + is_tensor=False, + example_strong_ref=wrapped_value, + ) + + # Directly do item to bypass capture_scalar_outputs + r = wrap_fx_proxy( + self.tx, + self.tx.output.create_proxy( + "call_method", + "item", + *proxy_args_kwargs([unspec_var], {}), + ), + ) + self.tx.output.tracked_fakes.append(TrackedFake(r.sym_num, self.source, None)) + + get_metrics_context().set("tensorify_float_attempt", True, overwrite=True) + + return r + + def wrap_unspecialized_primitive(self, value): + if self.name in self.tx.output.unspec_variable_map: + return self.tx.output.unspec_variable_map[self.name] + + wrapped_value = torch.tensor(value) + if not isinstance(self.get_source(), RandomValueSource): + install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) + + options = {"source": self.get_source()} + options.update({"raw_value": value}) + + example_value = wrap_to_fake_tensor_and_record( + wrapped_value, tx=self.tx, is_tensor=False, source=self.get_source() + ) + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(wrapped_value), + example_value, + source=self.get_source(), + ) + cache_real_value_when_export(self.tx, proxy, wrapped_value) + + unspec_var = wrap_fx_proxy_cls( + UnspecializedPythonVariable, + tx=self.tx, + proxy=proxy, + example_value=example_value, + **options, + ) + self.tx.output.unspec_variable_map[self.name] = unspec_var + if not is_constant_source(self.get_source()): + if self.tx.export and not isinstance(self.get_source(), LocalSource): + raise AssertionError( + f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}" + ) + fake_tensor_value = None + if unspec_var.is_python_constant(): + # TODO: when can this happen? + example_value = unspec_var.as_python_constant() + else: + example_value = unspec_var.proxy.node.meta["example_value"] + assert is_fake(example_value) + + fake_tensor_value = example_value + assert fake_tensor_value.fake_mode is self.tx.fake_mode, ( + f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode" + "({self.tx.fake_mode}) from InstructionTranslator" + ) + + proxy.node.meta["grapharg"] = GraphArg( + self.get_source(), + wrapped_value, + pass_arg_as_tensor=True, + fake_tensor=fake_tensor_value, + is_tensor=False, + example_strong_ref=wrapped_value, + ) + return unspec_var + + +def _dataclasses_fields_lambda(obj): + if isinstance(obj, UserDefinedObjectVariable): + value = obj.value + else: + unimplemented( + gb_type="dataclass fields failure", + context=f"obj: {obj}; variable type: {type(obj)}", + explanation=f"Dataclass fields handling fails for {obj}. Expected it to be a user-defined object.", + hints=[], + ) + items = [] + for field in dataclasses.fields(value): + source = None + if obj.source: + base_src = AttrSource(obj.source, "__dataclass_fields__") + source = DictGetItemSource(base_src, field.name) + items.append(UserDefinedObjectVariable(field, source=source)) + return TupleVariable(items) + + +def _clone_input(value, fake_mode): + if isinstance(value, torch.Tensor): + # tensor subclasses will not be converted to FakeTensors and need to be cloned + if not ( + isinstance(value, FakeTensor) + or ( + # Is functional tensor fakeified by this instance of Dynamo + torch._is_functional_tensor(value) + and maybe_get_fake_mode(value) is fake_mode + ) + or value.is_nested + ): + # NB: ensure strides are preserved + value = clone_input(value) + + return value + + +def wrap_fx_proxy( + tx, proxy, example_value=None, subclass_type=None, **options +) -> VariableTracker: + kwargs = { + "tx": tx, + "proxy": proxy, + "example_value": example_value, + "subclass_type": subclass_type, + **options, + } + if subclass_type is None: + return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) + else: + result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs) + result.install_global(tx) + return result + + +def cache_real_value_when_export(tx, proxy, example_value): + if tx.export: + # The legacy behavior for real value cache with subclasses was + # to perform a clone WITHOUT preserving the subclass. It's + # not entirely clear this is what you actually want though. + with torch._C.DisableTorchFunctionSubclass(): + proxy.tracer.real_value_cache[proxy.node] = _clone_input( + example_value, tx.fake_mode + ) + + +# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable +# Should be compositional instead +# +# This is a horribly complicated function that does too many things, to +# explain what it does, let's first talk about the classic usage wrap_fx_proxy +# for a TensorVariable. There are two primary modes of use: +# +# 1. Wrapping a pre-existing Tensor. In this case, example_value is set +# to the pre-existing Tensor. (Note that this example_value will NOT +# be the final example_value we put into node.meta['example_value'], +# instead it is converted into a fake tensor using +# wrap_to_fake_tensor_and_record and registered as a graph input.) +# +# 2. "Wrapping" the result of some Tensor operation Dynamo traced over. In +# this case, example_value is None (and we are going to figure it out +# ourselves using FakeTensors, via get_fake_value, which will run +# the operation represented by the (singular!) FX node referenced by +# the passed in proxy.) +# +# The expectation is you end up with a Tensor output, and everything is +# straightforwardly traced into the graph. +# +# In all cases, the returned `TensorVariable` subclass will have an `example_value` +# and that `example_value` must be a `FakeTensor` produced by the currently running +# instance of Dynamo. +# +# Upon closer inspection, you may notice that there are a slurry of non-Tensor +# output cases in handle_traced_output. What gives? Well, we sometimes trace operations into the +# graph that don't involve tensors. +# +# * Some operators return tuples; we need to recursively handle their +# contents +# +# * Some operators have side effects that will affect subsequent AOTAutograd +# tracing but don't otherwise return anything. +# +# * Some operators return symbolic ints/floats/bools which can go in the +# graph and be traced (but only if they're actually symbolic! If they're +# static you don't want to put them in the graph, which means you +# shouldn't call this function.) +# +# The common theme is that you only use this function WHEN YOU ARE TRACING +# SOMETHING INTO THE GRAPH. This is sort of obvious, because you can't call +# this function without a proxy. +def wrap_fx_proxy_cls( + target_cls, tx, proxy, example_value=None, subclass_type=None, **options +): + if example_value is None: + out = _wrap_fx_proxy( + target_cls, tx, proxy, example_value, subclass_type, **options + ) + elif isinstance(example_value, torch.Tensor): + out = _wrap_fx_preexisting_tensor( + target_cls, tx, proxy, example_value, subclass_type, **options + ) + else: + # This will skip tracing an op and recursively reinvoke wrap_fx_proxy_cls on supported + # data structures. In essence this just handles tracing some other value which may + # contain Fake Tensors or is otherwise proxyable. + out = handle_traced_output( + example_value, tx, proxy, options, subclass_type, target_cls + ) + + if ( + isinstance( + out, + ( + torch._dynamo.variables.TensorVariable, + torch._dynamo.variables.SymNodeVariable, + ), + ) + and proxy.node.op != "placeholder" + ): + tx.output.current_tracer.record_tensor_or_symint_vt(out) + return out + + +# This is 1 above (wrapping a preexisting tensor) +def _wrap_fx_preexisting_tensor( + target_cls, tx, proxy, tensor, subclass_type=None, **options +): + from ..symbolic_convert import InstructionTranslatorBase + + assert isinstance(tensor, torch.Tensor), ( + f"_wrap_fx_preexisting_tensor expected tensor, got {type(tensor)}" + ) + + assert isinstance(tx, InstructionTranslatorBase) + if "guards" in options and options["guards"] is not None: + tx.output.guards.update(options["guards"]) + + # Placeholders always carry example_value in node.meta. + # non-placeholders always have no example_value in node.meta + if proxy.node.op == "placeholder": + assert "example_value" in proxy.node.meta, ( + f"placeholder {proxy} doesn't have 'example_value' in node.meta" + ) + else: + assert "example_value" not in proxy.node.meta, ( + f"{proxy.node.meta['example_value']}" + ) + + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] + with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + # Handle recursive calls here + if maybe_get_fake_mode(tensor) is tx.fake_mode: + pass + else: + cache_real_value_when_export(tx, proxy, tensor) + if tx.export: + # The legacy behavior for real value cache with subclasses was + # to perform a clone WITHOUT preserving the subclass. It's + # not entirely clear this is what you actually want though. + with torch._C.DisableTorchFunctionSubclass(): + proxy.tracer.real_value_cache[proxy.node] = _clone_input( + tensor, tx.fake_mode + ) + # NB: If we're ignoring subclass, then the expectation is you will + # take the returned TensorVariable and wrap it into a more + # accurate TensorVariable that is able to track subclass-ness; + # otherwise this is wrong! + kwargs = { + "is_tensor": target_cls + in (TensorVariable, TensorWithTFOverrideVariable), + } + assert "source" in options and options["source"] is not None + kwargs["source"] = options["source"] + tensor = wrap_to_fake_tensor_and_record(tensor, tx=tx, **kwargs) + + if tensor.device.type != "meta" and ( + maybe_get_fake_mode(tensor) is not tx.fake_mode + ): + raise InternalTorchDynamoError( + "`tensor` needs to be a `FakeTensor`" + f"wrapped by this instance of Dynamo. Found: {tensor}" + ) + + return construct_tensor_variable( + target_cls, tx, proxy, tensor, subclass_type, options + ) + + +# This is 2 in the above comment (wrapping the output of a traced op) +def _wrap_fx_proxy( + target_cls, tx, proxy, example_value=None, subclass_type=None, **options +): + from ..symbolic_convert import InstructionTranslatorBase + + assert isinstance(tx, InstructionTranslatorBase) + if "guards" in options and options["guards"] is not None: + tx.output.guards.update(options["guards"]) + + assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}" + + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] + with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + # with preserve_rng_state(): + # only allow_non_graph_fake in this instance because we handle the non-fake + # cases properly below. + example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) + + return handle_traced_output( + example_value, tx, proxy, options, subclass_type, target_cls + ) + + +# This handles wrapping of the output of an op traced into the graph +def handle_traced_output(example_value, tx, proxy, options, subclass_type, target_cls): + import torch._functorch.vmap + import torch._subclasses.fake_tensor + import torch._utils + + if isinstance(example_value, torch.Tensor): + # Check if the result is a sparse tensor - + # We generally don't support sparse tensor so better to graph break here + if is_sparse_any(example_value) and ( + not tx.export or not config.capture_sparse_compute + ): + unimplemented( + gb_type="Attempted to wrap sparse Tensor with VariableTracker", + context=str(example_value), + explanation="torch.compile does not support sparse Tensors with VariableTracker", + hints=[*graph_break_hints.SUPPORTABLE], + ) + var = construct_tensor_variable( + target_cls, tx, proxy, example_value, subclass_type, options + ) + # NOTE: [Side effect tracking for newly constructed tensor] + # For newly constructed objects that have mutable attributes, we usually + # construct their VariableTracker via `track_object_new`, but since + # tensor variable construction is a bit different, we handle them + # specially here. This ensures that codegen will actually generate the + # attribute mutations on this tensor. + # + # NOTE we pass a dummy object as the `item` argument to avoid + # constructing a dummy _tensor_ object. The object isn't used for + # newly constructed VTs anyways. + tx.output.side_effects._track_obj( + proxy, var, mutation_type_cls=AttributeMutationNew + ) + return var + elif ( + hasattr(proxy.node.target, "__name__") + and proxy.node.target.__name__ == "set_state" + and isinstance(proxy.node.target.__self__, torch._C.Generator) + or proxy.node.target is torch.random.set_rng_state + ): + return TorchInGraphFunctionVariable(proxy.node.target) + elif ( + proxy.node.target is torch._C._DisableFuncTorch + or proxy.node.target is torch.cuda._is_in_bad_fork + ): + return UserDefinedObjectVariable(example_value) + elif istype(example_value, torch.Size) and all( + isinstance(x, int) for x in example_value + ): + sizes = [ConstantVariable.create(x) for x in example_value] + return SizeVariable(sizes, **options) + elif isinstance(example_value, (tuple, list)): + set_example_value(proxy.node, example_value) + unpacked = [] + for i, val in enumerate(example_value): + if val is None: + # nn.MultiheadAttention() can return None, see issue #175 + unpacked.append( + ConstantVariable.create(None, **options), + ) + else: + proxy_i = proxy.tracer.create_proxy( + kind="call_function", + target=operator.getitem, + args=(proxy, i), + kwargs={}, + ) + + if "source" in options: + # This path should only trigger for list stealing, so it's + # safe to use `GetItemSource`. + assert isinstance(example_value, list) + source = options["source"] + options_i = options.copy() + options_i["source"] = GetItemSource( + base=source, index=i, index_is_slice=False + ) + else: + # use the same options object as parent + options_i = options + + # WARNING: this assumes the same target_cls as this tuple/list call + unpacked.append( + wrap_fx_proxy_cls( + target_cls=target_cls, + tx=tx, + proxy=proxy_i, + example_value=val, + **options_i, + ) + ) + if isinstance(example_value, torch.Size): + # NB: Keep the old proxy around. See SizeVariable for an + # explanation why + return SizeVariable(unpacked, proxy, **options) + elif istype(example_value, tuple): + return TupleVariable(unpacked, **options) + elif istype(example_value, (list, immutable_list)): + return ListVariable(unpacked, **options) + else: + assert ( + example_value.__class__.__module__ == "torch.return_types" + or hasattr(example_value, "_fields") + ), ( + f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}" + ) + return NamedTupleVariable(unpacked, example_value.__class__, **options) + elif example_value is None or proxy.node.target is torch.manual_seed: + return ConstantVariable.create(None, **options) + elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)): + tx.output.current_tracer.track_produced_symints(example_value, proxy) + set_example_value(proxy.node, example_value) + return SymNodeVariable.create(tx, proxy, example_value, **options) + elif ( + isinstance(example_value, torch.Stream) + and proxy.node.target is get_external_object_by_index + ) or proxy.node.target in [ + device_interface.current_stream + for _, device_interface in get_registered_device_interfaces() + ]: + set_example_value(proxy.node, example_value) + index = None + if proxy.node.target is get_external_object_by_index: + index = proxy.node.args[0] + return StreamVariable(proxy, example_value, index, **options) + elif ( + isinstance(example_value, torch.Event) + and proxy.node.target is get_external_object_by_index + ) or proxy.node.target in [ + device_interface.current_stream + for _, device_interface in get_registered_device_interfaces() + ]: + index = None + if proxy.node.target is get_external_object_by_index: + index = proxy.node.args[0] + set_example_value(proxy.node, example_value) + return EventVariable(proxy, example_value, index, **options) + elif ( + inspect.isclass(proxy.node.target) + and issubclass(proxy.node.target, torch.Event) + ) or proxy.node.target in [ + device_interface.Event + for _, device_interface in get_registered_device_interfaces() + ]: + set_example_value(proxy.node, example_value) + return EventVariable(proxy, example_value, None, **options) + elif proxy.node.target == "query" and proxy.node.op == "call_method": + set_example_value(proxy.node, example_value) + return ConstantVariable(example_value, **options) + elif ( + example_value is not None + and isinstance(example_value, torch.Event) + and proxy.node.target == "record_event" + and proxy.node.op == "call_method" + ): + set_example_value(proxy.node, example_value) + return EventVariable(proxy, example_value, None, **options) + elif isinstance(example_value, int) and ( + proxy.node.target + in [ + torch.sym_int, + getattr, + operator.getitem, + torch._utils._element_size, + torch.seed, + operator.mod, + torch._functorch.vmap._validate_and_get_batch_size, + torch._functorch.predispatch._vmap_increment_nesting, + torch._functorch.predispatch._vmap_decrement_nesting, + # some mac builds are missing torch.distributed.get_rank() + getattr(torch.distributed, "get_rank", _missing), + getattr(torch.distributed, "get_world_size", _missing), + # This always wants to be in the graph, even if the constraint + # results in a constant int + torch._constrain_as_size, + ] + or ( + # TODO: this is a little sus, because we didn't check what the self is + proxy.node.op == "call_method" and proxy.node.target == "bit_length" + ) + ): + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) + elif isinstance(example_value, torch.backends.cuda.SDPAParams): + from .sdpa import SDPAParamsVariable + + set_example_value(proxy.node, example_value) + return SDPAParamsVariable(proxy, **options) + elif isinstance(example_value, bool) and ( + proxy.node.target + in [ + torch._C._are_functorch_transforms_active, + torch._C._functorch.is_batchedtensor, + torch.backends.cuda.is_flash_attention_available, + torch.backends.cuda.can_use_flash_attention, + torch.backends.cuda.can_use_efficient_attention, + torch._C._get_cudnn_sdp_enabled, + torch._C._get_flash_sdp_enabled, + torch._C._get_mem_efficient_sdp_enabled, + torch._C._get_math_sdp_enabled, + torch._C._get_overrideable_sdp_enabled, + "is_integer", + ] + + list(supported_const_comparison_op_values.keys()) + ): + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) + elif isinstance(example_value, (int, float, bool)) and ( + proxy.node.target is call_torchbind + or proxy.node.target is flat_apply + or (proxy.node.op == "call_method" and proxy.node.target == "item") + ): + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) + elif isinstance(example_value, float) or proxy.node.target in ["hex", "__round__"]: + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) + else: + unimplemented( + gb_type="torch.* op returned non-Tensor", + context=f"example_value type: {typestr(example_value)}; op: {proxy.node.op}; target: {proxy.node.target}", + explanation="torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output", + hints=[], + ) + + +def infer_subclass_type(value): + if type(value) in ( + torch.Tensor, + torch.nn.Parameter, + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ) or is_traceable_wrapper_subclass(value): + # Ordinarily, we would fakeify a tensor so that it can get dynamic + # shapes and be computed on without triggering actual operations. + # However, how can we fakeify a tensor subclass? Ordinary + # inheritance (nor multiple inheritance) won't work work. + # + # Instead, our plan is to *manually simulate* the tensor subclass + # inheriting from a fake tensor with dynamo. This means our + # data representation for a tensor subclass will be a fake tensor + # + tensor subclass type + any extra data the subclass may have + # been storing on the tensor. Because all Python accesses are + # mediated through TensorWithTFOverrideVariable, we can ensure + # that we dispatch differently, e.g., according to + # __torch_function__ + # + # To simplify things for now, the __dict__ tracking bits haven't + # been implemented yet, but they can be added into this design at + # a later point in time. + return None + else: + return type(value) + + +def get_specialized_props(target_cls, tx, example_value, subclass_type): + specialized_props = target_cls.specialize(example_value) + # TODO: not sure about this fake mode test + if ( + isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) + and example_value.fake_mode is tx.fake_mode + ): + if subclass_type: + tensor_type = subclass_type + elif isinstance(example_value, torch.nn.Parameter): + tensor_type = torch.nn.Parameter + elif isinstance(example_value, torch.nn.Buffer): + tensor_type = torch.nn.Buffer + else: + tensor_type = torch.Tensor + specialized_props["class_type"] = tensor_type + + return specialized_props + + +def construct_tensor_variable( + target_cls, tx, proxy, example_value, subclass_type, options +): + """ + Actually construct a tensor variable after all the pre-processing from + wrapping a pre-existing or newly created tensor value. + """ + # NB: In most (all?) cases, this does not actually do a clone. + # (WARNING: this means that if we mutate metadata on the fake + # tensor, the stored example value will update too!) + example_value = _clone_input(example_value, tx.fake_mode) + set_example_value(proxy.node, example_value) + # We bind the unbacked symints in sizes/trdies of tensor lazily. + # So that subgraphs can access the unbacked symbol's proxy in parent graph + # when lifting unbacked symbols of input tensors to subgraph inputs. + # We do it lazily because the tensor may not be used in subgraphs. + if proxy.node.op != "placeholder": + tx.output.current_tracer.track_produced_symints(example_value, proxy) + options.update(get_specialized_props(target_cls, tx, example_value, subclass_type)) + return target_cls(proxy, **options) + + +def get_automatic_dynamic_shapes_mark_as(): + if config.automatic_dynamic_shapes_mark_as == "dynamic": + return DimDynamic.DYNAMIC + elif config.automatic_dynamic_shapes_mark_as == "unbacked": + return DimDynamic.SIZE_LIKE_UNBACKED + elif config.automatic_dynamic_shapes_mark_as == "oblivious": + return DimDynamic.OBLIVIOUS_SIZE + else: + raise ValueError( + f"invalid automatic_dynamic_shapes_mark_as = {config.automatic_dynamic_shapes_mark_as}" + ) + + +_DYNAMIC_SOURCES: Optional[set[str]] = None +_DYNAMIC_SOURCES_CONFIG_HASH: Optional[int] = None + + +def get_dynamic_sources() -> set[str]: + global _DYNAMIC_SOURCES, _DYNAMIC_SOURCES_CONFIG_HASH + + current_hash = hash(torch.compiler.config.dynamic_sources) + + # If we have already calculated the sources and the config hasn't changed, return cached result + if _DYNAMIC_SOURCES is not None and _DYNAMIC_SOURCES_CONFIG_HASH == current_hash: + return _DYNAMIC_SOURCES + + # Config has changed or first time, (re)calculate the sources + _DYNAMIC_SOURCES = { + s + for s in torch.compiler.config.dynamic_sources.replace(" ", "").split(",") + if s + } + _DYNAMIC_SOURCES_CONFIG_HASH = current_hash + + return _DYNAMIC_SOURCES + + +def is_dynamic_source(source_name: str) -> bool: + dynamic_sources = get_dynamic_sources() + for pattern in dynamic_sources: + if pattern == source_name or re.match(pattern, source_name): + log.debug( + "%s was marked dynamic due to dynamic source allowlist pattern: %s", + source_name, + pattern, + ) + return True + return False + + +def record_automatic_dynamic( + tx: "InstructionTranslator", name: str, e: torch.Tensor +) -> FrameStateSizeEntry: + # This mimics stride inference algorithm in _create_symbolic_sizes_strides_storage_offset + ex_size = e.size() + if not is_sparse_any(e): + ex_stride = e.stride() + dim = e.dim() + + stride = [None] * dim + pending = [(ex_stride[i], -i) for i in range(dim)] + pending.sort(key=_nested_int_aware_sort) + candidates = {} + for i_stride, neg_i in pending: + i = -neg_i + stride[i] = candidates.get(i_stride, i_stride) + candidates.setdefault(i_stride * ex_size[i], InferStride(i)) + else: + stride = [] + + return process_automatic_dynamic( + tx, name, FrameStateSizeEntry.make_tensor(tuple(ex_size), tuple(stride)) + ) + + +_UNBACKED_SOURCES: Optional[set[str]] = None +_UNBACKED_SOURCES_CONFIG_HASH: Optional[int] = None + + +def get_unbacked_sources() -> set[str]: + global _UNBACKED_SOURCES, _UNBACKED_SOURCES_CONFIG_HASH + + current_hash = hash(torch.compiler.config.unbacked_sources) + + # If we have already calculated the sources and the config hasn't changed, return cached result + if _UNBACKED_SOURCES is not None and _UNBACKED_SOURCES_CONFIG_HASH == current_hash: + return _UNBACKED_SOURCES + + # Config has changed or first time, (re)calculate the sources + _UNBACKED_SOURCES = { + s + for s in torch.compiler.config.unbacked_sources.replace(" ", "").split(",") + if s + } + _UNBACKED_SOURCES_CONFIG_HASH = current_hash + + return _UNBACKED_SOURCES + + +def is_unbacked_source(source_name: str) -> bool: + unbacked_sources = get_unbacked_sources() + for pattern in unbacked_sources: + if pattern == source_name or re.match(pattern, source_name): + log.debug( + "%s was marked unbacked due to unbacked source allowlist pattern: %s", + source_name, + pattern, + ) + return True + return False + + +# Performs automatic dynamic dim determination. +# Returns a SymbolicContext +def _automatic_dynamic( + e, tx, source, static_shapes, outer_only=False +) -> SymbolicContext: + # strided NT not supported + if e.is_nested and not isinstance( + e, torch.nested._internal.nested_tensor.NestedTensor + ): + unimplemented( + gb_type="Encountered strided NestedTensor in automatic dynamic dim determination", + context="", + explanation="torch.compile does not support strided NestedTensor", + hints=[], + ) + + name = source.name + prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None) + shape_env_to_source_to_symbol_cache = ( + prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None + ) + + # Get base context if the tensor is a view + view_base_context: Optional[SymbolicContext] = None + if e._is_view(): + base_source = AttrSource(source, "_base") + view_base_context = _automatic_dynamic(e._base, tx, base_source, static_shapes) + + if is_traceable_wrapper_subclass(e) and not outer_only: + # Get symbolic context for outer tensor + outer_context = _automatic_dynamic( + e, tx, source, static_shapes, outer_only=True + ) + + # Get symbolic contexts for inner tensors + inner_contexts = {} # mapping from attr -> symbolic context + attrs, _ = type(e).__tensor_flatten__(e) + for attr in attrs: + inner_tensor = getattr(e, attr) + inner_source = AttrSource(source, attr) + inner_contexts[attr] = _automatic_dynamic( + inner_tensor, tx, inner_source, static_shapes + ) + + return SubclassSymbolicContext( + dynamic_sizes=outer_context.dynamic_sizes, + dynamic_strides=outer_context.dynamic_strides, + constraint_sizes=outer_context.constraint_sizes, + constraint_strides=outer_context.constraint_strides, + view_base_context=view_base_context, + tensor_source=outer_context.tensor_source, + shape_env_to_source_to_symbol_cache=outer_context.shape_env_to_source_to_symbol_cache, + inner_contexts=inner_contexts, + ) + + if static_shapes and not is_dynamic_source(name): + return StatefulSymbolicContext( + dynamic_sizes=[DimDynamic.STATIC] * e.dim(), + dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), + constraint_sizes=[None] * e.dim(), + constraint_strides=[None] * e.dim(), + view_base_context=view_base_context, + tensor_source=source, + shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, + ) + + # We preserve the dynamism of inputs. For example, when users call + # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes. + from torch.fx.experimental.symbolic_shapes import is_nested_int + + if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()): + return StatefulSymbolicContext( + dynamic_sizes=[ + DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC + for s in e.size() + ], + dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), + constraint_sizes=[None] * e.dim(), + constraint_strides=[None] * e.dim(), + view_base_context=view_base_context, + tensor_source=source, + shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, + ) + + # Prep for automatic dynamic + frame_state_entry = record_automatic_dynamic(tx, name, e) + + # TODO: index export_constraints ahead of time so we don't have to + # do a linear scan every time here + t_id = id(e) + dim2constraint = {} + + def update_dim2constraint(dim, constraint_range, name): + if dim in dim2constraint: + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + + old_constraint_range, old_name = dim2constraint[dim] + new_constraint_range = StrictMinMaxConstraint( + vr=constraint_range.vr & old_constraint_range.vr, + warn_only=False, + ) + # It is possible for (non-None) old_name and name to be different + # but this will only happen the corresponding Dims can be derived equal. + new_name = old_name or name + dim2constraint[dim] = new_constraint_range, new_name + else: + dim2constraint[dim] = constraint_range, name + + from torch.export.dynamic_shapes import _RelaxedConstraint + + if tx.output.export_constraints: + for constraint in tx.output.export_constraints: + if isinstance(constraint, _RelaxedConstraint): + continue + if constraint.t_id == t_id: + update_dim2constraint( + constraint.dim, constraint.constraint_range, constraint.name + ) + + dynamic_sizes = [] + dynamic_strides = [] + constraint_sizes = [] + constraint_strides = [] + specialize_on = [] + for i in range(e.dim()): + # NB: mark dynamic has precedence over static + marked_strict_unbacked = i in getattr( + e, "_dynamo_strict_unbacked_indices", set() + ) + marked_unbacked = i in getattr(e, "_dynamo_unbacked_indices", set()) + marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set()) + marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set()) + marked_static = i in getattr(e, "_dynamo_static_indices", set()) + + specialize_on.append(getattr(e, "_specialize_on", {}).get(i, [])) + + # Reflect the user directive in the frame_state + # For dynamic, apply None always + + normalized_source_name = normalize_source_name(source.name) + base_source = source + if isinstance(base_source, ChainedSource): + base_source = base_source.get_base() + + if marked_dynamic or ( + isinstance(base_source, LocalSource) + and base_source.dynamism is not None + and dict(base_source.dynamism).get(normalized_source_name, {i: False})[i] + ): + # TODO: This can be batched + # TODO: Doing this here is kind of sus, maybe better to set this + # up when we initially created the FrameStateSizeEntry to bong + # into the mutable state + log.debug("automatic dynamic %s marked dynamic", name) + mark_size = [auto_unset] * e.dim() + mark_size[i] = auto_dynamic + frame_state_entry |= FrameStateSizeEntry.make_size(size=mark_size) + + # NB: both static and dynamic have precedence over + automatic_dynamic_size = ( + config.automatic_dynamic_shapes and frame_state_entry.is_size_dynamic(i) + ) + # NB: previously, if size was dynamic, we wouldn't make its stride + # dynamic. But now, because of InferStride concept, we will properly + # not make stride dynamic even if it's wobbling + automatic_dynamic_stride = ( + config.automatic_dynamic_shapes and frame_state_entry.is_stride_dynamic(i) + ) + + if is_dynamic_source(name): + log.debug("%s marked dynamic via source whitelist", name) + automatic_dynamic_size = True + + if is_unbacked_source(name): + log.debug("%s marked unbacked via source whitelist", name) + automatic_dynamic_size = True + + automatic_dynamic = automatic_dynamic_size or automatic_dynamic_stride + + # We will process constraints first, as they will imply that we + # have a dynamic dimension + # Precedence: export constraints > eager constraints + constraint = dim2constraint.get(i) + if constraint is None: + constraint_size = None + constraint_stride = None + if marked_dynamic and not config.allow_ignore_mark_dynamic: + # constraint_stride is deliberaly kept None because no easy way to provide value ranges for mark dynamic + constraint_stride = None + if hasattr(e, "_dynamo_dynamic_range"): + dim_range = [ + dr for dr in e._dynamo_dynamic_range if dr.dim == i + ].pop() + if dim_range.min is None and dim_range.max is None: + constraint_size = RelaxedUnspecConstraint(warn_only=False) + else: + from torch.fx.experimental.symbolic_shapes import ( + StrictMinMaxConstraint, + ) + + constraint_size = StrictMinMaxConstraint( + vr=ValueRanges(lower=dim_range.min, upper=dim_range.max), + warn_only=False, + ) + else: + constraint_size = RelaxedUnspecConstraint(warn_only=False) + elif marked_strict_unbacked: + constraint_size = RelaxedUnspecConstraint(warn_only=False) + elif not marked_static and automatic_dynamic: + set_feature_use("dynamo.automatic_dynamic_shapes", True) + if automatic_dynamic_size: + constraint_size = RelaxedUnspecConstraint(warn_only=True) + if automatic_dynamic_stride: + constraint_stride = RelaxedUnspecConstraint(warn_only=True) + else: + if not marked_static and not config.automatic_dynamic_shapes: + set_feature_use("dynamo.automatic_dynamic_shapes", False) + constraint_size = None + constraint_stride = None + else: + constraint_size, name_ = constraint + constraint_stride = None + dim_name = f"{name}.size()[{i}]" + tx.output.shape_env.source_name_to_debug_name[dim_name] = name_ + constraint_sizes.append(constraint_size) + constraint_strides.append(constraint_stride) + + if marked_unbacked or is_unbacked_source(name): + dynamic_size = DimDynamic.SIZE_LIKE_UNBACKED + elif ( + constraint_size is not None + or marked_dynamic + or marked_weak_dynamic + or is_nested_int(e.size()[i]) + ): + # NB: We could assert static_shapes is False here, but it + # seems better to allow the user to override symbolic_context in this + # case + if automatic_dynamic: + dynamic_size = get_automatic_dynamic_shapes_mark_as() + else: + dynamic_size = DimDynamic.DYNAMIC + elif static_shapes or config.assume_static_by_default or marked_static: + dynamic_size = DimDynamic.STATIC + else: + # TODO: When does this show up? + dynamic_size = DimDynamic.DUCK + + if constraint_stride is not None: + dynamic_stride = DimDynamic.DYNAMIC + else: + dynamic_stride = DimDynamic.INFER_STRIDE + + dynamic_sizes.append(dynamic_size) + dynamic_strides.append(dynamic_stride) + + return StatefulSymbolicContext( + dynamic_sizes=dynamic_sizes, + dynamic_strides=dynamic_strides, + constraint_sizes=constraint_sizes, + constraint_strides=constraint_strides, + specialize_on=specialize_on, + view_base_context=view_base_context, + tensor_source=source, + shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, + ) + + +# See note [Tensor Fakification and Symbol Caching] +def wrap_to_fake_tensor_and_record( + e, tx, *, source: Optional[Source], is_tensor: bool, parent_context=None +): + if ( + type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor) + or isinstance(e, torch.Tensor) + or is_traceable_wrapper_subclass(e) + ): + assert source is not None + static_shapes, _reason = tensor_always_has_static_shape( + e, + is_tensor, + tensor_source=source, + ) + + if not parent_context: + symbolic_context = _automatic_dynamic(e, tx, source, static_shapes) + else: + # Parent contexts are passed in when we are recursively creating + # fake tensors for subclasses. A better design would be not to create a + # parent/child relationship, but to recursively call _automatic_dynamic + # as we recursively call wrap_to_fake_tensor_and_record. This runs + # into bugs around how meta_utils knows and works to create fake tensors + # with tensor subclasses. Ideally, dynamo would drive both the recursive + # wrap_to_fake_tensor_and_record and _automatic_dynamic policy creation. + assert isinstance(source, AttrSource) + inner_context_name = source.member + symbolic_context = parent_context.inner_contexts[inner_context_name] + + log.debug( + "wrap_to_fake %s %s %s %s", + source.name, + tuple(e.shape), + symbolic_context, + type(e), + ) + + # Note [enable_python_dispatcher in dynamo] + # Dynamo disables itself when it runs fake tensor prop, which means that tensor subclasses + # have no way to know (purely based off of global state) if they are currently being run under compile or not. + # we use enable_python_dispatcher mainly to tweak the DispatchKeyState so that subclass authors + # can check it to know if they are running in an eager context or not + with enable_python_dispatcher(): + fake_e = wrap_fake_exception( + lambda: tx.fake_mode.from_tensor( + e, + source=source, + symbolic_context=symbolic_context, + ) + ) + if ( + source is not None + and isinstance(fake_e, FakeTensor) + and (sym_val := fake_e.item_memo) is not None + ): + tx.output.tracked_fakes.append( + TrackedFake(sym_val, CallMethodItemSource(source), symbolic_context) + ) + + if is_traceable_wrapper_subclass(fake_e): + attrs, _ = fake_e.__tensor_flatten__() + for attr in attrs: + fake_inner = getattr(fake_e, attr) + inner = getattr(e, attr) + inner_source = AttrSource(source, attr) + wrap_to_fake_tensor_and_record( + inner, + tx, + source=inner_source, + is_tensor=isinstance(fake_inner, torch.Tensor), + parent_context=symbolic_context, + ) + + tx.output.tracing_context.tensor_to_context[e] = symbolic_context + if is_sparse_any(fake_e): + # TODO: for TensorGuards, this eventually may need more + # fields for the size/stride of any other constituents + values = fake_e._values() if fake_e.is_sparse else fake_e.values() + tx.output.input_source_to_sizes_strides[source] = { + "size": fake_e.size(), + # TODO: revise this, but for now this stride instead of () + # avoids SegFault with PYTORCH_TEST_WITH_DYNAMO=1 + "stride": (1,) * fake_e.ndim, + "values_size": values.size(), + "values_stride": values.stride(), + } + else: + tx.output.input_source_to_sizes_strides[source] = { + "size": fake_e.size(), + "stride": fake_e.stride(), + } + + if ( + is_tensor + and not (static_shapes and source.is_specialized_nn_module()) + and not is_constant_source(source) + ): + tx.output.tracked_fakes.append( + TrackedFake(fake_e, source, symbolic_context) + ) + tx.output.tracked_fakes_id_to_source[id(e)].append(source) + + return fake_e + else: + return e + + +class SourcelessBuilder: + """ + Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects + that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over + .), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However, + there may be reasons to represent it as a ListVariable internally. + + NOTE - Objects produced here are born UNGUARDED due to the nature of sources! + + NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant + if/else type->VariableTracker trees that were cropping up all over dynamo. + """ + + def __init__(self) -> None: + raise AssertionError("Use SourcelessBuilder.create()") + + @staticmethod + def create(tx: "InstructionTranslator", value) -> VariableTracker: + value_type = type(value) + fast_handler = SourcelessBuilder._type_handlers.get(value_type) + if fast_handler: + return fast_handler(tx, value) + + if isinstance(value, VariableTracker): + # This is always valid to call, and useful for recursive calls. + return value + elif isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS): + return UserDefinedObjectVariable(value) + elif ConstantVariable.is_literal(value): + return ConstantVariable.create(value) + elif callable(value) and trace_rules.lookup_callable(value) is not None: + if trace_rules.is_callable_allowed(value): + tx.output.has_user_defined_allowed_in_graph = True + return trace_rules.lookup_callable(value)(value) + elif callable(value) and UserDefinedClassVariable.is_supported_new_method( + value + ): + # NamedTuple._make uses an alias of tuple.__new__ + obj = trace_rules.lookup_callable(value.__self__)(value.__self__) + return GetAttrVariable(obj, "__new__") + elif is_function_or_wrapper(value): + return trace_rules.lookup(value)(value) + elif isinstance( + value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType) + ): + return EnumVariable(value) + elif isinstance(value, (type, abc.ABCMeta)): + return UserDefinedClassVariable(value) + elif isinstance(value, types.MethodWrapperType): + return MethodWrapperVariable(value) + elif ( + isinstance(value, types.MethodType) + # We only want to support sourceless class objects here + # An instance variable is not allowed and it should have source + and isinstance(value.__self__, (type, abc.ABCMeta)) + ): + # value is a classmethod + assert getattr(value.__self__, value.__func__.__name__) == value + cls_obj_vt = SourcelessBuilder.create(tx, value.__self__) + try: + return cls_obj_vt.var_getattr(tx, value.__func__.__name__) + except NotImplementedError: + pass # failthrough to unimplemented branch + elif isinstance(value, torch.fx.graph_module.GraphModule): + return SourcelessGraphModuleVariable(value) + elif isinstance(value, torch.utils._pytree.TreeSpec): + return UserDefinedObjectVariable(value) + elif PlacementVariable.is_placement(value): + return PlacementVariable(value) + elif DeviceMeshVariable.is_device_mesh(value): + return DeviceMeshVariable(value) + elif value is functools.wraps: + return FunctoolsWrapsVariable(value) + elif isinstance(value, re.Pattern): + return ConstantLikeVariable(value) + elif isinstance(value, torch._dynamo.variables.lazy.LazySymNodeFormatString): + return ConstantVariable.create(str(value)) + elif isinstance(value, type(torch._higher_order_ops.flex_attention_backward)): + return torch._dynamo.variables.higher_order_ops.FlexAttentionBackwardHighOrderVariable( + value + ) + elif isinstance(value, (types.GenericAlias, types.UnionType)): + return TypingVariable(value) + elif is_namedtuple(value): + output = [ + SourcelessBuilder.create(tx, getattr(value, name)) + for name in namedtuple_fields(type(value)) + ] + return NamedTupleVariable(output, tuple_cls=type(value)) + elif ( + isinstance(value, torch.SymInt) + and value.node.expr in tx.output.bound_symbols + ): + proxy = tx.output.bound_symbols[value.node.expr] + return SymNodeVariable.create(tx, proxy) + unimplemented( + gb_type="Unexpected type in sourceless builder", + context=f"{value_type.__module__}.{value_type.__qualname__}", + explanation=f"SourcelessBuilder.create does not know how to wrap {value_type}", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + + @staticmethod + def wrap_constant_literal(value): + assert ConstantVariable.is_literal(value) + return ConstantVariable.create(value=value) + + @staticmethod + def make_type_handlers(): + create = SourcelessBuilder.create + handlers = {} + for t in common_constant_types: + handlers[t] = lambda tx, value: ConstantVariable(value) + handlers[set] = lambda tx, value: SetVariable( + [create(tx, x) for x in value], mutation_type=ValueMutationNew() + ) + handlers[dict] = lambda tx, value: ConstDictVariable( + {create(tx, k): create(tx, v) for k, v in value.items()}, + type(value), + mutation_type=ValueMutationNew(), + ) + handlers[list] = lambda tx, value: ListVariable( + [create(tx, x) for x in value], mutation_type=ValueMutationNew() + ) + handlers[tuple] = lambda tx, value: TupleVariable( + [create(tx, x) for x in value] + ) + handlers[torch.Size] = lambda tx, value: SizeVariable( + [create(tx, x) for x in value] + ) + handlers[collections.OrderedDict] = handlers[dict] + handlers[immutable_dict] = handlers[dict] + handlers[immutable_list] = handlers[list] + handlers[random.Random] = lambda tx, value: RandomClassVariable() + handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value) + + handlers[torch.DispatchKeySet] = lambda tx, value: DispatchKeySetVariable( + value, mutation_type=ValueMutationNew() + ) + handlers[torch._functorch.pyfunctorch.FuncTorchInterpreter] = ( + lambda tx, value: FuncTorchInterpreterVariable( + value, mutation_type=ValueMutationNew() + ) + ) + + handlers[torch.distributions.constraints._Real] = ( + lambda tx, value: UserDefinedObjectVariable( + value, mutation_type=ValueMutationNew() + ) + ) + handlers[torch.distributions.constraints._Interval] = ( + lambda tx, value: UserDefinedObjectVariable( + value, mutation_type=ValueMutationNew() + ) + ) + handlers[torch.distributions.constraints.Constraint] = ( + lambda tx, value: UserDefinedObjectVariable( + value, mutation_type=ValueMutationNew() + ) + ) + + def passthrough(tx: "InstructionTranslator", value): + return value + + for cls in VariableTrackerMeta.all_subclasses: + handlers[cls] = passthrough + return handlers + + +SourcelessBuilder._type_handlers = SourcelessBuilder.make_type_handlers() + + +class SourcelessUserDefinedObjectBuilder: + """ + SourceLessBuilder does not return a UserDefinedObjectVariable, but in some + cases it might be ok to return UserDefinedObjects. In such case, use this + builder. + """ + + def __init__(self) -> None: + raise AssertionError("Use SourcelessUserDefinedObjectBuilder.create()") + + @staticmethod + def create(tx: "InstructionTranslator", value) -> VariableTracker: + value_type = type(value) + if issubclass(value_type, MutableMapping): + return MutableMappingVariable(value, mutation_type=ValueMutationNew()) + elif isinstance(value, torch.nn.Module): + return UnspecializedNNModuleVariable( + value, mutation_type=ValueMutationNew() + ) + else: + return UserDefinedObjectVariable(value, mutation_type=ValueMutationNew()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/builtin.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/builtin.py new file mode 100644 index 0000000000000000000000000000000000000000..44fca37314a62b79df1374270065f6d5837bfaab --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/builtin.py @@ -0,0 +1,3286 @@ +""" +Built-in function and type variable tracking for TorchDynamo's symbolic execution. + +This module contains variable tracker classes for Python built-in functions, types, +and operations during graph compilation. It handles symbolic execution of: + +- Built-in functions (len, getattr, isinstance, etc.) +- Type constructors (int, float, str, list, dict, etc.) +- Built-in operators and methods +- Special Python constructs (super, hasattr, etc.) + +Key classes: +- BuiltinVariable: Tracks built-in functions and handles their execution +- TypeVariable: Manages type constructor calls and type checking +- SuperVariable: Handles super() calls in class hierarchies + +These variable trackers ensure that built-in Python operations are correctly +handled during symbolic execution, either by executing them directly when safe +or by creating appropriate graph nodes when needed. +""" + +import contextlib +import functools +import inspect +import itertools +import logging +import math +import operator +import sys +import types +import typing +import unittest +from collections import defaultdict, OrderedDict +from collections.abc import Callable, Iterable, KeysView, Sequence +from typing import Any, cast, TYPE_CHECKING, Union + +import torch +from torch import sym_float, sym_int +from torch._subclasses.meta_utils import is_sparse_any +from torch.overrides import BaseTorchFunctionMode +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config, graph_break_hints, polyfills, variables +from ..exc import ( + AttributeMutationError, + ObservedAttributeError, + ObservedUserStopIteration, + raise_observed_exception, + unimplemented, + Unsupported, + UserError, + UserErrorType, +) +from ..guards import GuardBuilder, install_guard +from ..replay_record import DummyModule +from ..source import ( + AttrSource, + GetItemSource, + GlobalSource, + is_constant_source, + Source, + TypeSource, +) +from ..utils import ( + check_constant_args, + check_numpy_ndarray_args, + check_unspec_or_constant_args, + check_unspec_python_args, + cmp_name_to_op_mapping, + dict_methods, + extract_fake_example_value, + frozenset_methods, + get_fake_value, + guard_if_dyn, + is_tensor_getset_descriptor, + is_wrapper_or_member_descriptor, + istype, + numpy_operator_wrapper, + proxy_args_kwargs, + raise_args_mismatch, + set_methods, + str_methods, + tensortype_to_dtype, +) +from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker +from .constant import ConstantVariable +from .dicts import ( + ConstDictVariable, + DefaultDictVariable, + DictKeysVariable, + DictViewVariable, + FrozensetVariable, + is_hashable, + SetVariable, +) +from .lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + SizeVariable, + TupleIteratorVariable, + TupleVariable, +) +from .streams import EventVariable, StreamVariable +from .tensor import ( + FakeItemVariable, + supported_comparison_ops, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, +) +from .user_defined import ( + MutableMappingVariable, + UserDefinedDictVariable, + UserDefinedObjectVariable, + UserDefinedVariable, +) + + +if TYPE_CHECKING: + # Cyclic dependency... + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + +log = logging.getLogger(__name__) + + +IN_PLACE_DESUGARING_MAP = { + operator.iadd: operator.add, + operator.isub: operator.sub, + operator.imul: operator.mul, + operator.ifloordiv: operator.floordiv, + operator.itruediv: operator.truediv, + operator.imod: operator.mod, + operator.imatmul: operator.imatmul, + operator.ilshift: operator.lshift, + operator.irshift: operator.rshift, + operator.ipow: operator.pow, + operator.iand: operator.and_, + operator.ior: operator.or_, + operator.ixor: operator.xor, +} + + +_HandlerCallback = Callable[ + ["InstructionTranslator", typing.Any, typing.Any], VariableTracker | None +] +_TrackersType = Union[type[VariableTracker], tuple[type[VariableTracker], ...]] +polyfill_fn_mapping = { + operator.eq: polyfills.cmp_eq, + operator.ne: polyfills.cmp_ne, + operator.lt: polyfills.cmp_lt, + operator.le: polyfills.cmp_le, + operator.gt: polyfills.cmp_gt, + operator.ge: polyfills.cmp_ge, +} + +bin_ops = ( + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + operator.ne, + operator.eq, + operator.sub, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.imod, + operator.iadd, + operator.isub, +) + +bin_int_ops = ( + operator.and_, + operator.or_, + operator.xor, + operator.iand, + operator.ixor, + operator.ior, +) + +un_int_ops = (operator.invert,) + +tensor_and_int_ops = ( + operator.lshift, + operator.rshift, + operator.ilshift, + operator.irshift, + operator.getitem, +) + +un_ops = ( + operator.abs, + operator.pos, + operator.neg, + operator.not_, # Note: this has a local scalar dense call + operator.length_hint, +) + +BUILTIN_TO_TENSOR_FN_MAP: dict[Callable[..., Any], Callable[..., Any]] = {} + +# These functions represent the r* versions of the above ops +# Basically, if __add__(1, Tensor) is called, it is translated +# to __radd__(Tensor, 1). +# In the builtin var, we check if there is a tensor in the first args position, +# if not, we swap the args and use the r* version of the op. +BUILTIN_TO_TENSOR_RFN_MAP: dict[Callable[..., Any], Callable[..., Any]] = {} + + +def populate_builtin_to_tensor_fn_map() -> None: + global BUILTIN_TO_TENSOR_FN_MAP + if len(BUILTIN_TO_TENSOR_FN_MAP) > 0: + # Only populate once; after there are elements present no need to + # repopulate + return + most_recent_func: Callable[..., Any] | None = None + + class GetMethodMode(BaseTorchFunctionMode): + """ + Mode to extract the correct methods from torch function invocations + (Used to get the correct torch.Tensor methods from builtins) + """ + + def __torch_function__( + self, + func: Callable[..., Any], + types: Any, + args: Sequence[Any] = (), + kwargs: dict[str, Any] | None = None, + ) -> Any: + kwargs = kwargs or {} + nonlocal most_recent_func + most_recent_func = func + return func(*args, **kwargs) + + inp0 = torch.ones(1) + inp1 = torch.ones(1) + inp0_int = torch.ones(1, dtype=torch.int32) + inp1_int = torch.ones(1, dtype=torch.int32) + with GetMethodMode(): + setups_and_oplists: list[tuple[Callable[..., Any], Iterable[Any]]] = [ + (lambda o: o(inp0), un_ops), + (lambda o: o(inp0_int), un_int_ops), + (lambda o: o(inp0, inp1), bin_ops), + (lambda o: o(inp0_int, inp1_int), bin_int_ops), + (lambda o: o(inp0_int, 0), tensor_and_int_ops), + ] + for setup_fn, op_list in setups_and_oplists: + for op in op_list: + setup_fn(op) + assert most_recent_func is not None + BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func + + # gather the reverse functions + rsetups_and_oplists: list[tuple[Callable[..., Any], Iterable[Any]]] = [ + ( + lambda o: o(1, inp1), + bin_ops, + ), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int)) + (lambda o: o(1, inp1_int), bin_int_ops), + (lambda o: o(0, inp0_int), tensor_and_int_ops), + ] + + rskips = {operator.matmul, operator.imatmul, operator.getitem} + for setup_fn, op_list in rsetups_and_oplists: + for op in op_list: + if op in rskips: + continue + setup_fn(op) + assert most_recent_func is not None + if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]: + BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func + + +class BuiltinVariable(VariableTracker): + """ + A VariableTracker that represents a built-in value (functions and operators). + A lot of the code here assumes it will be a function object. + + The BuiltinVariable class wraps Python built-in functions (like len, isinstance, etc.) + and operators (like +, -, *, etc.) to enable symbolic execution during tracing. This allows + Dynamo to properly handle these operations when converting Python code to FX graphs while + maintaining correct semantics and enabling optimizations. + """ + + _SENTINEL = object() + _nonvar_fields = { + "fn", + *VariableTracker._nonvar_fields, + } + + @classmethod + def create_with_source(cls, value: Any, source: Source) -> "BuiltinVariable": + install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) + return cls(value, source=source) + + @staticmethod + @functools.cache + def _constant_fold_functions() -> set[Callable[..., Any]]: + fns: set[Callable[..., Any]] = { + abs, + all, + any, + bool, + callable, + chr, + complex, + divmod, + float, + getattr, + int, + len, + max, + min, + ord, + pow, + repr, + round, + str, + str.format, + sum, + type, + operator.abs, + operator.pos, + operator.neg, + operator.not_, + operator.truth, + operator.invert, + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.sub, + operator.getitem, + operator.length_hint, + operator.lshift, + operator.rshift, + operator.and_, + operator.or_, + operator.xor, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.imod, + operator.iadd, + operator.isub, + operator.ilshift, + operator.irshift, + operator.iand, + operator.ixor, + operator.ior, + operator.index, + } + from .tensor import supported_comparison_ops + + fns.update(supported_comparison_ops.values()) + fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt))) + return fns + + def can_constant_fold_through(self) -> bool: + return self.fn in self._constant_fold_functions() + + @staticmethod + @functools.cache + def _fx_graph_functions() -> set[Callable[..., Any]]: + fns = { + operator.abs, + operator.pos, + operator.neg, + operator.not_, + operator.invert, + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + operator.ne, + operator.eq, + operator.sub, + operator.length_hint, + operator.lshift, + operator.rshift, + operator.and_, + operator.or_, + operator.xor, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.getitem, + operator.imod, + operator.iadd, + operator.isub, + operator.ilshift, + operator.irshift, + operator.iand, + operator.ixor, + operator.ior, + } + return fns # type: ignore[return-value] + + @staticmethod + @functools.cache + def _binops() -> dict[ + Callable[..., object], tuple[list[str], Callable[..., object]] + ]: + # function -> ([forward name, reverse name, in-place name], in-place op) + fns: dict[Callable[..., object], tuple[list[str], Callable[..., object]]] = { + operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd), + operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub), + operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul), + operator.truediv: ( + ["__truediv__", "__rtruediv__", "__itruediv__"], + operator.itruediv, + ), + operator.floordiv: ( + ["__floordiv__", "__rfloordiv__", "__ifloordiv__"], + operator.ifloordiv, + ), + operator.mod: (["__mod__", "__rmod__", "__imod__"], operator.imod), + pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), + operator.pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), + operator.lshift: ( + ["__lshift__", "__rlshift__", "__ilshift__"], + operator.ilshift, + ), + operator.rshift: ( + ["__rshift__", "__rrshift__", "__irshift__"], + operator.irshift, + ), + operator.xor: (["__xor__", "__rxor__", "__ixor__"], operator.xor), + # NB: The follow binary operators are not supported for now, since the + # corresponding magic methods aren't defined on SymInt / SymFloat: + # operator.matmul + # divmod + # operator.and_ + # operator.or_ + } + return fns + + @staticmethod + @functools.cache + def _binop_handlers() -> dict[ + Callable[..., object], + list[ + tuple[ + tuple[ + type[VariableTracker], + _TrackersType, + ], + _HandlerCallback, + ] + ], + ]: + # Multiple dispatch mechanism defining custom binop behavior for certain type + # combinations. Handlers are attempted in order, and will be used if the type checks + # match. They are expected to have the signature: + # fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker + from .functions import BaseUserFunctionVariable, UserFunctionVariable + from .nn_module import NNModuleVariable + from .tensor import supported_const_comparison_ops + from .torch import BaseTorchVariable + from .user_defined import ( + UserDefinedClassVariable, + UserDefinedObjectVariable, + UserDefinedVariable, + ) + + # Override table contains: op_fn -> [list of handlers] + op_handlers: dict[Any, list[Any]] = {} + for ( + op, + (magic_method_names, in_place_op), + ) in BuiltinVariable._binops().items(): + op_handlers[op] = [] + op_handlers[in_place_op] = [] + + forward_name, reverse_name, inplace_name = magic_method_names + + # User-defined args (highest precedence) + def user_defined_handler( + tx: "InstructionTranslator", + a: VariableTracker, + b: VariableTracker, + *, + forward_name: str = forward_name, + reverse_name: str = reverse_name, + ) -> VariableTracker: + # Manually handle reversing logic if needed (e.g. call __radd__) + + # TODO: If we expand this to handle tensor args, we need to manually + # handle cases like this: + # + # class A(int): + # def __radd__(self, other): + # print("woof") + # torch.randn(3) + A(3) + # + # In this example, A.__radd__() is not called -> nothing is printed, because + # Tensor.__add__ only does a subtype test against int, ignoring the subclass. + # To be fully correct, we should not call A.__radd__() here, and there may be + # other cases to reason about and add exceptions for. + if isinstance(a, UserDefinedVariable): + return a.call_method(tx, forward_name, [b], {}) + else: + return b.call_method(tx, reverse_name, [a], {}) + + op_handlers[op].append( + ((UserDefinedVariable, VariableTracker), user_defined_handler) + ) + op_handlers[op].append( + ((VariableTracker, UserDefinedVariable), user_defined_handler) + ) + + def user_defined_inplace_handler( + tx: "InstructionTranslator", + a: VariableTracker, + b: VariableTracker, + *, + forward_name: str = inplace_name, + ) -> VariableTracker: + return a.call_method(tx, forward_name, [b], {}) + + op_handlers[in_place_op].append( + ((UserDefinedVariable, VariableTracker), user_defined_inplace_handler) + ) + op_handlers[in_place_op].append( + ((VariableTracker, UserDefinedVariable), user_defined_inplace_handler) + ) + + # Dynamic shape args + def dynamic_handler( + tx: "InstructionTranslator", + a: VariableTracker, + b: VariableTracker, + *, + fn: Callable[..., Any] = op, + ) -> VariableTracker: + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", fn, *proxy_args_kwargs([a, b], {}) + ), + ) + + op_handlers[op].append( + ((SymNodeVariable, VariableTracker), dynamic_handler) + ) + op_handlers[op].append( + ((VariableTracker, SymNodeVariable), dynamic_handler) + ) + + # NB: Prefer out-of-place op when calling in-place op to generate valid graph + op_handlers[in_place_op].append( + ((SymNodeVariable, VariableTracker), dynamic_handler) + ) + op_handlers[in_place_op].append( + ((VariableTracker, SymNodeVariable), dynamic_handler) + ) + + # Special cases - lower precedence but still prefer these over constant folding + + # List-like addition (e.g. [1, 2] + [3, 4]) + def tuple_add_handler( + tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker + ) -> VariableTracker: + return TupleVariable([*a.items, *b.unpack_var_sequence(tx)]) + + def size_add_handler( + tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker + ) -> VariableTracker: + return SizeVariable([*a.items, *b.unpack_var_sequence(tx)]) + + list_like_addition_handlers: list[ + tuple[ + tuple[ + type[VariableTracker], + _TrackersType, + ], + _HandlerCallback, + ] + ] = [ + # NB: Prefer the tuple-specific logic over base logic because of + # some SizeVariable weirdness. Specifically, the tuple-specific logic + # drops the subclass type (e.g. SizeVariable) and returns TupleVariables. + ( + (SizeVariable, SizeVariable), + size_add_handler, + ), + ( + (SizeVariable, TupleVariable), + size_add_handler, + ), + ( + (TupleVariable, SizeVariable), + size_add_handler, + ), + ( + (TupleVariable, TupleVariable), + tuple_add_handler, + ), + ( + (TupleVariable, ConstantVariable), + tuple_add_handler, + ), + ( + (ConstantVariable, TupleVariable), + lambda tx, a, b: TupleVariable( + [ + *a.unpack_var_sequence(tx), + *b.items, + ], + ), + ), + ( + ( + ListVariable, + (BaseListVariable, ConstantVariable, ListIteratorVariable), + ), + lambda tx, a, b: ListVariable( + [*a.items, *b.unpack_var_sequence(tx)], + mutation_type=ValueMutationNew(), + ), + ), + ( + (BaseListVariable, BaseListVariable), + lambda tx, a, b: type(a)( + [ + *a.items, + *b.items, + ] + ), + ), + ] + op_handlers[operator.add].extend(list_like_addition_handlers) + + def list_iadd_handler( + tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker + ) -> Any: + if a.is_immutable() or not b.has_unpack_var_sequence(tx): + # Handler doesn't apply + return None + + seq = b.unpack_var_sequence(tx) + tx.output.side_effects.mutation(a) + a.items.extend(seq) + return a + + list_like_iadd_handlers: list[Any] = [ + ( + (ListVariable, VariableTracker), + list_iadd_handler, + ), + ( + (TupleVariable, TupleVariable), + tuple_add_handler, + ), + ( + (TupleVariable, ConstantVariable), + tuple_add_handler, + ), + ] + op_handlers[operator.iadd].extend(list_like_iadd_handlers) + + # List-like expansion (e.g. [1, 2, 3] * 3) + def expand_list_like( + tx: "InstructionTranslator", lst: VariableTracker, const: VariableTracker + ) -> VariableTracker: + if not isinstance(lst, BaseListVariable) and lst.is_python_constant(): + lst, const = const, lst + try: + assert isinstance(lst, BaseListVariable) + return lst.__class__( + items=lst.items * const.as_python_constant(), + mutation_type=ValueMutationNew(), + ) + except MemoryError as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + + list_like_expansion_handlers: list[ + tuple[ + tuple[type[VariableTracker], type[VariableTracker]], + _HandlerCallback, + ] + ] = [ + ((ListVariable, ConstantVariable), expand_list_like), + ((TupleVariable, ConstantVariable), expand_list_like), + ((ConstantVariable, ListVariable), expand_list_like), + ((ConstantVariable, TupleVariable), expand_list_like), + ] + op_handlers[operator.mul].extend(list_like_expansion_handlers) + + def create_cmp_op_handlers( + op: Callable[..., Any], + ) -> list[tuple[tuple[_TrackersType, _TrackersType], _HandlerCallback]]: + def compare_by_value( + tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker: + try: + return ConstantVariable(op(a.value, b.value)) # type: ignore[attr-defined] + except TypeError as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + + result: list[ + tuple[ + tuple[ + _TrackersType, + _TrackersType, + ], + _HandlerCallback, + ] + ] = [((ConstantVariable, ConstantVariable), compare_by_value)] + + if op in polyfill_fn_mapping: + # For constants, speedup the comparison instead of using + # polyfill. Removing this line causes major regression for pr + # time benchmark - add_loop_eager. + result = [((ConstantVariable, ConstantVariable), compare_by_value)] + + op_var = BuiltinVariable(op) + # Special handling of SymNode variable + result.extend( + [ + ( + (SymNodeVariable, VariableTracker), + op_var._comparison_with_symnode, + ), + ( + (VariableTracker, SymNodeVariable), + op_var._comparison_with_symnode, + ), + ] + ) + + def handler( + tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker: + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfill_fn_mapping[op]), [a, b], {} + ) + + result.append(((VariableTracker, VariableTracker), handler)) + return result + + result = [((ConstantVariable, ConstantVariable), compare_by_value)] + + if op in supported_const_comparison_ops.values() and op.__name__.startswith( + "is_" + ): + # Tensor is None, List is not None, etc + none_result = op(object(), None) + + def never( + tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker: + return ConstantVariable(none_result) + + obj_op_none = never + none_op_obj = never + + types_that_are_never_none = ( + TensorVariable, + SymNodeVariable, + NNModuleVariable, + BaseListVariable, + UserDefinedVariable, + BaseUserFunctionVariable, + ConstDictVariable, + BaseTorchVariable, + ) + result.extend( + [ + ( + (types_that_are_never_none, ConstantVariable), + obj_op_none, + ), + ( + (ConstantVariable, types_that_are_never_none), + none_op_obj, + ), + ] + ) + + op_var = BuiltinVariable(op) + result.extend( + [ + ( + ( + (UserFunctionVariable, BuiltinVariable), + (UserFunctionVariable, BuiltinVariable), + ), + lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)), + ), + ( + ( + NNModuleVariable, + NNModuleVariable, + ), + lambda tx, a, b: ConstantVariable( + op( + tx.output.get_submodule(a.module_key), + tx.output.get_submodule(b.module_key), + ) + ), + ), + ( + (UserDefinedObjectVariable, UserDefinedObjectVariable), + compare_by_value, + ), + ( + (UserDefinedClassVariable, UserDefinedClassVariable), + compare_by_value, + ), + ( + ( + (StreamVariable, EventVariable, ConstantVariable), + (StreamVariable, EventVariable, ConstantVariable), + ), + compare_by_value, + ), + ( + (TensorVariable, VariableTracker), + op_var._comparison_with_tensor, + ), + ( + (VariableTracker, TensorVariable), + op_var._comparison_with_tensor, + ), + ( + (SymNodeVariable, VariableTracker), + op_var._comparison_with_symnode, + ), + ( + (VariableTracker, SymNodeVariable), + op_var._comparison_with_symnode, + ), + ] + ) + + def handle_is( + tx: "InstructionTranslator", + left: VariableTracker, + right: VariableTracker, + ) -> VariableTracker | None: + # If the two objects are of different type, we can safely return False + # and True for `is` and `is not`, respectively + if type(left) is not type(right): + return ConstantVariable.create(op.__name__ != "is_") + if left is right: + return ConstantVariable.create(op(left, right)) + if ( + istype(left, variables.ExceptionVariable) + and istype(right, variables.ExceptionVariable) + and left.exc_type is not right.exc_type + ): + return ConstantVariable.create(op(left, right)) + return None + + result.append(((VariableTracker, VariableTracker), handle_is)) # type: ignore[arg-type] + + return result + + for op in supported_comparison_ops.values(): + assert callable(op) + assert op not in op_handlers + op_handlers[op] = create_cmp_op_handlers(op) + + return op_handlers + + @staticmethod + def _find_binop_handler( + op: Callable[..., Any], a_type: type[VariableTracker], b_type: type + ) -> list[_HandlerCallback] | None: + handlers = BuiltinVariable._binop_handlers().get(op) + if handlers is None: + return None + + matches = [] + for (type1, type2), handler in handlers: + if issubclass(a_type, type1) and issubclass(b_type, type2): + matches.append(handler) + return matches + + def can_insert_in_graph(self) -> bool: + return self.fn in self._fx_graph_functions() + + def __init__(self, fn: Any, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.fn = fn + + def __repr__(self) -> str: + if self.fn is None: + name = "None" + else: + name = self.fn.__name__ + + return f"{self.__class__.__name__}({name})" + + def as_python_constant(self) -> Any: + return self.fn + + def as_proxy(self) -> Any: + DTYPE = { + bool: torch.bool, + int: torch.int64, + float: torch.float64, + } + if self.fn in DTYPE: + return DTYPE[self.fn] + return super().as_proxy() + + def reconstruct(self, codegen: "PyCodegen") -> None: + name = self.fn.__name__ + assert self.fn.__module__ == "builtins" + assert name not in codegen.tx.f_globals, "shadowed global" + codegen.append_output(codegen.create_load_global(name, add=True)) + + def constant_args(self, *args: VariableTracker, **kwargs: VariableTracker) -> bool: + return check_constant_args(args, kwargs) + + def tensor_args(self, *args: VariableTracker) -> bool: + any_tensor = False + for arg in args: + if isinstance(arg, variables.GetAttrVariable): + return False + any_tensor = any_tensor or arg.is_tensor() + return any_tensor + + def tensor_args_type(self, arg_types: list[type]) -> bool: + any_tensor = False + for arg_type in arg_types: + if issubclass(arg_type, variables.GetAttrVariable): + return False + any_tensor = any_tensor or issubclass(arg_type, variables.TensorVariable) + return any_tensor + + def python_and_tensor_constant_only( + self, *args: VariableTracker, **kwargs: VariableTracker + ) -> bool: + tensor_args = [] + non_tensor_args = [] + for i in itertools.chain(args, kwargs.values()): + if i.is_tensor(): + tensor_args.append(i) + else: + non_tensor_args.append(i) + return all( + is_constant_source(t.source) if t.source is not None else False + for t in tensor_args + ) and self.constant_args(*non_tensor_args) + + @staticmethod + def unwrap_unspec_args_kwargs( + args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker] + ) -> tuple[list[Any], dict[str, Any]]: + return [x.as_python_constant() for x in args], { + k: v.as_python_constant() for k, v in kwargs.items() + } + + def has_constant_handler( + self, args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker] + ) -> bool: + return self.can_constant_fold_through() and check_unspec_or_constant_args( + args, kwargs + ) + + @staticmethod + def _make_handler( + fn: Callable[..., Any], arg_types: list[type], has_kwargs: bool + ) -> Callable[ + [ + "InstructionTranslator", + Sequence[VariableTracker], + dict[str, VariableTracker], + ], + VariableTracker | None, + ]: + from .lazy import LazyVariableTracker + + obj = BuiltinVariable(fn) + handlers: list[_HandlerCallback] = [] + + if any(issubclass(t, LazyVariableTracker) for t in arg_types): + return lambda tx, args, kwargs: obj.call_function( + tx, [v.realize() for v in args], kwargs + ) + + if inspect.isclass(fn) and ( + issubclass(fn, Exception) + # GeneratorExit doesn't inherit from Exception + # >>> issubclass(GeneratorExit, Exception) + # False + or fn is GeneratorExit + ): + + def create_exception_class_object( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if fn is AssertionError and not all( + x.is_python_constant() and isinstance(x.as_python_constant(), str) + for x in args + ): + unimplemented( + gb_type="assert with non-string message", + context=str(args), + explanation="Dynamo only supports asserts with string messages", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + return variables.ExceptionVariable(fn, args, kwargs) + + return create_exception_class_object + + if obj.can_insert_in_graph() and not ( + fn is operator.getitem + and not issubclass(arg_types[0], variables.TensorVariable) + ): + if obj.tensor_args_type(arg_types): + return obj._handle_insert_op_in_graph + elif has_kwargs: + # need runtime check for kwargs + handlers.append(obj._handle_insert_op_in_graph) + + # Handle binary ops (e.g. __add__ / __radd__, __iadd__, etc.) + # NB: Tensor args are handled above and not here + if len(arg_types) == 2 and not has_kwargs: + # Try to find a handler for the arg types; otherwise, fall through to constant handler + binop_handlers = BuiltinVariable._find_binop_handler(fn, *arg_types) + if not binop_handlers: + pass + elif len(binop_handlers) == 1: + (binop_handler,) = binop_handlers + handlers.append(lambda tx, args, _: binop_handler(tx, *args)) + else: + + def call_binop_handlers( + tx: "InstructionTranslator", args: Any, _: Any + ) -> Any: + # pyrefly: ignore [not-iterable] + for fn in binop_handlers: + rv = fn(tx, *args) + if rv: + return rv + return None + + handlers.append(call_binop_handlers) + + self_handler = getattr(obj, f"call_{fn.__name__}", None) + if self_handler: + + def call_self_handler( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker | None: + try: + # pyrefly: ignore [not-callable] + return self_handler(tx, *args, **kwargs) + except TypeError: + # Check if binding is bad. inspect signature bind is expensive. + # So check only when handler call fails. + try: + # pyrefly: ignore [bad-argument-type] + inspect.signature(self_handler).bind(tx, *args, **kwargs) + except TypeError as e: + has_constant_handler = obj.has_constant_handler(args, kwargs) + if not has_constant_handler: + log.warning( # noqa: G200 + "incorrect arg count %s %s and no constant handler", + self_handler, + e, + ) + unimplemented( + gb_type="invalid call to builtin op handler", + context=f"invalid args to {self_handler}: {args} {kwargs}", + explanation=f"Encountered TypeError when trying to handle op {fn.__name__}", + hints=[*graph_break_hints.DIFFICULT], + ) + else: + raise + except Unsupported as exc: + has_constant_handler = obj.has_constant_handler(args, kwargs) + if not has_constant_handler: + raise + # Actually, we will handle this just fine + exc.remove_from_stats() + return None + + handlers.append(call_self_handler) + + if obj.can_constant_fold_through(): + if ( + all(issubclass(x, ConstantVariable) for x in arg_types) + and not has_kwargs + ): + + def constant_fold_handler( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker | None: + # fast path + try: + res = fn( + *[x.as_python_constant() for x in args], + ) + except Exception as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + except AsPythonConstantNotImplementedError as exc: + unimplemented( + gb_type="constant fold exception", + context=f"attempted to run function {fn} with arguments {args}", + explanation="Encountered exception when attempting to constant fold.", + hints=[*graph_break_hints.DYNAMO_BUG], + from_exc=exc, + ) + # pyrefly: ignore [unbound-name] + return VariableTracker.build(tx, res) + + else: + + def constant_fold_handler( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker | None: + # path with a runtime check + if check_unspec_or_constant_args(args, kwargs): + try: + res = fn( + *[x.as_python_constant() for x in args], + **{ + k: v.as_python_constant() for k, v in kwargs.items() + }, + ) + except AsPythonConstantNotImplementedError as exc: + unimplemented( + gb_type="constant fold exception", + context=f"attempted to run function {fn} with arguments {args}", + explanation="Encountered exception when attempting to constant fold.", + hints=[*graph_break_hints.DYNAMO_BUG], + from_exc=exc, + ) + except Exception as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + # pyrefly: ignore [unbound-name] + return VariableTracker.build(tx, res) + return None + + handlers.append(constant_fold_handler) + + def call_unimplemented(args: Sequence[VariableTracker]) -> None: + real_arg_types = [arg.python_type_name() for arg in args] + unimplemented( + gb_type="Failed to trace builtin operator", + context=f"builtin {fn.__name__} {arg_types} {has_kwargs}", + explanation=f"Dynamo does not know how to trace builtin operator `{fn.__name__}` " + f"with argument types {real_arg_types} (has_kwargs {has_kwargs})", + hints=[ + f"Avoid calling builtin `{fn.__name__}` with argument types {real_arg_types}. " + f"Consider using an equivalent alternative function/method to `{fn.__name__}`.", + "If you are attempting to call a logging function (e.g. `print`), " + "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.", + "Please report an issue to PyTorch.", + ], + ) + + if len(handlers) == 0: + return lambda tx, args, kwargs: call_unimplemented(args) + elif len(handlers) == 1: + (handler,) = handlers + + def builtin_dispatch( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker | None: + rv = handler(tx, args, kwargs) + if rv: + return rv + call_unimplemented(args) + return rv + + else: + + def builtin_dispatch( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker | None: + rv = None + for fn in handlers: + rv = fn(tx, args, kwargs) + if rv: + return rv + call_unimplemented(args) + return rv + + return builtin_dispatch + + def call_vars(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker: + if len(args) == 0: + unimplemented( + gb_type="unimplemented builtin op vars() with no arguments", + context=f"vars: {self} {args}", + explanation=f"Dynamo does not know how to trace builtin operator {self.fn} with no arguments", + hints=[*graph_break_hints.SUPPORTABLE], + ) + assert len(args) == 1 + # vars(obj) is obj.__dict__ if __dict__ is present else TypeError + try: + return args[0].var_getattr(tx, "__dict__") + except ObservedAttributeError: + raise_observed_exception(TypeError, tx) + + def _handle_insert_op_in_graph( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker | None: + from .builder import wrap_fx_proxy, wrap_fx_proxy_cls + + if kwargs and not self.tensor_args(*args, *kwargs.values()): + return None + + # insert handling for torch function here + from .builder import SourcelessBuilder + from .torch_function import can_dispatch_torch_function, dispatch_torch_function + + global BUILTIN_TO_TENSOR_RFN_MAP, BUILTIN_TO_TENSOR_FN_MAP + if can_dispatch_torch_function(tx, args, kwargs): + # Only remap the fn to tensor methods if we aren't exporting + # export serde does not handle method descriptors today + if not tx.export: + # Ensure the builtin maps are populated before accessing them + populate_builtin_to_tensor_fn_map() + # Use sourceless builder, we built the map ourselves + if not args[0].is_tensor(): + if self.fn in BUILTIN_TO_TENSOR_RFN_MAP: + func = BUILTIN_TO_TENSOR_RFN_MAP[self.fn] + else: + func = BUILTIN_TO_TENSOR_FN_MAP[self.fn] + + tmp = args[0] + # swap args and call reverse version of func + args[0] = args[1] # type: ignore[index] + args[1] = tmp # type: ignore[index] + else: + func = BUILTIN_TO_TENSOR_FN_MAP[self.fn] + else: + func = self.fn + + fn_var = SourcelessBuilder.create(tx, func) + + return dispatch_torch_function(tx, fn_var, args, kwargs) + + fn = self.fn + try: + # Constant fold for constant tensor and python constants + if self.python_and_tensor_constant_only(*args, **kwargs): + from ..bytecode_transformation import unique_id + from .functions import invoke_and_store_as_constant + + return invoke_and_store_as_constant( + tx, fn, unique_id(fn.__name__), args, kwargs + ) + + if fn in IN_PLACE_DESUGARING_MAP and isinstance( + args[0], variables.ConstantVariable + ): + # In-place operators like += usually mustate tensor + # values, but in the edge case of immutable values they + # re-bind the variable. + # + # The easiest way to keep the graph consistent in this + # scenario is to de-sugar eagerly. + fn = IN_PLACE_DESUGARING_MAP[fn] + args = [args[0], args[1]] # type: ignore[assignment] + + if fn is operator.getitem and isinstance(args[1], SymNodeVariable): + # Standard indexing will force specialization due to + # __index__. Rewrite as a regular torch op which will + # trace fine + fn = torch.select + args = [ + args[0], + variables.ConstantVariable.create(0), + args[1], + ] # type: ignore[assignment] + + # Interaction between ndarray and tensors: + # We prefer the tensor op whenever there are tensors involved + # NB: Use exact type check here - NumpyNdarrayVariable is a TensorVariable + # subclass but should NOT trigger the tensor path + if check_numpy_ndarray_args(args, kwargs) and not any( + type(arg) is TensorVariable for arg in args + ): + proxy = tx.output.create_proxy( + "call_function", + numpy_operator_wrapper(fn), + *proxy_args_kwargs(args, kwargs), + ) + + return wrap_fx_proxy_cls(variables.NumpyNdarrayVariable, tx, proxy) + + if fn is operator.eq and len(args) == 2 and args[0].is_tensor(): + # Dynamo expects `__eq__` str while operator.eq gives just `eq` + # TODO - supporting all comparison operators could also work but + # it fails lots of tests because graph str changes. + return args[0].call_method(tx, "__eq__", list(args[1:]), kwargs) + proxy = tx.output.create_proxy( + "call_function", + fn, + *proxy_args_kwargs(args, kwargs), + ) + if any(isinstance(arg, FakeItemVariable) for arg in args): + return wrap_fx_proxy_cls( + FakeItemVariable, + tx, + proxy, + ) + elif check_unspec_python_args(args, kwargs): + _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) + raw_value = fn(*_args, **_kwargs) + + need_unwrap = any( + x.need_unwrap + for x in itertools.chain(args, kwargs.values()) + if isinstance(x, variables.UnspecializedPythonVariable) + ) + + return wrap_fx_proxy_cls( + UnspecializedPythonVariable, + tx, + proxy, + raw_value=raw_value, + need_unwrap=need_unwrap, + ) + elif all(isinstance(x, SymNodeVariable) for x in args): + return SymNodeVariable.create(tx, proxy, None) + else: + # Work around for vision_maskrcnn due to precision difference + # specialize the dividend when float divide by tensor + if fn is operator.truediv and isinstance( + args[0], variables.UnspecializedPythonVariable + ): + args = list(args) + args[0] = args[0].as_python_constant() + return wrap_fx_proxy(tx, proxy) + + except NotImplementedError: + unimplemented( + gb_type="unimplemented builtin op on tensor arguments", + context=f"partial tensor op: {self} {args} {kwargs}", + explanation=f"Dynamo does not know how to trace builtin operator {self.fn} with tensor arguments", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + call_function_handler_cache: dict[ + tuple[object, ...], + Callable[ + [ + "InstructionTranslator", + Sequence[VariableTracker], + dict[str, VariableTracker], + ], + VariableTracker, + ], + ] = {} + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + key: tuple[object, ...] + if kwargs: + kwargs = {k: v.realize() for k, v in kwargs.items()} + key = (self.fn, *(type(x) for x in args), True) + else: + key = (self.fn, *(type(x) for x in args)) + + handler = self.call_function_handler_cache.get(key) + if not handler: + self.call_function_handler_cache[key] = handler = self._make_handler( # type: ignore[assignment] + self.fn, [type(x) for x in args], bool(kwargs) + ) + assert handler is not None + return handler(tx, args, kwargs) # type: ignore[return-value] + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.fn is object and name == "__setattr__": + assert len(args) == 3 + assert len(kwargs) == 0 + obj, name_var, val = args + obj = obj.realize() + if ( + isinstance(obj, UserDefinedObjectVariable) + and tx.output.side_effects.is_attribute_mutation(obj) + and name_var.is_python_constant() + ): + return obj.method_setattr_standard(tx, name_var, val) + + if name == "__new__": + # Supported __new__ methods + if self.fn is object and len(args) == 1: + assert len(kwargs) == 0 + return tx.output.side_effects.track_new_user_defined_object( + self, args[0], args[1:] + ) + + if self.fn is dict and len(args) == 1 and not kwargs: + dict_vt = ConstDictVariable({}, dict, mutation_type=ValueMutationNew()) + if isinstance(args[0], BuiltinVariable) and args[0].fn is dict: + return dict_vt + # We don't have to set the underlying dict_vt in + # UserDefinedDictVariable because it will be set to empty + # ConstDictVariableTracker in the constructor. + return tx.output.side_effects.track_new_user_defined_object( + self, + args[0], + args[1:], + ) + + if ( + self.fn is tuple + and len(args) == 2 + and args[1].has_force_unpack_var_sequence(tx) + and not kwargs + ): + if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple: + init_args = args[1].force_unpack_var_sequence(tx) + return variables.TupleVariable( + init_args, mutation_type=ValueMutationNew() + ) + + return tx.output.side_effects.track_new_user_defined_object( + self, + args[0], + args[1:], + ) + + if self.fn is list: + list_vt = ListVariable([], mutation_type=ValueMutationNew()) + if isinstance(args[0], BuiltinVariable) and args[0].fn is list: + return list_vt + return tx.output.side_effects.track_new_user_defined_object( + self, + args[0], + args[1:], + ) + + if ( + self.fn in (float, complex) + and len(args) == 1 + and ( + (self.fn is float and name in ("fromhex", "hex")) + or (name == "from_number" and sys.version_info >= (3, 14)) + ) + ): + if args[0].is_python_constant(): + try: + fn = getattr(self.fn, name) + res = fn(args[0].as_python_constant()) + return variables.ConstantVariable.create(res) + except (OverflowError, ValueError) as e: + raise_observed_exception( + type(e), + tx, + args=list(map(ConstantVariable.create, e.args)), + ) + + if self.fn is object and name == "__init__": + # object.__init__ is a no-op + return variables.ConstantVariable(None) + + if self.fn is dict and name == "fromkeys": + return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) + + if self.fn is dict: + resolved_fn = getattr(self.fn, name) + if resolved_fn in dict_methods: + if isinstance(args[0], variables.UserDefinedDictVariable): + # pyrefly: ignore [missing-attribute] + return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs) + elif isinstance(args[0], variables.ConstDictVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + + if self.fn is set: + resolved_fn = getattr(self.fn, name) + if resolved_fn in set_methods: + if isinstance(args[0], variables.UserDefinedSetVariable): + # pyrefly: ignore [missing-attribute] + return args[0]._set_vt.call_method(tx, name, args[1:], kwargs) + elif isinstance(args[0], variables.SetVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + + if self.fn is frozenset: + resolved_fn = getattr(self.fn, name) + if resolved_fn in frozenset_methods: + if isinstance(args[0], variables.FrozensetVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + + if self.fn is str and len(args) >= 1: + resolved_fn = getattr(self.fn, name) + if resolved_fn in str_methods: + # Only delegate to ConstantVariable, not other types that happen to be constants + if isinstance(args[0], ConstantVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + + if self.fn is float and len(args) >= 1: + # Only delegate to ConstantVariable, not other types that happen to be constants + if isinstance(args[0], ConstantVariable): + return ConstantVariable.create( + getattr(float, name)(args[0].as_python_constant()) + ) + + return super().call_method(tx, name, args, kwargs) + + def _call_int_float( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker | None: + # Handle cases like int(torch.seed()) + # Also handle sym_float to sym_int cases + if arg.is_tensor() or isinstance(arg, SymNodeVariable): + if arg.is_tensor(): + item = arg.call_method(tx, "item", [], {}) + else: + item = arg + fn_ = sym_int if self.fn is int else sym_float + from torch._dynamo.variables.builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + (item.as_proxy(),), + {}, + ), + ) + return None + + call_int = _call_int_float + call_float = _call_int_float + + def call_bool( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker | None: + # Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`. + # https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697 + if isinstance(arg, SymNodeVariable): + # Note that we delay specializing on symbolic values to avoid + # unnecessary guards. Specialization will happen later if, e.g., the + # resulting boolean is used for branching. + if isinstance(arg.sym_num, torch.SymBool): + return arg + + # Emulate `nb_bool` of int/float objects + # - https://github.com/python/cpython/blob/3.12/Objects/longobject.c#L4940-L4944 + # - https://github.com/python/cpython/blob/3.12/Objects/floatobject.c#L878-L882 + assert istype(arg.sym_num, (torch.SymInt, torch.SymFloat)) + return SymNodeVariable.create(tx, arg.as_proxy() != 0) + + # TODO handle more cases and merge this with this with `generic_jump`. + return None + + def call_repr(self, tx: "InstructionTranslator", arg): + """Handle repr() on user defined objects.""" + if isinstance(arg, variables.UserDefinedObjectVariable): + repr_method = arg.value.__repr__ + + if type(arg.value).__repr__ is object.__repr__: + # Default repr - build and trace it + fn_vt = VariableTracker.build(tx, repr_method) + return fn_vt.call_function(tx, [], {}) + else: + # Custom repr - inline the method for tracing + bound_method = repr_method.__func__ + fn_vt = VariableTracker.build(tx, bound_method) + return fn_vt.call_function(tx, [arg], {}) + + def call_str( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker | None: + # Handle `str` on a user defined function or object + if isinstance(arg, (variables.UserFunctionVariable)): + return variables.ConstantVariable.create(value=str(arg.fn)) + elif isinstance(arg, (variables.UserDefinedObjectVariable)): + # Check if object has __str__ method + if hasattr(arg.value, "__str__"): + str_method = arg.value.__str__ + elif hasattr(arg.value, "__repr__"): + # account for __repr__ functions when __str__ is absent + str_method = arg.value.__repr__ + else: + unimplemented( + gb_type="failed to call str() on user defined object", + context=str(arg), + explanation="User defined object has no __str__ or __repr__ method", + hints=[*graph_break_hints.USER_ERROR], + ) + + if type(arg.value).__str__ is object.__str__: + # Rely on the object str method + try: + # pyrefly: ignore [unbound-name] + return variables.ConstantVariable.create(value=str_method()) + except AttributeError: + # Graph break + return None + # pyrefly: ignore [unbound-name] + elif is_wrapper_or_member_descriptor(str_method): + unimplemented( + gb_type="Attempted to a str() method implemented in C/C++", + context="", + explanation=f"{type(arg.value)} has a C/C++ based str method. This is not supported.", + hints=["Write the str method in Python"], + ) + else: + # Overrides for custom str method + # Pass method as function to call tx.inline_user_function_return + bound_method = str_method.__func__ # type: ignore[attr-defined] + + try: + # Only supports certain function types + user_func_variable = VariableTracker.build(tx, bound_method) + except AssertionError: + # Won't be able to do inline the str method, return to avoid graph break + log.warning("Failed to create UserFunctionVariable", exc_info=True) + return None + + # Inline the user function + return user_func_variable.call_function(tx, [arg], {}) + elif isinstance(arg, (variables.ExceptionVariable,)): + if len(arg.args) == 0: + value = f"{arg.exc_type}" + else: + value = ", ".join(a.as_python_constant() for a in arg.args) + return variables.ConstantVariable.create(value=value) + return None + + def _call_min_max( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker | None: + if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) + return self._call_min_max_seq(tx, items) + elif len(args) == 2: + return self._call_min_max_binary(tx, args[0], args[1]) + elif len(args) > 2: + return self._call_min_max_seq(tx, args) + return None + + def _call_min_max_seq( + self, tx: "InstructionTranslator", items: Sequence[VariableTracker] + ) -> VariableTracker: + assert len(items) > 0 + if len(items) == 1: + return items[0] + + return functools.reduce(functools.partial(self._call_min_max_binary, tx), items) # type: ignore[arg-type,return-value] + + def _call_min_max_binary( + self, + tx: "InstructionTranslator", + a: VariableTracker | None, + b: VariableTracker | None, + ) -> VariableTracker | None: + if a is None or b is None: + # a or b could be none if we reduce and _call_min_max_binary failed + # to return something + return None + if self.tensor_args(a, b): + if not a.is_tensor(): + a, b = b, a + assert a.is_tensor() + + # result of an item call is a scalar convert to a tensor + if isinstance(a, FakeItemVariable): + a = variables.TorchInGraphFunctionVariable(torch.tensor).call_function( + tx, [a], {} + ) + + # Dynamic input does not get resolved, rather, gets stored as call_function + if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): + from .builder import wrap_fx_proxy_cls + + return wrap_fx_proxy_cls( + type(a), + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.fn, + *proxy_args_kwargs([a, b], {}), + ), + ) + + # convert min/max to torch ops + if b.is_python_constant(): + fn: VariableTracker + if isinstance(a, variables.NumpyNdarrayVariable): + import numpy as np + + fn = variables.NumpyVariable(np.clip) + else: + fn = variables.TorchInGraphFunctionVariable(torch.clamp) + kwargs = {"min": b} if (self.fn is max) else {"max": b} + result = fn.call_function(tx, [a], kwargs) + else: + if isinstance(a, variables.NumpyNdarrayVariable): + import numpy as np + + np_fn = {max: np.maximum, min: np.minimum}[self.fn] + fn = variables.NumpyVariable(np_fn) + else: + torch_fn = {max: torch.maximum, min: torch.minimum}[self.fn] + fn = variables.TorchInGraphFunctionVariable(torch_fn) + result = fn.call_function(tx, [a, b], {}) + + # return unspec if both a, b are unspec or const + if all( + isinstance( + i, + ( + variables.UnspecializedPythonVariable, + variables.ConstantVariable, + ), + ) + for i in [a, b] + ): + if any(isinstance(val, FakeItemVariable) for val in [a, b]): + return variables.FakeItemVariable.from_tensor_variable(result) + + if b.is_python_constant(): + raw_b = b.as_python_constant() + else: + raw_b = b.raw_value # type: ignore[attr-defined] + if self.fn is max: + raw_res = max(a.raw_value, raw_b) # type: ignore[attr-defined] + else: + raw_res = min(a.raw_value, raw_b) # type: ignore[attr-defined] + + need_unwrap = any( + x.need_unwrap + for x in [a, b] + if isinstance(x, variables.UnspecializedPythonVariable) + ) + return variables.UnspecializedPythonVariable.from_tensor_variable( + result, raw_res, need_unwrap + ) + # otherwise return tensor + else: + return result + elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): + py_fn = torch.sym_max if self.fn is max else torch.sym_min + proxy = tx.output.create_proxy( + "call_function", py_fn, *proxy_args_kwargs([a, b], {}) + ) + return SymNodeVariable.create(tx, proxy, None) + elif isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + value = self.fn( + a.as_python_constant(), + b.as_python_constant(), + ) + return ConstantVariable.create(value) + return None + + call_min = _call_min_max + call_max = _call_min_max + + def call_abs( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + # Call arg.__abs__() + abs_method = BuiltinVariable(getattr).call_function( + tx, [arg, ConstantVariable.create("__abs__")], {} + ) + return abs_method.call_function(tx, [], {}) + + def call_pos( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + # Call arg.__pos__() + pos_method = BuiltinVariable(getattr).call_function( + tx, [arg, ConstantVariable.create("__pos__")], {} + ) + return pos_method.call_function(tx, [], {}) + + def call_index( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + if arg.is_tensor(): + unimplemented( + gb_type="unsupported index(Tensor)", + context="", + explanation="Dynamo does not support tracing builtin index() on a Tensor", + hints=[], + ) + + arg = guard_if_dyn(arg) + constant_value = operator.index(arg) + return variables.ConstantVariable.create(constant_value) + + def call_round( + self, + tx: "InstructionTranslator", + arg: VariableTracker, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + # Call arg.__round__() + round_method = BuiltinVariable(getattr).call_function( + tx, [arg, ConstantVariable.create("__round__")], {} + ) + return round_method.call_function(tx, args, kwargs) + + def call_range( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker | None: + if check_unspec_or_constant_args(args, {}): + return variables.RangeVariable(args) + elif self._dynamic_args(*args): + args = tuple( + variables.ConstantVariable.create(guard_if_dyn(arg)) for arg in args + ) + return variables.RangeVariable(args) + # None no-ops this handler and lets the driving function proceed + return None + + def _dynamic_args(self, *args: VariableTracker, **kwargs: VariableTracker) -> bool: + return any(isinstance(x, SymNodeVariable) for x in args) or any( + isinstance(x, SymNodeVariable) for x in kwargs.values() + ) + + def call_slice( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + return variables.SliceVariable(args, tx) + + def _dyn_proxy( + self, tx: "InstructionTranslator", *args: Any, **kwargs: Any + ) -> VariableTracker: + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", self.fn, *proxy_args_kwargs(args, kwargs) + ), + ) + + # NOTE must handle IteratorVariable separately! + def _call_iter_tuple_list( + self, + tx: "InstructionTranslator", + obj: VariableTracker | None = None, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker | None: + assert not isinstance(obj, variables.IteratorVariable) + + if self._dynamic_args(*args, **kwargs): + return self._dyn_proxy(tx, *args, **kwargs) + + cls = variables.BaseListVariable.cls_for(self.fn) + if obj is None: + return cls( + [], + mutation_type=ValueMutationNew(), + ) + elif obj.has_unpack_var_sequence(tx): + if obj.source and not is_constant_source(obj.source): + if isinstance(obj, TupleIteratorVariable): + install_guard( + obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN) + ) + else: + if ( + getattr(obj, "source", False) + and isinstance(obj, ConstDictVariable) + and not istype(obj, (SetVariable, FrozensetVariable)) + ): + tx.output.guard_on_key_order.add(obj.source) + + if isinstance(obj, variables.MappingProxyVariable): + # This could be an overguarding, but its rare to iterate + # through a mapping proxy and not use the keys. + install_guard( + obj.source.make_guard(GuardBuilder.MAPPING_KEYS_CHECK) + ) + elif not isinstance(obj, variables.UnspecializedNNModuleVariable): + # Prevent calling __len__ method for guards, the tracing + # of __iter__ will insert the right guards later. + install_guard( + obj.source.make_guard(GuardBuilder.SEQUENCE_LENGTH) + ) + + return cls( + list(obj.unpack_var_sequence(tx)), + mutation_type=ValueMutationNew(), + ) + return None + + def _call_iter_tuple_generator( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + cls = variables.BaseListVariable.cls_for(self.fn) + return cls( + list(obj.force_unpack_var_sequence(tx)), # exhaust generator + mutation_type=ValueMutationNew(), + ) + + def _call_tuple_list( + self, + tx: "InstructionTranslator", + obj: VariableTracker | None = None, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker | None: + if isinstance(obj, variables.IteratorVariable): + cls = variables.BaseListVariable.cls_for(self.fn) + return cls( + list(obj.force_unpack_var_sequence(tx)), + mutation_type=ValueMutationNew(), + ) + elif isinstance(obj, variables.LocalGeneratorObjectVariable) or ( + isinstance(obj, UserDefinedObjectVariable) + and obj.has_force_unpack_var_sequence(tx) + ): + return self._call_iter_tuple_generator(tx, obj, *args, **kwargs) + else: + return self._call_iter_tuple_list(tx, obj, *args, **kwargs) + + def call_iter( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + # avoid the overhead of tracing the polyfill if we already know the class implemented __iter__ + if isinstance( + obj, + ( + variables.ListVariable, + variables.RangeVariable, + variables.IteratorVariable, + variables.ConstDictVariable, + variables.NNModuleVariable, + variables.TensorVariable, + ), + ): + return obj.call_method(tx, "__iter__", [], {}) + else: + # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway. + # If the object implements a __iter__ method, inlining effectively forwards the call to another iter call + # (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator. + # If the object implements a __getitem__ method, iter(...) will call obj.__getitem__() + # with an integer argument starting at 0, until __getitem__ raises IndexError + ret = variables.UserFunctionVariable( + polyfills.builtins.iter_ # type: ignore[arg-type] + ).call_function(tx, [obj, *args], {}) + + if args: + # iter(obj, sentinel) returns an object that implements + # __iter__ and __next__ methods (UserDefinedObjectVariable) + # Wrap the return value in a IteratorVariable subclass (LazyObjectIteratorVariable) + # that forwards the next_variable call to the object. + ret = variables.ObjectIteratorVariable(ret) + return ret + + call_tuple = _call_tuple_list + call_list = _call_tuple_list + + def call_callable( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker | None: + from .functions import BaseUserFunctionVariable, FunctoolsPartialVariable + from .nn_module import NNModuleVariable + + if isinstance( + arg, + ( + variables.UserDefinedClassVariable, + BaseUserFunctionVariable, + FunctoolsPartialVariable, + NNModuleVariable, + ), + ): + return variables.ConstantVariable.create(True) + elif isinstance(arg, UserDefinedVariable): + return variables.ConstantVariable.create(callable(arg.value)) + elif isinstance( + arg, + ( + ConstantVariable, + SymNodeVariable, + TensorVariable, + ListVariable, + TupleVariable, + ListIteratorVariable, + ), + ): + return variables.ConstantVariable.create(False) + else: + return None + + def call_cast( + self, _: Any, *args: VariableTracker, **kwargs: VariableTracker + ) -> VariableTracker | None: + if len(args) == 2: + return args[1] + + unimplemented( + gb_type="bad args to builtin cast()", + context=f"got args {args} {kwargs}", + explanation="Dynamo expects exactly 2 args to builtin cast().", + hints=["Ensure your call to cast() has exactly 2 arguments."], + ) + + def call_dir( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker | None: + if isinstance(arg, variables.UserDefinedClassVariable): + return VariableTracker.build(tx, dir(arg.value)) + if isinstance(arg, BuiltinVariable): + return VariableTracker.build(tx, dir(arg.fn)) + return None + + def call_dict( + self, + tx: "InstructionTranslator", + /, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs) + + @staticmethod + def call_custom_dict( + tx: "InstructionTranslator", + user_cls: type, + /, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + args_list = list(args) + if ( + len(args_list) == 1 + and isinstance(args_list[0], variables.GetAttrVariable) + and isinstance(args_list[0].obj, variables.UserDefinedClassVariable) + and not tx.output.side_effects.has_pending_mutation(args_list[0].obj) + ): + # Forward the GetAttrVariable(foo, "__dict__") to a realized vt of + # VT(foo.__dict__). This simplifies the construction of the new + # dict. + args_list[0] = args_list[0].get_forwarded_dict(tx) + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.construct_dict), + [VariableTracker.build(tx, user_cls), *args_list], + kwargs, + ) + + @staticmethod + def call_custom_dict_fromkeys( + tx: "InstructionTranslator", + user_cls: type, + /, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + if user_cls not in {dict, OrderedDict, defaultdict}: + unimplemented( + gb_type="Unsupported dict type for fromkeys()", + context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}", + explanation=f"Failed to call {user_cls.__name__}.fromkeys() because " + f"{user_cls.__name__} is not any type of dict, OrderedDict, or defaultdict", + hints=[ + f"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict.", + ], + ) + if kwargs: + # Only `OrderedDict.fromkeys` accepts `value` passed by keyword + if ( + user_cls is not OrderedDict + or len(args) != 1 + or len(kwargs) != 1 + or "value" not in kwargs + ): + raise_args_mismatch( + tx, + f"{user_cls.__name__}.fromkeys", + "1 args and 1 kwargs (`value`)", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + args = (*args, kwargs.pop("value")) + if len(args) == 0: + raise_args_mismatch( + tx, + f"{user_cls.__name__}.fromkeys", + "at least 1 args", + f"{len(args)} args", + ) + if len(args) == 1: + args = (*args, ConstantVariable.create(None)) + if len(args) != 2: + raise_args_mismatch( + tx, + f"{user_cls.__name__}.fromkeys", + "2 args", + f"{len(args)} args", + ) + # pyrefly: ignore [bad-unpacking] + arg, value = args + DictVariableType = ( + ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable + ) + + if isinstance(arg, dict): + arg_list = [ConstantVariable.create(k) for k in arg] + return DictVariableType( + # pyrefly: ignore [bad-argument-type] + dict.fromkeys(arg_list, value), + user_cls, + mutation_type=ValueMutationNew(), + ) + elif arg.has_force_unpack_var_sequence(tx): + keys = arg.force_unpack_var_sequence(tx) + if all(is_hashable(v) for v in keys): + return DictVariableType( + # pyrefly: ignore [bad-argument-type] + dict.fromkeys(keys, value), + user_cls, + mutation_type=ValueMutationNew(), + ) + + unimplemented( + gb_type="failed to call dict.fromkeys()", + context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}", + explanation=f"Failed to call {user_cls.__name__}.fromkeys() because " + "arguments could not be automatically converted to a list, " + "or some dict key is not hashable.", + hints=[ + "Manually convert the argument to a list.", + "Ensure all keys are hashable.", + ], + ) + + def call_set( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + # Can we merge this implementation and call_dict's one? + assert not kwargs + if not args: + return SetVariable([], mutation_type=ValueMutationNew()) + if len(args) != 1: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create( + f"set() takes 1 positional argument but {len(args)} were given" + ) + ], + ) + arg = args[0] + if istype(arg, variables.SetVariable): + return arg.clone(mutation_type=ValueMutationNew()) + elif arg.has_force_unpack_var_sequence(tx): + items = arg.force_unpack_var_sequence(tx) + return SetVariable(items, mutation_type=ValueMutationNew()) + elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( + arg.value, KeysView + ): + iter_fn = arg.var_getattr(tx, "__iter__") + if isinstance(iter_fn, variables.UserMethodVariable): + out = tx.inline_user_function_return(iter_fn, args, kwargs) + if isinstance(out, SetVariable): + return out + return BuiltinVariable(set).call_set(tx, out) + raise_observed_exception( + TypeError, + tx, + args=[ConstantVariable.create("failed to construct builtin set()")], + ) + + def call_frozenset( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + assert not kwargs + if not args: + return FrozensetVariable([]) + if len(args) != 1: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create( + f"frozenset() takes 1 positional argument but {len(args)} were given" + ) + ], + ) + arg = args[0] + if istype(arg, variables.FrozensetVariable): + return FrozensetVariable([x.vt for x in arg.set_items]) + elif arg.has_force_unpack_var_sequence(tx): + items = arg.force_unpack_var_sequence(tx) + return FrozensetVariable(items) + raise_observed_exception( + TypeError, + tx, + args=[ConstantVariable.create("failed to construct builtin frozenset()")], + ) + + def call_zip( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + if kwargs: + if not (len(kwargs) == 1 and "strict" in kwargs): + raise_args_mismatch( + tx, + "zip", + "1 kwargs (`strict`)", + f"{len(kwargs)} kwargs", + ) + strict = kwargs.pop("strict", ConstantVariable.create(False)) + iter_args = [BuiltinVariable(iter).call_function(tx, [arg], {}) for arg in args] + return variables.ZipVariable( + iter_args, + strict=strict.as_python_constant(), + mutation_type=ValueMutationNew(), + ) + + def call_len( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + try: + return args[0].call_method(tx, "__len__", list(args[1:]), kwargs) + except AttributeError as e: + raise_observed_exception(type(e), tx, args=list(e.args)) + + def call_getitem( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + return args[0].call_method(tx, "__getitem__", list(args[1:]), kwargs) + + def call_isinstance( + self, + tx: "InstructionTranslator", + arg: VariableTracker, + isinstance_type_var: VariableTracker, + ) -> VariableTracker: + try: + arg_type = arg.python_type() + except NotImplementedError: + unimplemented( + gb_type="builtin isinstance() cannot determine type of argument", + context=f"isinstance({arg}, {isinstance_type_var})", + explanation=f"Dynamo doesn't have a rule to determine the type of argument {arg}", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + isinstance_type = isinstance_type_var.as_python_constant() + if isinstance(arg, variables.TensorVariable) and arg.dtype is not None: + + def _tensor_isinstance( + tensor_var: VariableTracker, tensor_type: Any + ) -> bool: + def check_type(ty: Any) -> bool: + if ty not in tensortype_to_dtype: + example_val = arg.as_proxy().node.meta["example_value"] + if ( + is_traceable_wrapper_subclass(example_val) + and ty is torch.nn.parameter.Parameter + ): + # N.B: we are calling isinstance directly on the example value. + # torch.nn.Parameter has a meta-class that overrides __isinstance__, + # the isinstance check here allows us to invoke that logic. + return isinstance(example_val, ty) + else: + return issubclass(arg.python_type(), ty) + + dtypes = tensortype_to_dtype[ty] + # pyrefly: ignore [missing-attribute] + return arg.dtype in dtypes + + if type(tensor_type) is tuple: + return any(check_type(ty) for ty in tensor_type) + else: + return check_type(tensor_type) + + return variables.ConstantVariable.create( + _tensor_isinstance(arg, isinstance_type) + ) + # UserDefinedObject with C extensions can have torch.Tensor attributes, + # so break graph. + if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( + arg.value, types.MemberDescriptorType + ): + unimplemented( + gb_type="isinstance() called on user defined object with C extensions", + context=f"isinstance({arg}, {isinstance_type})", + explanation="User-defined object with C extensions can have torch.Tensor " + "attributes; intentionally graph breaking.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + # handle __instancecheck__ defined in user class + if ( + isinstance(arg, variables.UserDefinedObjectVariable) + and "__instancecheck__" in isinstance_type.__class__.__dict__ + ): + return variables.ConstantVariable.create( + isinstance_type.__class__.__instancecheck__(isinstance_type, arg.value) + ) + + if isinstance(arg, variables.UserDefinedExceptionClassVariable): + # pyrefly: ignore [unbound-name] + return ConstantVariable.create(isinstance(arg_type, isinstance_type)) + + isinstance_type_tuple: tuple[type, ...] + if isinstance(isinstance_type, type) or callable( + # E.g. isinstance(obj, typing.Sequence) + getattr(isinstance_type, "__instancecheck__", None) + ): + isinstance_type_tuple = (isinstance_type,) + elif isinstance(isinstance_type, types.UnionType): + isinstance_type_tuple = isinstance_type.__args__ + elif isinstance(isinstance_type, tuple) and all( + isinstance(tp, type) or callable(getattr(tp, "__instancecheck__", None)) + for tp in isinstance_type + ): + isinstance_type_tuple = isinstance_type + else: + raise_observed_exception( + TypeError, + tx, + args=[ + "isinstance() arg 2 must be a type, a tuple of types, or a union" + ], + ) + + try: + # NB: `isinstance()` does not call `__subclasscheck__` but use `__instancecheck__`. + # But usually `isinstance(obj, type_info)` and `issubclass(type(obj), type_info)` gives + # the same result. + # WARNING: This might run arbitrary user code `__subclasscheck__` and we did not trace + # through it. This is a limitation of the current implementation. + # Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it + # might not be a big issue and we trade off it for performance. + # pyrefly: ignore [unbound-name] + val = issubclass(arg_type, isinstance_type_tuple) + except TypeError: + # pyrefly: ignore [unbound-name] + val = arg_type in isinstance_type_tuple + return variables.ConstantVariable.create(val) + + def call_issubclass( + self, + tx: "InstructionTranslator", + left_ty: VariableTracker, + right_ty: VariableTracker, + ) -> VariableTracker: + """Checks if first arg is subclass of right arg""" + try: + left_ty_py = left_ty.as_python_constant() + right_ty_py = right_ty.as_python_constant() + except NotImplementedError: + unimplemented( + gb_type="issubclass() with non-constant arguments", + context=f"issubclass({left_ty}, {right_ty})", + explanation="issubclass() with non-constant arguments not supported.", + hints=[ + "Make sure your arguments are types.", + *graph_break_hints.USER_ERROR, + ], + ) + + # WARNING: This might run arbitrary user code `__subclasscheck__`. + # See the comment in call_isinstance above. + # pyrefly: ignore [unbound-name] + return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py)) + + def call_super( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker: + return variables.SuperVariable(a, b) + + def call_next( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + arg = args[0] + try: + return arg.next_variable(tx) + except ObservedUserStopIteration: + if len(args) == 2: + return args[1] + raise + except Unsupported as ex: + if isinstance(arg, variables.BaseListVariable): + ex.remove_from_stats() + return arg.items[0] + raise + + def call_hasattr( + self, tx: "InstructionTranslator", obj: VariableTracker, attr: VariableTracker + ) -> VariableTracker | None: + if attr.is_python_constant(): + name = attr.as_python_constant() + if isinstance(obj, variables.BuiltinVariable): + return variables.ConstantVariable(hasattr(obj.fn, name)) + return obj.call_obj_hasattr(tx, name) + return None + + def call_map( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + *seqs: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + strict = ConstantVariable.create(False) + if kwargs: + if sys.version_info >= (3, 14): + if not (len(kwargs) == 1 and "strict" in kwargs): + raise_args_mismatch( + tx, + "map", + "1 kwargs (`strict`)", + f"{len(kwargs)} kwargs", + ) + strict = kwargs.pop("strict", ConstantVariable.create(False)) + else: + raise_args_mismatch( + tx, + "map", + "0 kwargs", + f"{len(kwargs)} kwargs", + ) + + seq_list = [ + seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq + for seq in seqs + ] + return variables.MapVariable( + fn, + seq_list, # type: ignore[arg-type] + strict=strict.as_python_constant(), + mutation_type=ValueMutationNew(), + ) + + def call_filter( + self, tx: "InstructionTranslator", fn: VariableTracker, seq: VariableTracker + ) -> VariableTracker: + seq_or_list = ( + seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq + ) + return variables.FilterVariable( + fn, + seq_or_list, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + source = self.source and AttrSource(self.source, name) + if self.fn is object: + # for object, we can just directly read the attribute + try: + value = getattr(self.fn, name) + except AttributeError: + raise_observed_exception(AttributeError, tx) + # pyrefly: ignore [unbound-name] + if not callable(value): + # pyrefly: ignore [unbound-name] + return VariableTracker.build(tx, value, source) + return variables.GetAttrVariable(self, name, source=source) + + def call_getattr( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + name_var: VariableTracker, + default: VariableTracker | None = None, + ) -> VariableTracker | None: + if not name_var.is_python_constant(): + unimplemented( + gb_type="getattr() with non-constant name argument", + context=f"getattr({obj}, {name_var}, {default})", + explanation="getattr() with non-constant name argument is not supported", + hints=["Ensure the name argument of getattr() is a string"], + ) + + name = name_var.as_python_constant() + + # See NOTE [Tensor "grad" and "_grad" attr] + if obj.is_tensor() and name == "_grad": + name = "grad" + + if tx.output.side_effects.is_attribute_mutation(obj): + if isinstance(obj, variables.UnspecializedNNModuleVariable): + if ( + name + in ( + "named_parameters", + "parameters", + "named_buffers", + "buffers", + "named_modules", + "modules", + ) + and obj.is_state_mutated + and tx.output.side_effects.has_pending_mutation(obj) + ): + unimplemented( + gb_type="getattr() on nn.Module with pending mutation", + context=f"getattr({obj}, {name}, {default})", + explanation="Intentionally graph breaking on getattr() on a nn.Module " + "with a pending mutation", + hints=[], + ) + + if tx.output.side_effects.has_pending_mutation_of_attr(obj, name): + return tx.output.side_effects.load_attr(obj, name) + + if default is not None: + hasattr_var = self.call_hasattr(tx, obj, name_var) + if hasattr_var is not None: + assert hasattr_var.is_constant_match(True, False) + if not hasattr_var.as_python_constant(): + return default + else: + return default + + source = obj.source and AttrSource(obj.source, name) + if name in {"__bases__", "__base__", "__flags__"}: + try: + value = obj.as_python_constant() + if isinstance(value, type): + if name == "__bases__": + tuple_args = [ + VariableTracker.build( + tx, b, source and GetItemSource(source, i) + ) + for i, b in enumerate(value.__bases__) + ] + return variables.TupleVariable(tuple_args, source=source) + if name == "__base__": + return VariableTracker.build(tx, value.__base__, source) + if name == "__flags__": + return ConstantVariable.create(value.__flags__) + except NotImplementedError: + pass + + if isinstance(obj, variables.NNModuleVariable): + return obj.var_getattr(tx, name) + elif isinstance( + obj, + ( + variables.TensorVariable, + variables.NamedTupleVariable, + variables.ConstantVariable, + variables.DistributedVariable, + variables.UserDefinedClassVariable, + variables.UserDefinedObjectVariable, + ), + ): + if ( + isinstance(obj, variables.UserDefinedObjectVariable) + and issubclass(obj.value.__class__, unittest.TestCase) + and config.enable_trace_unittest + and name + in ( + "assertRaisesRegex", + "assertNotWarns", + "assertWarnsRegex", + "assertWarns", + ) + ): + unimplemented( + gb_type="Failed to trace unittest method", + context=f"function: unittest.TestCase.{name}", + explanation=f"Dynamo does not know how to trace unittest method `{name}` ", + hints=[ + f"Avoid calling `TestCase.{name}`. " + "Please report an issue to PyTorch.", + ], + ) + if obj.is_tensor(): + fake_val = obj.as_proxy().node.meta["example_value"] + if ( + isinstance(fake_val, torch.Tensor) + and is_sparse_any(fake_val) + and (not tx.export or not config.capture_sparse_compute) + ): + unimplemented( + gb_type="Attempted to wrap sparse Tensor", + context="", + explanation="torch.compile does not support sparse Tensors", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + try: + return obj.var_getattr(tx, name) + except AsPythonConstantNotImplementedError: + # dont fallback on as_python_constant error because this leads + # to a failure later on, and leads to a wrong stacktrace + raise + except NotImplementedError: + return variables.GetAttrVariable(obj, name, source=source) + elif isinstance(obj, variables.TorchInGraphFunctionVariable): + # Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default. + member = getattr(obj.value, name) + if isinstance( + member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ) and torch._dynamo.trace_rules.is_aten_op_or_tensor_method(member): + return variables.TorchInGraphFunctionVariable(member, source=source) + elif name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(obj, name, source=source) + else: + return None + elif isinstance(obj, DummyModule): + # TODO(mlazos) - Do we need this? + if obj.is_torch or name not in obj.value.__dict__: + member = getattr(obj.value, name) + else: + member = obj.value.__dict__[name] + + if config.replay_record_enabled: + tx.exec_recorder.record_module_access(obj.value, name, member) # type: ignore[arg-type, union-attr] + return VariableTracker.build(tx, member, source) + + elif istype(obj, variables.UserFunctionVariable) and name in ( + "__name__", + "__module__", + ): + return ConstantVariable.create(getattr(obj.fn, name)) + else: + try: + return obj.var_getattr(tx, name) + except NotImplementedError: + return variables.GetAttrVariable(obj, name, source=source) + + def call_setattr( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + name_var: VariableTracker, + val: VariableTracker, + ) -> VariableTracker | None: + if isinstance( + obj, + ( + variables.PlacementVariable, + variables.NamedTupleVariable, + variables.UserDefinedObjectVariable, + variables.NestedUserFunctionVariable, + variables.ExceptionVariable, + ), + ): + return obj.call_method(tx, "__setattr__", [name_var, val], {}) + elif ( + tx.output.side_effects.is_attribute_mutation(obj) + and name_var.is_python_constant() + ): + name = name_var.as_python_constant() + if obj.is_tensor(): + from .builder import wrap_fx_proxy + + # Some special handling for tensor attributes. + if name == "requires_grad": + # TODO(voz): Make it work properly + unimplemented( + gb_type="setattr() on Tensor.requires_grad", + context=f"setattr({obj}, {name}, {val})", + explanation="setattr() on Tensor.requires_grad not supported. " + "Mutating requires_grad can introduce a new leaf from non-leaf or vice versa in " + "the middle of the graph, which AOTAutograd does not currently know how to handle.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + elif name == "data": + # See comments on `test_set_data_on_scoped_tensor` for plans + # to support this. + if obj.source is None: + unimplemented( + gb_type="Failed to mutate tensor data attribute", + context=f"setattr({obj}, {name}, {val})", + explanation="Dyanmo only supports mutating `.data`" + " of tensor created outside `torch.compile` region", + hints=[ + "Don't mutate `.data` on this tensor, or move " + "the mutation out of `torch.compile` region", + ], + ) + elif obj.dtype != val.dtype: # type: ignore[attr-defined] + unimplemented( + gb_type="Failed to mutate tensor data attribute to different dtype", + context=f"setattr({obj}, {name}, {val})", + explanation="Dyanmo only supports mutating `.data`" + " of tensor to a new one with the same dtype", + hints=[ + "Don't mutate `.data` on this tensor, or move " + "the mutation out of `torch.compile` region", + ], + ) + + # Remove the old reference in tracked fakes - if we don't do this + # new .data value size and shape differences will cause + # tracked fakes to produce incorrect guards. This is sound because the TensorVariable + # coming out of set_() below will be a new one, and get + # installed in tracked fakes. + to_remove = [ + tf for tf in tx.output.tracked_fakes if tf.source == obj.source + ] + for tf in to_remove: + tx.output.tracked_fakes.remove(tf) + + # Step 1 - disable grads + with dynamo_disable_grad(tx), torch.no_grad(): + # Step 2 - call `set_` + out = wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + torch.Tensor.set_, + *proxy_args_kwargs([obj, val], {}), + ), + ) + + # Step 3 - drop the version counter - this is a step required to get + # .data setting to play correctly with the autograd engine. + # Essentially, dynamo is trying to faithfully preserve the (absurd) + # behavior of .data= from eager mode + def _lower_version_count_by_1(x: torch.Tensor) -> torch.Tensor: + version = x._version + if version > 0: + version = version - 1 + torch._C._autograd._unsafe_set_version_counter((x,), (version,)) + return x + + tx.output.create_proxy( + "call_function", + _lower_version_count_by_1, + (out.as_proxy(),), + {}, + ) + _lower_version_count_by_1(obj.as_proxy().node.meta["example_value"]) + # This handles options prop, guards and ends with a clone + # Step 4 - replace all reference to the current object with the new one + return out + elif name in ("_grad", "grad"): + # NOTE: [Tensor "grad" and "_grad" attr] + # _grad and grad share the same setter/getter, see + # THPVariable_properties, and here we make sure setting one + # enables reading `val` from the other, by routing all + # read/write to `grad`. + name = "grad" + elif is_tensor_getset_descriptor(name): + # Attribute like `torch.Tensor.real` has special setters we + # don't yet support; it's not as simple adding an entry to + # the side effect mapping. + unimplemented( + gb_type="Failed to set tensor attribute", + context=f"setattr({obj}, {name}, {val})", + explanation="Dyanmo doesn't support setting these tensor attributes", + hints=[ + f"Don't mutate attribute '{name}' on tensors, or " + "move the mutation out of `torch.compile` region", + ], + ) + + tx.output.side_effects.store_attr(obj, name, val) + return val + elif isinstance(obj, variables.NNModuleVariable): + if not tx.output.is_root_tracer(): + raise AttributeMutationError( + "Can't inplace modify module params/buffers inside HigherOrderOp" + ) + if name_var.is_python_constant() and isinstance( + val, variables.TensorVariable + ): + assigning_fake_val = get_fake_value(val.as_proxy().node, tx) + + try: + getattr_var = obj.var_getattr(tx, name_var.as_python_constant()) + except (AttributeError, ObservedAttributeError): + getattr_var = None + + if getattr_var is not None and getattr_var.is_tensor(): + # get_fake_val will get the same fake tensor + existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx) + + # same tensor identity, setattr is a no-op + mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__") + if ( + existing_fake_attr is assigning_fake_val + and mod_setattr is torch.nn.Module.__setattr__ + ): + return getattr_var + + obj.convert_to_unspecialized(tx) + return None + + def call_delattr( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + name_var: VariableTracker, + ) -> VariableTracker: + return obj.call_method(tx, "__delattr__", [name_var], {}) + + def call_type( + self, tx: "InstructionTranslator", obj: VariableTracker + ) -> VariableTracker: + try: + py_type = obj.python_type() + except NotImplementedError as error: + raise UserError( + UserErrorType.INVALID_INPUT, + str(error), + case_name="unknown_python_type", + ) from None + + source = obj.source and TypeSource(obj.source) + if ( + source is None + and isinstance(obj, variables.UserDefinedObjectVariable) + and obj.cls_source + ): + source = obj.cls_source + if py_type is torch.Tensor: + # In some cases torch isn't available in globals + name = tx.output.install_global_by_id("", torch) + source = AttrSource(GlobalSource(name), "Tensor") + + return VariableTracker.build(tx, py_type, source) + + def call_reversed( + self, tx: "InstructionTranslator", obj: VariableTracker + ) -> VariableTracker | None: + if obj.has_unpack_var_sequence(tx): + items = list(reversed(obj.unpack_var_sequence(tx))) + return variables.TupleVariable(items) + return None + + def call_sorted( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker | None: + if obj.has_force_unpack_var_sequence(tx) and not isinstance( + obj, variables.TensorVariable + ): + list_var = variables.ListVariable( + obj.force_unpack_var_sequence(tx), + mutation_type=ValueMutationNew(), + ) + list_var.call_method(tx, "sort", [], kwargs) + return list_var + return None + + # neg is a constant fold function, so we only get here if constant fold is not valid + def call_neg( + self, tx: "InstructionTranslator", a: VariableTracker + ) -> VariableTracker | None: + if isinstance(a, SymNodeVariable): + return SymNodeVariable.create( + tx, + (operator.neg)(a.as_proxy()), + sym_num=None, + ) + + if ( + isinstance(a, UserDefinedObjectVariable) + and a.call_obj_hasattr(tx, "__neg__").value # type: ignore[attr-defined] + ): + return a.call_method(tx, "__neg__", [], {}) + + # None no-ops this handler and lets the driving function proceed + return None + + def call_format( + self, + tx: "InstructionTranslator", + _format_string: VariableTracker, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + format_string = _format_string.as_python_constant() + format_string = str(format_string) + return variables.StringFormatVariable.create(format_string, args, kwargs) + + def call_id( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable): + nn_mod_variable = args[0] + mod = tx.output.get_submodule(nn_mod_variable.module_key) + return variables.ConstantVariable.create(id(mod)) + elif len(args) == 1 and isinstance( + args[0], + (variables.UserDefinedClassVariable, variables.UserDefinedObjectVariable), + ): + if args[0].source: + if isinstance(args[0], variables.UserDefinedClassVariable): + install_guard(args[0].source.make_guard(GuardBuilder.CLASS_MATCH)) + else: + install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH)) + constant_result = id(args[0].value) + return variables.ConstantVariable.create(constant_result) + elif len(args) == 1 and args[0].is_tensor(): + tensor_variable = cast(TensorVariable, args[0]) + return tensor_variable.call_id(tx) + elif istype(args[0], variables.UserFunctionVariable): + return variables.ConstantVariable.create(id(args[0].fn)) + elif istype(args[0], variables.SkipFunctionVariable): + return variables.ConstantVariable.create(id(args[0].value)) + elif istype(args[0], variables.FunctoolsPartialVariable): + return variables.ConstantVariable.create(id(args[0].fake_value)) + else: + unimplemented( + gb_type="id() with unsupported args", + context=str(args), + explanation=f"Dynamo doesn't know how to trace id() call with args {args}", + hints=[ + "Supported args are Tensors, and functions/nn.Modules/user-defined objects " + "from outside the compiled region.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def call_deepcopy( + self, tx: "InstructionTranslator", x: VariableTracker + ) -> VariableTracker: + unimplemented( + gb_type="copy.deepcopy()", + context=f"copy.deepcopy({x})", + explanation="Dynamo does not support copy.deepcopy()", + hints=[ + "Avoid calling copy.deepcopy()", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def _comparison_with_tensor( + self, tx: "InstructionTranslator", left: VariableTracker, right: VariableTracker + ) -> VariableTracker: + from .builder import wrap_fx_proxy_cls + from .tensor import supported_tensor_comparison_op_values + + op = self.fn + + if op in [operator.is_, operator.is_not]: + is_result = ( + left.is_tensor() + and right.is_tensor() + and id(extract_fake_example_value(left.as_proxy().node)) + == id(extract_fake_example_value(right.as_proxy().node)) + ) + if op is operator.is_: + return ConstantVariable.create(is_result) + else: + return ConstantVariable.create(not is_result) + + if op not in supported_tensor_comparison_op_values: + unimplemented( + gb_type="unsupported Tensor comparison op", + context=f"{op.__name__}({left}, {right})", + explanation=f"Dynamo does not support the comparison op {op.__name__} " + f"with Tensor arguments {left}, {right}", + hints=[*graph_break_hints.SUPPORTABLE], + ) + if ( + isinstance(left, TensorVariable) + and isinstance(right, TensorVariable) + and (left.size and right.size) is not None + and left.size != right.size + ): + try: + torch.broadcast_shapes(left.size, right.size) + except RuntimeError: + # not broadcastable, can't be compared + unimplemented( + gb_type="failed to broadcast when attempting Tensor comparison op", + context=f"{op.__name__}({left}, {right})", + explanation=f"Dynamo was unable to broad cast the arguments {left}, {right} " + f"when attempting to trace the comparison op {op.__name__}.", + hints=[*graph_break_hints.USER_ERROR], + ) + tensor_cls = left if left.is_tensor() else right + proxy = tx.output.create_proxy( + "call_function", op, (left.as_proxy(), right.as_proxy()), {} + ) + return wrap_fx_proxy_cls( + type(tensor_cls), # handle Ndarrays and Tensors + tx, + proxy, + ) + + def _comparison_with_symnode( + self, tx: "InstructionTranslator", left: VariableTracker, right: VariableTracker + ) -> VariableTracker: + from .tensor import supported_tensor_comparison_op_values + + op = self.fn + + if op not in supported_tensor_comparison_op_values: + unimplemented( + gb_type="unsupported SymNode comparison op", + context=f"{op.__name__}({left}, {right})", + explanation=f"Dynamo does not support the comparison op {op.__name__} " + f"with SymNode arguments {left}, {right}", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + # This is seen in inspect signature where we check if the value is a default value + if isinstance(right, variables.UserDefinedClassVariable): + return variables.ConstantVariable(op(object(), None)) + + proxy = tx.output.create_proxy( + "call_function", op, (left.as_proxy(), right.as_proxy()), {} + ) + return SymNodeVariable.create( + tx, + proxy, + sym_num=None, + ) + + def call_xor( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if a.is_symnode_like() and b.is_symnode_like(): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.xor, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + + if isinstance( + a, + (DictKeysVariable, SetVariable, UserDefinedObjectVariable), + ): + return a.call_method(tx, "__xor__", [b], {}) + return None + + def call_ixor( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + return a.call_method(tx, "__ixor__", [b], {}) + return None + + def call_sub( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + return a.call_method(tx, "__sub__", [b], {}) + return None + + def call_isub( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + return a.call_method(tx, "__isub__", [b], {}) + return None + + def call_and_( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if a.is_symnode_like() and b.is_symnode_like(): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.and_, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + return a.call_method(tx, "__and__", [b], {}) + # None no-ops this handler and lets the driving function proceed + return None + + def call_iand( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if a.is_symnode_like() and b.is_symnode_like(): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.iand, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + return a.call_method(tx, "__iand__", [b], {}) + return None + + def call_or_( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if a.is_symnode_like() and b.is_symnode_like(): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.or_, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + + # This call looks like `{"one": torch.ones(1)} | {"two": torch.ones(2)}`. + if isinstance( + a, + ( + ConstDictVariable, + DictKeysVariable, + MutableMappingVariable, + SetVariable, + UserDefinedDictVariable, + UserDefinedObjectVariable, + ), + ): + # TODO(guilhermeleobas): forward the call to b.__ror__(a) if + # a.__ror__(b) returns NotImplemented + return a.call_method(tx, "__or__", [b], {}) + + # None no-ops this handler and lets the driving function proceed + return None + + def call_ior( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if a.is_symnode_like() and b.is_symnode_like(): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.ior, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + + # This call looks like `{"one": torch.ones(1)} |= {"two": torch.ones(2)}`. + if isinstance( + a, + ( + ConstDictVariable, + DictKeysVariable, + MutableMappingVariable, + SetVariable, + UserDefinedObjectVariable, + ), + ): + return a.call_method(tx, "__ior__", [b], {}) + + # None no-ops this handler and lets the driving function proceed + return None + + def call_not_( + self, tx: "InstructionTranslator", a: VariableTracker + ) -> VariableTracker | None: + if isinstance(a, SymNodeVariable): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.not_, *proxy_args_kwargs([a], {}) + ), + sym_num=None, + ) + + # Unwrap the underlying ConstDictVariable + if isinstance(a, DictViewVariable): + a = a.dv_dict + if isinstance(a, (ListVariable, ConstDictVariable)): + return ConstantVariable.create(len(a.items) == 0) + + return None + + def call_contains( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker: + return a.call_method(tx, "__contains__", [b], {}) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.fn) + + def is_python_equal(self, other): + return isinstance(other, variables.BuiltinVariable) and self.fn is other.fn + + +@contextlib.contextmanager +def dynamo_disable_grad(tx: "InstructionTranslator") -> typing.Iterator[None]: + from . import GradModeVariable + + gmv = GradModeVariable.create(tx, False) + try: + gmv.enter(tx) + yield + finally: + gmv.exit(tx) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/constant.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..86b5301b63e7233fd4061858f081695511517537 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/constant.py @@ -0,0 +1,421 @@ +""" +Constant and enum variable tracking in Dynamo. + +This module is fundamental to Dynamo's ability to track and propagate constant +values during compilation, ensuring proper handling of Python literals and +maintaining type safety through the compilation process. +""" + +import enum +import operator +from collections.abc import Sequence +from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union +from typing_extensions import override + +import torch +from torch._dynamo.source import AttrSource, GetItemSource + +from .. import graph_break_hints, variables +from ..exc import raise_observed_exception, unimplemented +from ..utils import ( + cmp_name_to_op_mapping, + common_constant_types, + istype, + np, + raise_args_mismatch, + raise_on_overridden_hash, +) +from .base import ValueMutationNew, VariableTracker + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + from .functions import UserFunctionVariable + + +class ConstantVariable(VariableTracker): + """ + Variable tracker for Python literals and basic immutable types, with automatic + routing support for collection types (lists, tuples, sets, etc.). + + The create() method intelligently constructs appropriate variable types for + nested collections. + """ + + @overload + @staticmethod + def create(value: bool) -> "ConstantVariable": ... + + @overload + @staticmethod + def create(value: Any, **kwargs: Any) -> VariableTracker: ... + + @staticmethod + def create(value: Any, **kwargs: Any) -> VariableTracker: + """ + Create a `ConstantVariable` based on the given value, and supports + automatic routing for collection types like `tuple` (in which case we'd + create `ConstantVariable` for the leaf items). + + NOTE: the caller must install the proper guards if needed; most often + the guard will be `CONSTANT_MATCH`. + """ + source = kwargs.get("source") + + # Routing for supported collection literals. + if isinstance(value, set): + items = [ConstantVariable.create(x) for x in value] + return variables.SetVariable(items, **kwargs) # type: ignore[arg-type] + elif isinstance(value, frozenset): + items = [ConstantVariable.create(x) for x in value] + return variables.FrozensetVariable(items, **kwargs) # type: ignore[arg-type] + elif isinstance(value, slice): + slice_args = (value.start, value.stop, value.step) + slice_args_vars = tuple(ConstantVariable.create(arg) for arg in slice_args) + return variables.SliceVariable(slice_args_vars, **kwargs) + elif isinstance(value, (list, tuple)): + items = [] + for i, x in enumerate(value): + item_source = GetItemSource(source, i) if source else None + items.append( + ConstantVariable.create( + x, + source=item_source, + ) + ) + return variables.BaseListVariable.cls_for(type(value))(items, **kwargs) + + return ConstantVariable(value, **kwargs) + + def __init__(self, value: Any, **kwargs: Any) -> None: + super().__init__(**kwargs) + assert ConstantVariable.is_base_literal(value), f""" +Cannot construct `ConstantVariable` for value of type {type(value)}. + +This failure likely due to PyTorch-internal use of `ConstantVariable` on +non-literal python values, please try using `VariableTracker.build` instead. If +you believe it's a necessary and legitimate use case (the value is immutable and +can't easily be represented with another `VariableTracker` class), please add +its type to `common_constant_types`. +""" + if np is not None and isinstance(value, np.number): + self.value = value.item() + else: + self.value = value + + def as_proxy(self) -> Any: + return self.value + + def __repr__(self) -> str: + return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})" + + def as_python_constant(self) -> Any: + return self.value + + def is_python_constant(self) -> Literal[True]: + return True + + def is_symnode_like(self) -> bool: + return isinstance(self.value, (int, bool)) + + def is_constant_match(self, *values: Any) -> bool: + return self.value in values + + def is_constant_none(self) -> bool: + return self.value is None + + @property + def items(self) -> list[VariableTracker]: + """ + Need this when adding a BaseListVariable and a ConstantVariable together. + Happens in detectron2. + """ + return self.unpack_var_sequence(tx=None) + + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + return ConstantVariable.create( + self.value[arg.as_python_constant()], + ) + + @staticmethod + def is_base_literal(obj: object) -> bool: + return type(obj) in common_constant_types + + @staticmethod + def is_literal(obj: object) -> bool: + if type(obj) in (list, tuple, set, frozenset, torch.Size): + return all(ConstantVariable.is_literal(x) for x in obj) # type: ignore[attr-defined] + return ConstantVariable.is_base_literal(obj) + + def unpack_var_sequence( + self, tx: Optional["InstructionTranslator"] + ) -> list[VariableTracker]: + try: + return [ConstantVariable.create(x) for x in self.as_python_constant()] + except TypeError as e: + raise NotImplementedError from e + + def const_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if not hasattr(self.value, name): + raise_observed_exception(AttributeError, tx, args=[name]) + member = getattr(self.value, name) + if callable(member): + raise NotImplementedError + return member + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .tensor import SymNodeVariable + + if name == "format" and istype(self.value, str): + return variables.BuiltinVariable(str.format).call_function( + tx, [self, *args], kwargs + ) + elif name == "join" and istype(self.value, str): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + arg_unpacked = args[0].force_unpack_var_sequence(tx) + try: + arg_const = [x.as_python_constant() for x in arg_unpacked] + return ConstantVariable.create(self.value.join(arg_const)) + except NotImplementedError: + return super().call_method(tx, name, args, kwargs) + elif name == "__iter__" and istype(self.value, str): + # this could be some generic iterator to avoid the circular import, + # but ListIterator does what we want + from .lists import ListIteratorVariable + + return ListIteratorVariable( + self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() + ) + + if any(isinstance(x, SymNodeVariable) for x in args): + # Promote to SymNodeVariable for operations involving dynamic shapes. + return variables.SymNodeVariable.create( + tx, self.as_proxy(), self.value + ).call_method(tx, name, args, kwargs) + + try: + const_args = [a.as_python_constant() for a in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + except NotImplementedError: + return super().call_method(tx, name, args, kwargs) + + if isinstance(self.value, str) and name in str.__dict__: + method = getattr(self.value, name) + try: + return ConstantVariable.create(method(*const_args, **const_kwargs)) + except Exception as e: + raise_observed_exception(type(e), tx) + elif isinstance(self.value, (float, int)): + if not (args or kwargs): + try: + return ConstantVariable.create(getattr(self.value, name)()) + except (OverflowError, ValueError) as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + if ( + hasattr(operator, name) + and len(args) == 1 + and args[0].is_python_constant() + ): + add_target = const_args[0] + op = getattr(operator, name) + if isinstance( + add_target, (torch.SymBool, torch.SymFloat, torch.SymInt) + ): + # Addition between a non sym and sym makes a sym + proxy = tx.output.create_proxy( + "call_function", op, (self.value, add_target), {} + ) + return SymNodeVariable.create(tx, proxy, add_target) + else: + try: + return ConstantVariable.create(op(self.value, add_target)) + except Exception as e: + raise_observed_exception( + type(e), tx, args=list(map(ConstantVariable.create, e.args)) + ) + elif isinstance(self.value, bytes) and name == "decode": + method = getattr(self.value, name) + return ConstantVariable.create(method(*const_args, **const_kwargs)) + elif type(self.value) is complex and name in complex.__dict__: + method = getattr(self.value, name) + try: + return ConstantVariable.create(method(*const_args, **const_kwargs)) + except Exception as e: + raise_observed_exception(type(e), tx) + + if name == "__len__" and not (args or kwargs): + # pyrefly: ignore [bad-argument-type] + return ConstantVariable.create(len(self.value)) + elif name == "__round__" and len(args) == 1 and args[0].is_python_constant(): + try: + return ConstantVariable.create( + # pyrefly: ignore [no-matching-overload] + round(self.value, args[0].as_python_constant()) + ) + except Exception as e: + raise_observed_exception( + type(e), tx, args=list(map(ConstantVariable.create, e.args)) + ) + elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant(): + assert not kwargs + search = args[0].as_python_constant() + try: + # pyrefly: ignore [unsupported-operation] + result = search in self.value + return ConstantVariable.create(result) + except TypeError as e: + raise_observed_exception( + type(e), tx, args=list(map(ConstantVariable.create, e.args)) + ) + return super().call_method(tx, name, args, kwargs) + + def call_tree_map( + self, + tx: "InstructionTranslator", + tree_map_fn: "UserFunctionVariable", + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.value is None: + none_is_leaf_var = tree_map_kwargs.get("none_is_leaf") + if none_is_leaf_var is not None: + try: + none_is_leaf = bool(none_is_leaf_var.as_python_constant()) + except NotImplementedError: + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + else: + tree_map_module = getattr( + getattr(tree_map_fn, "fn", None), "__module__", "" + ) + # torch.utils._pytree and torch.utils._cxx_pytree treat None as a leaf + # by default, while optree keeps it as an internal node unless + # none_is_leaf=True is provided. + none_is_leaf = not tree_map_module.startswith("optree") + if none_is_leaf: + return map_fn.call_function(tx, [self, *rest], {}) + else: + for other in rest: + if not other.is_constant_none(): + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + return self.clone() + if isinstance(self.value, (int, float, bool, complex, str, bytes, torch.dtype)): + return map_fn.call_function(tx, [self, *rest], {}) + return super().call_tree_map( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + @override + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "ConstantVariable": + result = hasattr(self.value, name) + return variables.ConstantVariable.create(result) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + # Could be an EnumVariable as well + from .tensor import SymNodeVariable + + if isinstance(other, SymNodeVariable): + return self.as_python_constant() == other.evaluate_expr() + return self.as_python_constant() == other.as_python_constant() + + +class EnumVariable(VariableTracker): + """VariableTracker for enum.Enum and enum.IntEnum instances + + Provides specialized handling for Python enum types, supporting + both standard Enum and IntEnum with proper value tracking and comparison. + """ + + def __init__(self, value: Union[enum.Enum, enum.IntEnum], **kwargs: Any) -> None: + super().__init__(**kwargs) + self.value = value + + @classmethod + def create( + cls, cls_type: Any, value_vt: VariableTracker, options: Any + ) -> "EnumVariable": + if value_vt.is_python_constant(): + for member in list(cls_type): + if member.value == value_vt.as_python_constant(): + return cls(member, **options) + unimplemented( + gb_type="Failed to construct Enum variable", + context=f"value: {value_vt}, allowed enum values: {list(cls_type)}", + explanation="Attempted to construct an Enum value that is non-constant (e.g. int, string) " + "or is not an acceptable value for the Enum. " + f"Acceptable values for Enum `{cls_type}`: {list(cls_type)}.", + hints=[*graph_break_hints.USER_ERROR, *graph_break_hints.SUPPORTABLE], + ) + + def as_proxy(self) -> Union[enum.Enum, int]: + if isinstance(self.value, int): + return int(self.value) # convert IntEnum to a normal int + return self.value + + def __repr__(self) -> str: + return f"EnumVariable({type(self.value)})" + + def as_python_constant(self) -> Union[enum.Enum, enum.IntEnum]: + return self.value + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if not hasattr(self.value, name): + raise NotImplementedError + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + member = getattr(self.value, name) + source = self.source and AttrSource(self.source, name) + return VariableTracker.build(tx, member, source=source) + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/ctx_manager.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/ctx_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..9c08a2d12eb96d3bf94880d17fe9064f9ea53975 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/ctx_manager.py @@ -0,0 +1,1529 @@ +""" +This file contains a collection of context manager classes used by Dynamo for tracking +and managing various PyTorch runtime states during graph compilation. These context +managers handle different aspects of PyTorch's execution environment, including: + +- Autograd states (grad mode, inference mode) +- CUDA streams and events +- Profiling contexts +- Deterministic algorithms +- Forward/backward AD modes +- SDPA (Scaled Dot Product Attention) kernels +- FSDP (Fully Sharded Data Parallel) states +- AMP (Automatic Mixed Precision) autocast states + +The context managers ensure proper state transitions during graph compilation by +tracking enter/exit points and managing cleanup operations. They help maintain +consistency between eager execution and compiled graph behavior by capturing and +restoring state changes. +""" + +import inspect +import sys +import warnings +from collections.abc import Callable, Sequence, Sized +from contextlib import AbstractContextManager, ExitStack +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch._C +from torch._guards import Guard + +from .. import graph_break_hints, variables +from ..bytecode_transformation import ( + create_call_function, + create_instruction, + create_setup_with, +) +from ..exc import unimplemented +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, GlobalStateSource +from ..utils import _get_error_on_graph_break, _set_error_on_graph_break +from .base import VariableTracker +from .functions import ( + NestedUserFunctionVariable, + SkipFunctionVariable, + UserFunctionVariable, + UserMethodVariable, + WrappedNestedUserFunctionVariable, + WrappedSkipFunctionVariable, + WrappedUserFunctionVariable, + WrappedUserMethodVariable, +) +from .user_defined import UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class ContextWrappingVariable(VariableTracker): + _nonvar_fields = { + "cm_obj", + "target_values", + "initial_values", + "state", + *VariableTracker._nonvar_fields, + } + + def __init__( + self, target_values: Any, initial_values: Optional[Any] = None, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + self.target_values = target_values + self.initial_values = initial_values + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + if hasattr(self, "_call_func"): + self._call_func(tx, self.target_values) + self.set_cleanup_hook(tx) + return variables.ConstantVariable.create(None) + + def set_cleanup_hook( + self, tx: "InstructionTranslator", fn: Optional[Callable[..., Any]] = None + ) -> None: + if fn is None: + + def fn() -> None: + if hasattr(self, "_call_func"): + self._call_func(tx, self.initial_values) + + self.cleanup_fn: Optional[Callable[..., Any]] = fn + tx.output.add_cleanup_hook(self.cleanup) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup_assert() + return variables.ConstantVariable.create(None) + + def reconstruct_type(self, codegen: "PyCodegen") -> None: + codegen( + AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name()) + ) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: self.reconstruct_type(codegen)) + target_values = self.target_values + if not target_values: + target_values = () + codegen.extend_output([codegen.create_load_const(val) for val in target_values]) + codegen.extend_output(create_call_function(len(target_values), False)) + + def module_name(self) -> str: + raise NotImplementedError("module_name called on base") + + def fn_name(self) -> str: + raise NotImplementedError("fn_name called on base") + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + assert len(args) == 1 + assert isinstance( + args[0], + ( + NestedUserFunctionVariable, + SkipFunctionVariable, + UserMethodVariable, + UserFunctionVariable, + ), + ) + + if isinstance(args[0], NestedUserFunctionVariable): + return WrappedNestedUserFunctionVariable(args[0], self) + elif isinstance(args[0], SkipFunctionVariable): + return WrappedSkipFunctionVariable(args[0], self) + elif isinstance(args[0], UserMethodVariable): + return WrappedUserMethodVariable(args[0], self) + elif isinstance(args[0], UserFunctionVariable): + return WrappedUserFunctionVariable(args[0], self) + else: + raise AssertionError("Unexpected arg type") + + def supports_graph_breaks(self) -> bool: + return True + + def exit_on_graph_break(self) -> bool: + return True + + def cleanup(self) -> None: + if self.cleanup_fn is not None: + self.cleanup_fn() + self.cleanup_fn = None + + def cleanup_assert(self) -> None: + assert self.cleanup_fn, "multiple exits?" + self.cleanup() + + +class GenericContextWrappingVariable(UserDefinedObjectVariable): + # Some methods in ContextWrappingVariable assumes the arguments are + # python constants. Which might not always be the case here. + def __init__(self, cm_obj: AbstractContextManager[Any], **kwargs: Any) -> None: + assert cm_obj is not None + super().__init__( + value=cm_obj, + value_type=cm_obj.__class__, + **kwargs, + ) + self.cm_obj = cm_obj + + def module_name(self) -> str: + return self.cm_obj.__module__ + + def fn_name(self) -> str: + return type(self.cm_obj).__name__ + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + source = None if self.source is None else AttrSource(self.source, "__enter__") + return variables.UserMethodVariable( + self.cm_obj.__enter__.__func__, # type: ignore[attr-defined] + self, + source=source, + ).call_function(tx, [], {}) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + source = None if self.source is None else AttrSource(self.source, "__exit__") + x = variables.UserMethodVariable( + self.cm_obj.__exit__.__func__, # type: ignore[attr-defined] + self, + source=source, + ).call_function(tx, list(args), {}) + tx.active_generic_context_managers.pop() + return x + + def supports_graph_breaks(self) -> bool: + return False + + def exit_on_graph_break(self) -> bool: + return True + + +class RepararametrizeModuleContextVariable(GenericContextWrappingVariable): + def __init__(self, ctx_manager_vt: ContextWrappingVariable, mod: Any) -> None: + self.cm_vt = ctx_manager_vt + self.mod = mod + # We don't call super().__init__() because we're delegating most methods to cm_vt + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + # Custom enter implementation with side effects + + self.old_parameters_var = self.mod.var_getattr(tx, "_parameters").realize() + self.old_buffer_var = self.mod.var_getattr(tx, "_buffers").realize() + tx.output.side_effects.ignore_mutations_on(self.old_parameters_var) + tx.output.side_effects.ignore_mutations_on(self.old_buffer_var) + return self.cm_vt.enter(tx) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + # Custom exit implementation with side effects + x = self.cm_vt.exit(tx, *args) + tx.output.side_effects.stop_ignoring_mutations_on(self.old_buffer_var) + tx.output.side_effects.stop_ignoring_mutations_on(self.old_parameters_var) + return x + + # Forward all other method calls to self.cm_vt + def __getattr__(self, name: str) -> Any: + # This will be called for any attribute not explicitly defined in this class + return getattr(self.cm_vt, name) + + +class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): + """represents torch grad requires grad""" + + @staticmethod + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "GradInplaceRequiresGradCtxManagerVariable": + return GradInplaceRequiresGradCtxManagerVariable( + target_values=target_values, + initial_values=None, + **kwargs, + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + [enabled] = self.target_values + self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed() + torch._C._functorch.set_inplace_requires_grad_allowed(enabled) + self.set_cleanup_hook( + tx, + lambda: torch._C._functorch.set_inplace_requires_grad_allowed( + self.prev_state + ), + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._functorch.set_inplace_requires_grad_allowed, + (enabled,), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", + torch._C._functorch.set_inplace_requires_grad_allowed, + (self.prev_state,), + {}, + ) + return variables.ConstantVariable.create(None) + + +class TemporarilyPopInterpreterStackCtxManagerVariable(ContextWrappingVariable): + """represents torch._functorch.pyfunction.temporarily_pop_interpreter_stack()""" + + @staticmethod + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "TemporarilyPopInterpreterStackCtxManagerVariable": + return TemporarilyPopInterpreterStackCtxManagerVariable( + target_values=target_values, + initial_values=None, + **kwargs, + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + self.saved = torch._C._functorch.pop_dynamic_layer_stack() + self.set_cleanup_hook( + tx, + lambda: torch._C._functorch.push_dynamic_layer_stack(self.saved), + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._functorch.pop_dynamic_layer_stack, + (), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", + torch._C._functorch.push_dynamic_layer_stack, + (self.proxy,), + {}, + ) + return variables.ConstantVariable.create(None) + + +class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable): + """represents torch.func.jvp increment/decrement nesting""" + + # A guard is needed as the grad level is baked into the torch FX graph + # This is fine if jvp is only called from within the function + # being compiled. But the FX graph may be invalid in the case of a jvp + # call from eager that calls the compiled function, as the jvp levels + # may be different. + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] + + @staticmethod + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "JvpIncrementNestingCtxManagerVariable": + var = JvpIncrementNestingCtxManagerVariable( + target_values=None, + initial_values=None, + **kwargs, + ) + return var + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + install_guard(self._guards_singleton) + jvp_level = torch._functorch.eager_transforms.enter_jvp_nesting() + self.set_cleanup_hook( + tx, lambda: torch._functorch.eager_transforms.exit_jvp_nesting() + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._functorch._jvp_increment_nesting, + (), + {}, + ) + return variables.ConstantVariable.create(jvp_level) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", torch._C._functorch._jvp_decrement_nesting, (), {} + ) + return variables.ConstantVariable.create(None) + + +class SetFwdGradEnabledContextManager(ContextWrappingVariable): + """represents torch.autograd.forward_ad._set_fwd_grad_enabled() to enable/disable fwd grad""" + + @staticmethod + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "SetFwdGradEnabledContextManager": + return SetFwdGradEnabledContextManager( + target_values=target_values, + initial_values=None, + **kwargs, + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + [mode] = self.target_values + self.prev_state = torch._C._is_fwd_grad_enabled() + torch._C._set_fwd_grad_enabled(mode) + self.set_cleanup_hook( + tx, + lambda: torch._C._set_fwd_grad_enabled(self.prev_state), + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._set_fwd_grad_enabled, + (mode,), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", + torch._C._set_fwd_grad_enabled, + (self.prev_state,), + {}, + ) + return variables.ConstantVariable.create(None) + + +class DualLevelContextManager(ContextWrappingVariable): + """Represents torch.autograd.forward_ad.dual_level ctx manager""" + + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL) # type: ignore[arg-type] + + @staticmethod + def create(tx: "InstructionTranslator", **kwargs: Any) -> "DualLevelContextManager": + return DualLevelContextManager( + target_values=None, + initial_values=None, + **kwargs, + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + install_guard(self._guards_singleton) + self.new_level = torch.autograd.forward_ad.enter_dual_level() + self.set_cleanup_hook( + tx, lambda: torch.autograd.forward_ad.exit_dual_level(level=self.new_level) + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._enter_dual_level, + (), + {}, + ) + return variables.ConstantVariable.create(self.new_level) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", + torch._C._exit_dual_level, + (self.new_level,), + {}, + ) + return variables.ConstantVariable.create(None) + + +class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable): + """represents torch.func.grad increment/decrement nesting""" + + # A guard is needed as the grad level is baked into the torch FX graph + # This is fine if grad is only called from within the function + # being compiled. But the FX graph may be invalid in the case of a grad + # call from eager that calls the compiled function, as the grad levels + # may be different. + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] + + @staticmethod + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "GradIncrementNestingCtxManagerVariable": + var = GradIncrementNestingCtxManagerVariable( + target_values=None, + initial_values=None, + **kwargs, + ) + return var + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + install_guard(self._guards_singleton) + grad_level = torch._C._functorch._grad_increment_nesting() + self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting()) + self.proxy = tx.output.create_node( + "call_function", + torch._C._functorch._grad_increment_nesting, + (), + {}, + ) + return variables.ConstantVariable.create(grad_level) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", torch._C._functorch._grad_decrement_nesting, (), {} + ) + return variables.ConstantVariable.create(None) + + +class CatchWarningsCtxManagerVariable(ContextWrappingVariable): + """Delay a call to warnings.catch_warnings""" + + @staticmethod + def create( + tx: "InstructionTranslator", catch_warnings_args: dict[str, VariableTracker] + ) -> "CatchWarningsCtxManagerVariable": + return CatchWarningsCtxManagerVariable( + catch_warnings_args=catch_warnings_args, + target_values=None, + initial_values=None, + ) + + def __init__( + self, + catch_warnings_args: dict[str, VariableTracker], + target_values: Optional[Any] = None, + initial_values: Optional[Any] = None, + **kwargs: Any, + ) -> None: + assert isinstance(catch_warnings_args, dict), catch_warnings_args + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.catch_warnings_args = catch_warnings_args + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + kwargs = { + k: v.as_python_constant() for k, v in self.catch_warnings_args.items() + } + ctx_val = warnings.catch_warnings(**kwargs) + self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None)) + return variables.ConstantVariable.create(ctx_val.__enter__()) + + def reconstruct(self, cg: "PyCodegen") -> None: + cg.add_push_null(lambda: cg.load_import_from("warnings", "catch_warnings")) + cg.foreach(self.catch_warnings_args.values()) + keys = tuple(self.catch_warnings_args.keys()) + cg.extend_output(cg.create_call_function_kw(len(keys), keys, False)) + + +class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable): + """represents torch VMap increment/decrement nesting""" + + # A guard is needed as the vmap level is baked into the torch FX graph + # generated. This is fine if vmap is only called from within the function + # being compiled. But the FX graph may be invalid in the case of a vmap + # call from eager that calls the compiled function, as the vmap levels + # may be different. + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] + + @staticmethod + def create( + tx: "InstructionTranslator", + target_values: Sequence[VariableTracker], + **kwargs: Any, + ) -> "VmapIncrementNestingCtxManagerVariable": + var = VmapIncrementNestingCtxManagerVariable( + target_values=target_values, + initial_values=None, + **kwargs, + ) + return var + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + install_guard(self._guards_singleton) + batch_size, randomness = self.target_values + if isinstance(batch_size, variables.SymNodeVariable): + batch_size_value = batch_size.sym_num + else: + batch_size_value = batch_size.as_python_constant() + randomness = randomness.as_python_constant() + vmap_level = torch._C._functorch._vmap_increment_nesting( + batch_size_value, randomness + ) + self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting()) + self.proxy = tx.output.create_proxy( + "call_function", + torch._functorch.predispatch._vmap_increment_nesting, + (batch_size.as_proxy(), randomness), + {}, + ) + return variables.ConstantVariable.create(vmap_level) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", + torch._functorch.predispatch._vmap_decrement_nesting, + (), + {}, + ) + return variables.ConstantVariable.create(None) + + +class GradModeVariable(ContextWrappingVariable): + """represents torch.{no_grad,enable_grad,set_grad_mode}()""" + + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE) # type: ignore[arg-type] + + @staticmethod + def create( + tx: "InstructionTranslator", + target_value: Any, + initialized: bool = False, + **kwargs: Any, + ) -> "GradModeVariable": + var = GradModeVariable( + target_values=[target_value], + initial_values=[torch.is_grad_enabled()], + **kwargs, + ) + if initialized: + var._call_func(tx, var.target_values) + return var + + def __init__( + self, + target_values: Any, + initial_values: Optional[Sequence[bool]] = None, + initialized: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + install_guard(self._guards_singleton) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + self._call_func(tx, self.target_values) + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self._call_func(tx, self.initial_values) + return variables.ConstantVariable.create(None) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + self._call_func(tx, self.initial_values) # undo eager initialization + return super().call_function(tx, args, kwargs) + + def _call_func(self, tx: "InstructionTranslator", values: Any) -> None: + assert len(values) == 1 + value = values[0] + # Coalesce grad mode mutations + if torch.is_grad_enabled() != value: + tx.output.create_node( + "call_function", torch._C._set_grad_enabled, (value,), {} + ) + torch._C._set_grad_enabled(value) + + def module_name(self) -> str: + return "torch" + + def fn_name(self) -> str: + return "set_grad_enabled" + + +class InferenceModeVariable(ContextWrappingVariable): + @staticmethod + def create( + tx: "InstructionTranslator", target_value: Any, **kwargs: Any + ) -> "InferenceModeVariable": + var = InferenceModeVariable( + [target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs + ) + return var + + def __init__( + self, + target_values: Any, + initial_values: Optional[bool] = None, + **kwargs: Any, + ) -> None: + if initial_values is None: + # This must be called here since function defaults are evaluated at import time + initial_values = torch.is_inference_mode_enabled() + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup_assert() + tx.output.create_node( + "call_function", + torch.autograd.grad_mode._exit_inference_mode, + (self.proxy,), + {}, + ) + return variables.ConstantVariable.create(None) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + disabled_inference_mode_forcibly = False + if ( + torch._dynamo.config.fake_tensor_disable_inference_mode + and self.target_values[0] + ): + # Do not set the inference mode because we keep it off during + # compilation. Set the grad_enabled to False to reflect the relevant + # part of inference_mode to torch.compile. + disabled_inference_mode_forcibly = True + prior = torch.is_grad_enabled() + torch._C._set_grad_enabled(False) + else: + ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values) + + def cleanup_hook() -> None: + if disabled_inference_mode_forcibly: + torch._C._set_grad_enabled(prior) + else: + torch.autograd.grad_mode._exit_inference_mode(ctx) + + self.set_cleanup_hook(tx, cleanup_hook) + self.proxy = tx.output.create_node( + "call_function", + torch.autograd.grad_mode._enter_inference_mode, + (*self.target_values,), + {}, + ) + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "torch" + + def fn_name(self) -> str: + return "inference_mode" + + +class CUDADeviceVariable(ContextWrappingVariable): + """represents torch.cuda.device""" + + @staticmethod + def create( + tx: "InstructionTranslator", device: Any, **kwargs: Any + ) -> "CUDADeviceVariable": + var = CUDADeviceVariable( + target_values=[torch.cuda._get_device_index(device, optional=True)], + initial_values=None, + **kwargs, + ) + return var + + def __init__( + self, + target_values: Any, + initial_values: Optional[Any] = None, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup_assert() + tx.output.create_node( + "call_function", + torch.cuda._maybe_exchange_device, + (self.proxy,), + {}, + ) + return variables.ConstantVariable.create(False) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + prev_idx = torch.cuda._exchange_device(*self.target_values) + self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx)) + self.proxy = tx.output.create_node( + "call_function", + torch.cuda._exchange_device, + (*self.target_values,), + {}, + ) + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "torch.cuda" + + def fn_name(self) -> str: + return "device" + + +class TorchFunctionDisableVariable(ContextWrappingVariable): + """represents whether torch function overrides are enabled or not""" + + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE) # type: ignore[arg-type] + + @staticmethod + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "TorchFunctionDisableVariable": + var = TorchFunctionDisableVariable( + target_values=[], + initial_values=[], + **kwargs, + ) + return var + + def __init__( + self, + target_values: Sized, + initial_values: Optional[Sized] = None, + only_subclass: bool = True, + **kwargs: Any, + ) -> None: + assert len(target_values) == 0 + assert initial_values is not None and len(initial_values) == 0 + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + self.only_subclass = only_subclass + self.initial_torch_function_subclass_enabled = ( + tx.symbolic_torch_function_state.torch_function_subclass_enabled + ) + self.initial_torch_function_mode_enabled = ( + tx.symbolic_torch_function_state.torch_function_mode_enabled + ) + + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + install_guard(self._guards_singleton) + + def set_cleanup_hook( + self, + tx: "InstructionTranslator", + cleanup_fn: Optional[Callable[..., Any]] = None, + ) -> None: + if cleanup_fn is None: + + def cleanup_fn() -> None: + tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( + self.initial_torch_function_subclass_enabled + ) + if not self.only_subclass: + tx.symbolic_torch_function_state.torch_function_mode_enabled = ( + self.initial_torch_function_subclass_enabled + ) + + self.cleanup_fn = cleanup_fn + tx.output.add_cleanup_hook(self.cleanup) + + def _call_func(self, tx: "InstructionTranslator", values: Sized) -> None: + assert len(values) == 0 + tx.symbolic_torch_function_state.torch_function_subclass_enabled = False + if not self.only_subclass: + tx.symbolic_torch_function_state.torch_function_mode_enabled = False + + def module_name(self) -> str: + return "torch._C" + + def fn_name(self) -> str: + if self.only_subclass: + return "DisableTorchFunctionSubclass" + return "DisableTorchFunction" + + +class DeterministicAlgorithmsVariable(ContextWrappingVariable): + """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()""" + + _guards_singleton = Guard( + GlobalStateSource(), + GuardBuilder.DETERMINISTIC_ALGORITHMS, # type: ignore[arg-type] + ) + + @staticmethod + def create( + tx: "InstructionTranslator", target_value: bool, **kwargs: Any + ) -> "DeterministicAlgorithmsVariable": + var = DeterministicAlgorithmsVariable( + target_values=[target_value], + initial_values=[torch.are_deterministic_algorithms_enabled()], + **kwargs, + ) + var._call_func(tx, [target_value]) + var.set_cleanup_hook(tx) + return var + + def __init__( + self, + target_values: Sequence[bool], + initial_values: Optional[Sequence[bool]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + install_guard(self._guards_singleton) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + return variables.ConstantVariable.create(None) + + def _call_func(self, tx: "InstructionTranslator", values: Sequence[bool]) -> None: + assert len(values) == 1 + value = values[0] + tx.output.create_node( + "call_function", torch._C._set_deterministic_algorithms, (value,), {} + ) + torch._C._set_deterministic_algorithms(value) + + def module_name(self) -> str: + return "torch" + + def fn_name(self) -> str: + return "use_deterministic_algorithms" + + +class DisabledSavedTensorsHooksVariable(ContextWrappingVariable): + """represents torch.autograd.graph.disable_saved_tensors_hook.""" + + @staticmethod + def create( + tx: "InstructionTranslator", target_value: Optional[str], **kwargs: Any + ) -> "DisabledSavedTensorsHooksVariable": + var = DisabledSavedTensorsHooksVariable( + target_values=[target_value], + initial_values=[ + torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() + ], + **kwargs, + ) + var._call_func(tx, [target_value]) + var.set_cleanup_hook(tx) + return var + + def __init__( + self, + target_values: Sequence[Optional[str]], + initial_values: Optional[Sequence[Optional[str]]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + return variables.ConstantVariable.create(None) + + def _call_func( + self, tx: "InstructionTranslator", values: Sequence[Optional[str]] + ) -> None: + assert len(values) == 1 + value = values[0] + if value is not None: + # Disable `saved_tensors_hooks` with message (`value`) + # OR + # we are exiting this context and restoring the previous message. + tx.output.create_node( + "call_function", + torch._C._autograd._saved_tensors_hooks_disable, + (value,), + {}, + ) + torch._C._autograd._saved_tensors_hooks_disable(value) + else: + # We are exiting this context and if prev_message was None, we re-enable `saved_tensors_hooks`. + tx.output.create_node( + "call_function", torch._C._autograd._saved_tensors_hooks_enable, (), {} + ) + torch._C._autograd._saved_tensors_hooks_enable() + + def module_name(self) -> str: + return "torch.autograd.graph" + + def fn_name(self) -> str: + return "disable_saved_tensors_hooks" + + +class AutocastModeVariable(ContextWrappingVariable): + @staticmethod + def create( + func: torch.amp.autocast_mode.autocast, + args: Sequence[Any], + kwargs: dict[str, Any], + ) -> "AutocastModeVariable": + assert func in [ + torch.amp.autocast_mode.autocast, + torch.cuda.amp.autocast, + torch.cpu.amp.autocast, + ] + # device_type : str, + # dtype : Optional[_dtype] = None, + # enabled : bool = True, + # cache_enabled : Optional[bool] = None):cache_enabled + bound_args = inspect.signature(func).bind(*args, **kwargs) + bound_args.apply_defaults() + target_values = [] + kwargs.clear() + + for key in ["device_type", "dtype", "enabled", "cache_enabled"]: + if key == "device_type" and func in [ + torch.cuda.amp.autocast, + torch.cpu.amp.autocast, + ]: + arg = "cuda" if func is torch.cuda.amp.autocast else "cpu" + else: + arg = bound_args.arguments[key] + if isinstance(arg, VariableTracker): + target_values.append(arg.as_python_constant()) + else: + target_values.append(arg) + + var = AutocastModeVariable(target_values, initial_values=None, **kwargs) + return var + + def __init__( + self, + target_values: Sequence[Any], + initial_values: Optional[Any] = None, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup_assert() + tx.output.create_node( + "call_function", torch.amp._exit_autocast, (self.proxy,), {} + ) + return variables.ConstantVariable.create(None) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + ctx = torch.amp._enter_autocast(*self.target_values) + self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx)) + self.proxy = tx.output.create_node( + "call_function", torch.amp._enter_autocast, (*self.target_values,), {} + ) + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "torch.amp.autocast_mode" + + def fn_name(self) -> str: + return "autocast" + + +class NullContextVariable(ContextWrappingVariable): + """ + This class represents Python contextlib.nullcontext. + """ + + def __init__(self, target_values: Optional[Any] = None, **kwargs: Any) -> None: + super().__init__(target_values=target_values, **kwargs) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + none = variables.ConstantVariable.create(None) + return self.target_values if self.target_values else none + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "contextlib" + + def fn_name(self) -> str: + return "nullcontext" + + +class ProfilerContextVariable(ContextWrappingVariable): + """ + This class represents a set of torch profiler context objects, where Dynamo + ignores all the side-effects in the __init__, __enter__ and __exit__ methods + by treating the object mostly as a `contextlib.nullcontext`, except for edge + cases like the `__enter__` method which returns the object itself rather + than `None`, per implementation of the torch objects. + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__(target_values=None, **kwargs) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + return self + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "contextlib" + + def fn_name(self) -> str: + return "nullcontext" + + def reconstruct(self, cg: "PyCodegen") -> None: + unimplemented( + gb_type="torch.profiler object escaped from compiled region", + context=str(self), + explanation="Dynamo doesn't support compiling a region that returns a torch.profiler context manager.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + +class PreserveVersionContextVariable(ContextWrappingVariable): + """ + Wraps torch.autograd._unsafe_preserve_version_counter + """ + + @staticmethod + def _create_lambda_from_tensors( + tx: "InstructionTranslator", + tensors: VariableTracker, + ) -> "PreserveVersionContextVariable": + if tensors.is_tensor(): + versions = variables.TupleVariable( + [x.var_getattr(tx, "_version") for x in [tensors]] + ) + tensors_tuple = variables.TupleVariable([tensors]) + else: + assert isinstance(tensors, variables.TupleVariable) + versions = variables.TupleVariable( + [x.var_getattr(tx, "_version") for x in tensors.items] + ) + tensors_tuple = tensors + return PreserveVersionContextVariable(tensors_tuple, versions) + + @staticmethod + def constructor(tx: "InstructionTranslator") -> VariableTracker: + return variables.LambdaVariable( + lambda tensors: PreserveVersionContextVariable._create_lambda_from_tensors( + tx, tensors + ) + ) + + def __init__( + self, + tensors: VariableTracker, + prev_versions: VariableTracker, + **kwargs: Any, + ) -> None: + kwargs.setdefault("target_values", None) + super().__init__(**kwargs) + self.tensors = tensors + self.prev_versions = prev_versions + # The context manager accepts Union[Tensor, Tuple[Tensor]] + if self.tensors.is_tensor(): + self.tensors = variables.TupleVariable([self.tensors]) + if self.prev_versions.is_symnode_like(): + self.prev_versions = variables.TupleVariable([self.prev_versions]) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + from ..tensor_version_op import _unsafe_set_version_counter + + return variables.TorchInGraphFunctionVariable( + _unsafe_set_version_counter + ).call_function(tx, [self.tensors, self.prev_versions], {}) + + def reconstruct(self, codegen: "PyCodegen") -> None: + unimplemented( + gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region", + context=str(self), + explanation=( + "Dynamo doesn't support compiling a region that returns " + "a torch.autograd._unsafe_preserve_version_counter context manager." + ), + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + +class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable): + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE) # type: ignore[arg-type] + + @staticmethod + def create( + tx: "InstructionTranslator", + param_group_var: Any, + target_value: Any, + **kwargs: Any, + ) -> "FSDPParamGroupUseTrainingStateVariable": + var = FSDPParamGroupUseTrainingStateVariable( + param_group_var=param_group_var, + target_values=[target_value], + initial_values=[param_group_var.value._training_state], + **kwargs, + ) + return var + + def __init__( + self, + param_group_var: Any, + target_values: Sequence[Any], + initial_values: Optional[Sequence[Any]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.param_group_var = param_group_var + install_guard(self._guards_singleton) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + self._call_func(tx, self.target_values) + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self._call_func(tx, self.initial_values) # type: ignore[arg-type] + return variables.ConstantVariable.create(None) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # undo eager initialization + self._call_func(tx, self.initial_values) # type: ignore[arg-type] + return super().call_function(tx, args, kwargs) + + def _call_func(self, tx: "InstructionTranslator", values: Sequence[Any]) -> None: + assert len(values) == 1 + value = values[0] + if self.param_group_var.value._training_state != value: + self.param_group_var.call_method( + tx, + "__setattr__", + ( + variables.ConstantVariable.create("_training_state"), + variables.EnumVariable(value), + ), + {}, + ) + self.param_group_var.value._training_state = value + + def module_name(self) -> str: + return "torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup" + + def fn_name(self) -> str: + return "use_training_state" + + +class SDPAKernelVariable(ContextWrappingVariable): + """represents torch.nn.attention.sdpa_kernel""" + + @staticmethod + def create( + tx: "InstructionTranslator", + backends: Any, + set_priority: bool = False, + **kwargs: Any, + ) -> "SDPAKernelVariable": + if isinstance(backends, torch.nn.attention.SDPBackend): + backends = [backends] + var = SDPAKernelVariable( + target_values=backends, + initial_values=None, + set_priority=set_priority, + **kwargs, + ) + return var + + def __init__( + self, + target_values: list[torch.nn.attention.SDPBackend], + initial_values: Any = None, + set_priority: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.set_priority = set_priority + + @staticmethod + def _backends_to_nodes( + tx: "InstructionTranslator", + backends: list[Any], + ) -> list[Any]: + # convert to/from string in order to bake the backend into FX graph + nodes = [ + tx.output.create_node( + "call_function", + torch.nn.attention._backend_from_string, + (backend.name,), + {}, + ) + for backend in backends + ] + return nodes + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + self.prev_backends = torch.nn.attention._cur_sdpa_kernel_backends( + with_priority=self.set_priority + ) + self.set_cleanup_hook( + tx, + lambda: torch.nn.attention._sdpa_kernel( + self.prev_backends, set_priority=self.set_priority + ), + ) + torch.nn.attention._sdpa_kernel( + self.target_values, set_priority=self.set_priority + ) + arg = self._backends_to_nodes(tx, self.target_values) + tx.output.create_node( + "call_function", + torch.nn.attention._sdpa_kernel, + (arg, bool(self.set_priority)), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup_assert() + arg = self._backends_to_nodes(tx, self.prev_backends) + tx.output.create_node( + "call_function", + torch.nn.attention._sdpa_kernel, + (arg, bool(self.set_priority)), + {}, + ) + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "torch.nn.attention" + + # use a private version of sdpa_kernel that accepts variadic arguments + # since dynamo reconstructs the contents of target_values one-by-one + def fn_name(self) -> str: + return "_sdpa_kernel_variadic" + + +class FxTracebackAnnotateVariable(ContextWrappingVariable): + """ + fx.traceback.annotate is a context manager that allows users to annotate the + fx graph nodes with custom metadata. In the context of Dynamo, we don't have + to trace the body of the context manager. Instead we want to directly run + the body of the context manager, so the Dynamo created Fx graphs have the + right custom metadata. This variable tracker just runs __enter__ and + __exit__ method (instead of tracing). + """ + + def __init__( + self, target_values: Any, initial_values: Any = None, **kwargs: Any + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + def enter( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + # Run the annotation ctx manager in eager. Also ensure that + # preserve_node_meta context manager is setup. This is important to pass + # on the metadata to the create_proxy nodes. + stack = ExitStack() + stack.enter_context(torch.fx.traceback.annotate(self.target_values)) + stack.enter_context(torch.fx.traceback.preserve_node_meta()) + self.set_cleanup_hook(tx, lambda: stack.close()) + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "torch.fx.traceback" + + def fn_name(self) -> str: + return "annotate" + + def reconstruct_type(self, codegen: "PyCodegen") -> None: + unimplemented( + gb_type="torch.fx.traceback.annotate escaped from compiled region", + context=str(self), + explanation="Dynamo doesn't support graph break on torch.fx.traceback.annotate.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + +class DynamoConfigPatchVariable(ContextWrappingVariable): + """represents torch._dynamo.patch_dynamo_config""" + + # NOTE: no need to guard on dynamo config because dynamo config should not affect soundness + # (though it may affect tracing behavior) + def __init__(self, target_values: dict[str, Any], **kwargs: Any) -> None: + target_values_tuple = tuple(target_values.items()) + super().__init__( + target_values=(target_values_tuple,), initial_values=None, **kwargs + ) + initial_values_dict = {} + for key, _ in target_values_tuple: + initial_values_dict[key] = torch._dynamo.config.__getattr__(key) # type: ignore[attr-defined] + self.initial_values = (tuple(initial_values_dict.items()),) + + def _call_func(self, tx: "InstructionTranslator", values: Any) -> None: + assert len(values) == 1 + value = values[0] + # manually patch dynamo config + for key, val in value: + torch._dynamo.config.__setattr__(key, val) # type: ignore[attr-defined] + # No need to keep track of global side effects because + # dynamo will properly restore this context manager for + # unsupported instructions and continuation functions. + # Dynamo config also should not affect the semantics of the compiled graph. + + def module_name(self) -> str: + return "torch._dynamo" + + def fn_name(self) -> str: + return "patch_dynamo_config" + + +class ErrorOnGraphBreakVariable(ContextWrappingVariable): + """represents torch._dynamo.error_on_graph_break""" + + def __init__(self, error_on_graph_break: bool, **kwargs: Any) -> None: + super().__init__( + target_values=(error_on_graph_break,), + initial_values=(_get_error_on_graph_break(),), + **kwargs, + ) + + def _call_func(self, tx: "InstructionTranslator", values: Sequence[bool]) -> None: + assert len(values) == 1 + _set_error_on_graph_break(values[0]) + + def module_name(self) -> str: + return "torch._dynamo" + + def fn_name(self) -> str: + return "error_on_graph_break" + + +class WithEnterFunctionVariable(VariableTracker): + def __init__( + self, + ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable], + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.ctx = ctx + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + assert not args + assert not kwargs + # NOTE: we assume that the instruction immediately after the current CALL instruction + # is the first instruction of the block. + # pyrefly: ignore [bad-argument-type] + return tx.enter_ctx(self.ctx, tx.current_instruction) + + def reconstruct(self, codegen: "PyCodegen") -> None: + try: + type_str = f"{self.ctx.module_name()}.{self.ctx.fn_name()}" + except NotImplementedError: + type_str = str(type(self.ctx)) + unimplemented( + gb_type="Attempted to reconstruct context manager's __enter__ method", + context=str(self.ctx), + explanation=f"Attempted to reconstruct context manager {type_str} while tracing `with ...:`", + hints=[ + "It is likely there is a graph break while tracing `with ctx:` " + "but outside the actual `ctx.__enter__()` method. " + "`torch.compile` does not expect this to happen.", + *graph_break_hints.DIFFICULT, + *graph_break_hints.DYNAMO_BUG, + ], + ) + + +class WithExitFunctionVariable(VariableTracker): + _nonvar_fields = { + "target", + *VariableTracker._nonvar_fields, + } + + def __init__( + self, + ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable], + target: Any, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + assert isinstance( + ctx, (ContextWrappingVariable, GenericContextWrappingVariable) + ) + self.ctx = ctx + self.target = target + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + assert not kwargs + return self.ctx.exit(tx, *args) + + def reconstruct(self, codegen: "PyCodegen") -> None: + # Note here we reconstruct the context manager rather than the + # exit function. The handler generated by BlockStackEntry + # will re-enter the context in the resume function. + self.ctx.reconstruct_type(codegen) # type: ignore[union-attr] + if codegen.tx.output.partial_convert: + if sys.version_info >= (3, 11): + codegen.append_output(create_instruction("PUSH_NULL")) + if sys.version_info < (3, 13): + codegen.append_output(create_instruction("SWAP", arg=2)) + # We rely on classes subtyping `GenericContextWrappingVariable` + # to implement these fns and have these attributes + codegen.extend_output( + [codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[union-attr] + ) + codegen.extend_output( + create_call_function(len(self.ctx.target_values), False) # type: ignore[union-attr] + ) + codegen.append_output(create_setup_with(self.target)) + codegen.append_output(create_instruction("POP_TOP")) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/dicts.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/dicts.py new file mode 100644 index 0000000000000000000000000000000000000000..3a07bc1ac03cea5d41890904ce988f5608c96a82 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/dicts.py @@ -0,0 +1,1555 @@ +""" +Dictionary-related variable tracking classes for PyTorch Dynamo. + +This module implements variable tracking for different types of dictionary-like objects: +- Regular Python dictionaries (dict) +- Ordered dictionaries (collections.OrderedDict) +- Default dictionaries (collections.defaultdict) +- Dictionary views (keys and values) +- Sets and frozensets (implemented internally using dictionaries) + +These classes are responsible for tracking dictionary operations during graph compilation, +maintaining proper guards for dictionary mutations and key existence checks. They handle +dictionary creation, modification, key/value access, and view operations while ensuring +correct behavior in the compiled code through appropriate guard installation. + +The implementation uses a special _HashableTracker wrapper to handle dictionary keys +while preserving proper aliasing semantics. Sets are implemented as dictionaries with +None values for efficiency and code reuse. +""" + +import collections +import functools +import operator +import types +from collections.abc import Sequence +from typing import Any, Optional, TYPE_CHECKING, Union + +from .. import graph_break_hints, polyfills, variables +from ..bytecode_transformation import create_call_function, create_instruction +from ..exc import raise_observed_exception, unimplemented +from ..guards import GuardBuilder, install_guard +from ..source import is_constant_source, is_from_local_source +from ..utils import ( + cmp_name_to_op_mapping, + dict_items, + dict_keys, + dict_values, + istype, + raise_args_mismatch, + specialize_symnode, +) +from .base import ValueMutationNew, VariableTracker +from .constant import ConstantVariable +from .lists import ListIteratorVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + from .functions import UserFunctionVariable + + +# [Adding a new supported class within the keys of ConstDictVariable] +# - Implement is_python_hashable() method in the VariableTracker subclass +# - Implement get_python_hash() and is_python_equal() methods for hashable types + + +def was_instancecheck_override(obj: Any) -> bool: + return type(obj).__dict__.get("__instancecheck__", False) + + +def raise_unhashable( + arg: VariableTracker, tx: Optional["InstructionTranslator"] = None +) -> None: + if tx is None: + from torch._dynamo.symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + try: + arg_type = arg.python_type() + except Exception: + arg_type = type(arg) + + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable( + f"unhashable type: {arg_type!r} and variable tracker = {type(arg.realize())}" + ) + ], + ) + + +def is_hashable(x: VariableTracker) -> bool: + # NB - performing isinstance check on a LazVT realizes the VT, accidentally + # inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at + # the underlying value without realizing the VT. Consider updating the + # lazyVT `is_hashable` method if you see unnecessary guarding for a key VT. + if ( + isinstance(x, variables.LazyVariableTracker) + and not x.is_realized() + and x.is_hashable() + ): + return True + return x.is_python_hashable() + + +class ConstDictVariable(VariableTracker): + CONTAINS_GUARD = GuardBuilder.DICT_CONTAINS + + _nonvar_fields = { + "user_cls", + *VariableTracker._nonvar_fields, + } + + class _HashableTracker: + """ + Auxiliary opaque internal class that wraps a VariableTracker and makes it hashable + This should not be seen or touched by anything outside of ConstDictVariable and its children + Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing + """ + + def __init__(self, vt: VariableTracker) -> None: + # We specialize SymNodes + vt = specialize_symnode(vt) + + # If Dynamo does not know the hashability of the vt, it will raise unsupported here + if not is_hashable(vt): + raise_unhashable(vt) + self.vt = vt + + def __hash__(self) -> int: + """ + Computes the hash value for the wrapped VariableTracker. + + For unrealized LazyVariableTrackers, uses the hash of the original value + to avoid realizing the tracker and inserting unnecessary guards. + For all other cases, delegates to the VariableTracker's get_python_hash method. + + Returns: + The hash value of the underlying variable tracker + """ + if ( + isinstance(self.vt, variables.LazyVariableTracker) + and not self.vt.is_realized() + and self.vt.is_hashable() + ): + return hash(self.vt.original_value()) + return self.vt.get_python_hash() + + def __eq__(self, other) -> bool: + """ + Checks equality between two _HashableTracker instances. + + Delegates to the VariableTracker's is_python_equal method to compare + the underlying variable trackers for Python-level equality. + + Args: + other: Another _HashableTracker instance to compare with + + Returns: + True if the underlying variable trackers are Python-equal, False otherwise + """ + if self.vt is other.vt: + return True + return self.vt.is_python_equal(other.vt) + + def __init__( + self, + items: dict[VariableTracker, VariableTracker], + user_cls: type = dict, + **kwargs: Any, + ) -> None: + # .clone() pass these arguments in kwargs but they're recreated a few + # lines below + if "original_items" in kwargs: + kwargs.pop("original_items") + if "should_reconstruct_all" in kwargs: + kwargs.pop("should_reconstruct_all") + + super().__init__(**kwargs) + + Hashable = ConstDictVariable._HashableTracker + + # Keys will just be HashableTrackers when cloning, in any other case they'll be VariableTrackers + assert all( + isinstance(x, (VariableTracker, Hashable)) + and isinstance(v, VariableTracker) + for x, v in items.items() + ) + + def make_hashable( + key: Union[VariableTracker, "ConstDictVariable._HashableTracker"], + ) -> "ConstDictVariable._HashableTracker": + return key if isinstance(key, Hashable) else Hashable(key) + + dict_cls = self._get_dict_cls_from_user_cls(user_cls) + self.items = dict_cls({make_hashable(x): v for x, v in items.items()}) + # need to reconstruct everything if the dictionary is an intermediate value + # or if a pop/delitem was executed + self.should_reconstruct_all = ( + not is_from_local_source(self.source) if self.source else True + ) + self.original_items = items.copy() + self.user_cls = user_cls + + def _get_dict_cls_from_user_cls(self, user_cls: type) -> type: + accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict) + + # avoid executing user code if user_cls is a dict subclass + if user_cls in accepted_dict_types: + dict_cls = user_cls + else: + # + dict_cls = next( + base for base in user_cls.__mro__ if base in accepted_dict_types + ) + assert dict_cls in accepted_dict_types, dict_cls + + # Use a dict instead as the call "defaultdict({make_hashable(x): v ..})" + # would fail as defaultdict expects a callable as first argument + if dict_cls is collections.defaultdict: + dict_cls = dict + return dict_cls + + def as_proxy(self) -> dict[Any, Any]: + return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()} + + def debug_repr(self) -> str: + return ( + "{" + + ", ".join( + f"{k.vt.debug_repr()}: {v.debug_repr()}" for k, v in self.items.items() + ) + + "}" + ) + + def as_python_constant(self) -> dict[Any, Any]: + return { + k.vt.as_python_constant(): v.as_python_constant() + for k, v in self.items.items() + } + + def keys_as_python_constant(self) -> dict[Any, VariableTracker]: + self.install_dict_keys_match_guard() + return {k.vt.as_python_constant(): v for k, v in self.items.items()} + + def python_type(self) -> type: + return self.user_cls + + def __contains__(self, vt: VariableTracker) -> bool: + assert isinstance(vt, VariableTracker) + Hashable = ConstDictVariable._HashableTracker + return ( + vt.is_python_hashable() + and Hashable(vt) in self.items + and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) + ) + + def call_tree_map_branch( + self, + tx: "InstructionTranslator", + tree_map_fn: "UserFunctionVariable", + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + other_dicts: list[ConstDictVariable] = [] + for candidate in rest: + candidate = candidate.realize() + if not isinstance(candidate, ConstDictVariable) or len( + candidate.items + ) != len(self.items): + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + other_dicts.append(candidate) + + new_items_hashed = type(self.items)() + for key_tracker, value in self.items.items(): + sibling_leaves: list[VariableTracker] = [] + for candidate in other_dicts: + try: + sibling_leaves.append(candidate.items[key_tracker]) + except KeyError: + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + new_items_hashed[key_tracker] = value.call_tree_map( + tx, + tree_map_fn, + map_fn, + sibling_leaves, + tree_map_kwargs, + ) + + updated_original_items = { + key_tracker.vt: new_items_hashed[key_tracker] + for key_tracker in new_items_hashed + } + + return self.clone( + items=new_items_hashed, + original_items=updated_original_items, + should_reconstruct_all=True, + source=None, + mutation_type=ValueMutationNew(), + ) + + def len(self) -> int: + return sum( + not isinstance(x, variables.DeletedVariable) for x in self.items.values() + ) + + def has_new_items(self) -> bool: + return self.should_reconstruct_all or any( + self.is_new_item(self.original_items.get(key.vt), value) + for key, value in self.items.items() + ) + + def is_new_item( + self, value: Optional[VariableTracker], other: VariableTracker + ) -> bool: + # compare the id of the realized values if both values are not lazy VTs + if value and value.is_realized() and other.is_realized(): + return id(value.realize()) != id(other.realize()) + return id(value) != id(other) + + def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None: + # Build a dictionary that contains the keys and values. + num_args = 0 + for key, value in self.items.items(): + # We can safely call realize() here as it won't introduce any new guards + item = self.original_items.get(key.vt) + if self.is_new_item(item, value) or self.should_reconstruct_all: + codegen(key.vt) + codegen(value) + num_args += 1 + codegen.append_output(create_instruction("BUILD_MAP", arg=num_args)) + + def reconstruct(self, codegen: "PyCodegen") -> None: + if self.user_cls is collections.OrderedDict: + # emit `OrderedDict(constructed_dict)` + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(collections), + codegen.create_load_attr("OrderedDict"), + ] + ) + ) + self.reconstruct_kvs_into_new_dict(codegen) + codegen.extend_output(create_call_function(1, False)) + else: + self.reconstruct_kvs_into_new_dict(codegen) + + def getitem_const_raise_exception_if_absent( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + key = ConstDictVariable._HashableTracker(arg) + if key not in self.items: + try: + error_message = ( + f"Dict key lookup failed for {str(arg)}. " + f"Debug representation of the key is {arg.debug_repr()!r}" + ) + except Exception: + error_message = ConstantVariable.create( + f"Dict key lookup failed for {str(arg)}" + ) + raise_observed_exception(KeyError, tx, args=[error_message]) + return self.items[key] + + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + key = ConstDictVariable._HashableTracker(arg) + if key not in self.items: + msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined] + unimplemented( + gb_type="key not found in dict", + context=f"Key {arg.value}", # type: ignore[attr-defined] + explanation=msg, + hints=[ + "Check if the key exists in the dictionary before accessing it.", + *graph_break_hints.USER_ERROR, + ], + ) + return self.items[key] + + def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]: + key = ConstDictVariable._HashableTracker(arg) + if key not in self.items: + return None + return self.items[key] + + def realize_key_vt(self, arg: VariableTracker) -> None: + # Realize the LazyVT on a particular index + assert arg in self + key = ConstDictVariable._HashableTracker(arg) + index = tuple(self.items.keys()).index(key) + original_key_vt = tuple(self.original_items.keys())[index] + if isinstance(original_key_vt, variables.LazyVariableTracker): + original_key_vt.realize() + + def install_dict_keys_match_guard(self) -> None: + if self.source: + install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH)) + + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: + # Key guarding - These are the cases to consider + # 1) The dict has been mutated. In this case, we would have already + # inserted a DICT_KEYS_MATCH guard, so we can skip. + # + # 2) args[0].source is None. This happens for const keys. Here, we + # have to insert the DICT_CONTAINS guard. + # + # 3) args[0].source is not None. This can happen for non-const VTs. + # 3a) contains=True. In this case, we can access the lazyVT from + # original_items and selectively realize it. + # 3b) contains=False. There is no easy way to selectively apply this + # DICT_NOT_CONTAINS guard because our guard are represented via trees. + # Be conservative and add DICT_KEYS_MATCH guard. + + if not self.source: + return + + if tx.output.side_effects.is_modified(self): + return + + contains = args[0] in self + if args[0].source is None and args[0].is_python_constant(): + install_guard( + self.make_guard( + functools.partial( + type(self).CONTAINS_GUARD, + key=args[0].as_python_constant(), + invert=not contains, + ) + ) + ) + elif args[0].source: + if contains: + self.realize_key_vt(args[0]) + else: + self.install_dict_keys_match_guard() + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # NB - Both key and value are LazyVariableTrackers in the beginning. So, + # we have to insert guards when a dict method is accessed. For this to + # be simple, we are conservative and overguard. We skip guard only for + # get/__getitem__ because the key guard will be inserted by the + # corresponding value VT. For __contains__, we add a DICT_CONTAINS + # guard. But for all the other methods, we insert the DICT_KEYS_MATCH + # guard to be conservative. + from . import BuiltinVariable, ConstantVariable + + Hashable = ConstDictVariable._HashableTracker + + if name == "__init__": + temp_dict_vt = variables.BuiltinVariable(dict).call_dict( + tx, *args, **kwargs + ) + tx.output.side_effects.mutation(self) + self.items.update(temp_dict_vt.items) # type: ignore[attr-defined] + return ConstantVariable.create(None) + elif name == "__getitem__": + # Key guarding - Nothing to do. LazyVT for value will take care. + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + return self.getitem_const_raise_exception_if_absent(tx, args[0]) + elif name == "items": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + self.install_dict_keys_match_guard() + if self.source: + tx.output.guard_on_key_order.add(self.source) + return DictItemsVariable(self) + elif name == "keys": + if len(args): + raise_args_mismatch(tx, name, "0 args", f"{len(args)} args") + self.install_dict_keys_match_guard() + if self.source: + tx.output.guard_on_key_order.add(self.source) + return DictKeysVariable(self) + elif name == "values": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + self.install_dict_keys_match_guard() + if self.source: + tx.output.guard_on_key_order.add(self.source) + if args or kwargs: + raise_observed_exception(TypeError, tx) + return DictValuesVariable(self) + elif name == "copy": + self.install_dict_keys_match_guard() + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return self.clone( + items=self.items.copy(), mutation_type=ValueMutationNew(), source=None + ) + elif name == "__len__": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + self.install_dict_keys_match_guard() + return ConstantVariable.create(len(self.items)) + elif name == "__setitem__" and self.is_mutable(): + arg_hashable = args and is_hashable(args[0]) + if not arg_hashable: + raise_unhashable(args[0], tx) + + self.install_dict_keys_match_guard() + if kwargs or len(args) != 2: + raise_args_mismatch( + tx, + name, + "2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + tx.output.side_effects.mutation(self) + self.items[Hashable(args[0])] = args[1] + return ConstantVariable.create(None) + elif name == "__delitem__" and self.is_mutable(): + arg_hashable = args and is_hashable(args[0]) + if arg_hashable: + self.install_dict_keys_match_guard() + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.__delitem__(Hashable(args[0])) + return ConstantVariable.create(None) + else: + return super().call_method(tx, name, args, kwargs) + elif name == "get": + if len(args) not in (1, 2): + raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") + + arg_hashable = args and is_hashable(args[0]) + if not arg_hashable: + raise_unhashable(args[0], tx) + + if args[0] not in self: + self.install_dict_contains_guard(tx, args) + if len(args) == 1: + # if default is not given, return None + return ConstantVariable.create(None) + return args[1] + # Key guarding - Nothing to do. + return self.getitem_const(tx, args[0]) + elif name == "pop" and self.is_mutable(): + if len(args) not in (1, 2): + raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") + + arg_hashable = args and is_hashable(args[0]) + if not arg_hashable: + raise_unhashable(args[0], tx) + + if args[0] not in self: + # missing item, return the default value. Install no DICT_CONTAINS guard. + self.install_dict_contains_guard(tx, args) + if len(args) == 1: + # if default is not given, raise KeyError + raise_observed_exception(KeyError, tx) + return args[1] + + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + return self.items.pop(Hashable(args[0])) + elif name == "popitem" and self.is_mutable(): + if ( + issubclass(self.user_cls, dict) + and not issubclass(self.user_cls, collections.OrderedDict) + and len(args) + ): + raise_args_mismatch(tx, name) + + if not self.items: + msg = ConstantVariable.create("popitem(): dictionary is empty") + raise_observed_exception(KeyError, tx, args=[msg]) + + if self.user_cls is collections.OrderedDict and ( + len(args) == 1 or "last" in kwargs + ): + if len(args) == 1 and args[0].is_python_constant(): + last = args[0].as_python_constant() + elif (v := kwargs.get("last")) and v.is_python_constant(): + last = v.as_python_constant() + else: + raise_args_mismatch(tx, name) + k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined] + else: + k, v = self.items.popitem() + + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + + return variables.TupleVariable([k.vt, v]) + elif name == "clear": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.clear() + return ConstantVariable.create(None) + elif name == "update" and self.is_mutable(): + # In general, this call looks like `a.update(b, x=1, y=2, ...)`. + # Either `b` or the kwargs is omittable, but not both. + self.install_dict_keys_match_guard() + has_arg = len(args) == 1 + has_kwargs = len(kwargs) > 0 + if has_arg or has_kwargs: + tx.output.side_effects.mutation(self) + if has_arg: + if isinstance(args[0], ConstDictVariable): + # NB - Guard on all the keys of the other dict to ensure + # correctness. + args[0].install_dict_keys_match_guard() + dict_vt: ConstDictVariable = args[0] + else: + dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment] + self.items.update(dict_vt.items) # type: ignore[attr-defined] + if has_kwargs: + # Handle kwargs + kwargs_hashable = { + Hashable(ConstantVariable.create(k)): v + for k, v in kwargs.items() + } + self.items.update(kwargs_hashable) + return ConstantVariable.create(None) + else: + return super().call_method(tx, name, args, kwargs) + elif name == "__contains__": + if not len(args): + raise_args_mismatch( + tx, + name, + "more than 1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + arg_hashable = args and is_hashable(args[0]) + if not arg_hashable: + raise_unhashable(args[0], tx) + + self.install_dict_contains_guard(tx, args) + contains = args[0] in self + return ConstantVariable.create(contains) + elif name == "setdefault" and self.is_mutable(): + if len(args) not in (1, 2): + raise_args_mismatch( + tx, + name, + "1 or 2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + arg_hashable = args and is_hashable(args[0]) + if not arg_hashable: + raise_unhashable(args[0], tx) + + self.install_dict_keys_match_guard() + if kwargs or len(args) > 2: + raise_args_mismatch( + tx, + name, + "at most 2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + value = self.maybe_getitem_const(args[0]) + if value is not None: + return value + else: + if len(args) == 1: + x = ConstantVariable.create(None) + else: + x = args[1] + tx.output.side_effects.mutation(self) + self.items[Hashable(args[0])] = x + return x + elif name == "move_to_end": + self.install_dict_keys_match_guard() + tx.output.side_effects.mutation(self) + if args[0] not in self: + raise_observed_exception(KeyError, tx) + + last = True + if len(args) == 2 and args[1].is_python_constant(): + last = args[1].as_python_constant() + + if kwargs and "last" in kwargs and kwargs["last"].is_python_constant(): + last = kwargs.get("last").as_python_constant() # type: ignore[union-attr] + + key = Hashable(args[0]) + self.items.move_to_end(key, last=last) + return ConstantVariable.create(None) + elif name == "__eq__" and istype( + self, ConstDictVariable + ): # don't let Set use this function + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + + return variables.UserFunctionVariable(polyfills.dict___eq__).call_function( + tx, [self, args[0]], {} + ) + elif name == "__ne__": + return ConstantVariable.create( + not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined] + ) + elif name == "__or__": + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + other = args[0] + + # Method resolution for binops works as follow (using __or__ as example): + # (1) dict.__or__(dict) => dict + # (2) dict.__or__(subclass): return NotImplemented + # (3) Check if subclass implements __ror__ => forward the call + # to subclass.__ror__(dict) + + # Let's not forward the call to __ror__ yet because __ror__ can be + # implemented in C (i.e. OrderedDict subclass) which Dynamo cannot + # trace + # if istype(other, variables.UserDefinedDictVariable): + # if other.call_obj_hasattr(tx, "__ror__").value: + # return other.call_method(tx, "__ror__", [self], kwargs) + + # The three dict types Dynamo can handle are dict, OrderedDict and + # defaultdict. + + # TODO(guilhermeleobas): this check should be on builtin.py::call_or_ + if not istype( + other, (ConstDictVariable, variables.UserDefinedDictVariable) + ): + err_msg = ( + f"unsupported operand type(s) for |: '{self.python_type().__name__}'" + f"and '{other.python_type().__name__}'" + ) + raise_observed_exception(TypeError, tx, args=[err_msg]) + + # OrderedDict overloads __ror__ + ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined] + user_cls = ( + collections.OrderedDict + if any(issubclass(t, collections.OrderedDict) for t in ts) + else dict + ) + + self.install_dict_keys_match_guard() + new_dict_vt = self.clone( + items=self.items.copy(), + mutation_type=ValueMutationNew(), + source=None, + user_cls=user_cls, + ) + + # NB - Guard on all the keys of the other dict to ensure + # correctness. + args[0].install_dict_keys_match_guard() # type: ignore[attr-defined] + new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined] + return new_dict_vt + elif name == "__ior__": + self.call_method(tx, "update", args, kwargs) + return self + elif name == "__iter__": + if self.source and not is_constant_source(self.source): + tx.output.guard_on_key_order.add(self.source) + return ListIteratorVariable( + self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() + ) + else: + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + self.install_dict_keys_match_guard() + return [x.vt for x in self.items] + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + # dict not allow setting arbitrary attributes. OrderedDict and + # defaultdict allow arbitrary setattr, but not deletion of default attrs + if any( + self.user_cls is t + for t in (dict, collections.OrderedDict, collections.defaultdict) + ): + if hasattr(self.user_cls, name): + return ConstantVariable.create(True) + if self.user_cls is dict: + return ConstantVariable.create(False) + + msg = f"hasattr on {self.user_cls} is not supported" + unimplemented( + gb_type="unsupported hasattr operation", + context=f"Class {self.user_cls}", + explanation=msg, + hints=[ + "Consider using a regular dictionary instead", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def clone(self, **kwargs: Any) -> VariableTracker: + self.install_dict_keys_match_guard() + return super().clone(**kwargs) + + def is_python_hashable(self): + """ + Dictionaries are mutable and therefore not hashable in Python. + """ + return False + + +class MappingProxyVariable(VariableTracker): + # proxies to the original dict_vt + def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None: + super().__init__(**kwargs) + assert isinstance(dv_dict, ConstDictVariable) + self.dv_dict = dv_dict + + def python_type(self) -> type: + return types.MappingProxyType + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + return self.dv_dict.unpack_var_sequence(tx) + + def reconstruct(self, codegen: "PyCodegen") -> None: + # load types.MappingProxyType + if self.source: + msg = ( + f"Preexisting MappingProxyVariable (source: {self.source}) cannot be reconstructed " + "because the connection to the original dict will be lost." + ) + unimplemented( + gb_type="mapping proxy cannot be reconstructed", + context=f"Source: {self.source}", + explanation=msg, + hints=[ + "Use a mapping proxy constructed in the same `torch.compile` region.", + *graph_break_hints.SUPPORTABLE, + ], + ) + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(types), + codegen.create_load_attr("MappingProxyType"), + ] + ) + ) + codegen(self.dv_dict) + codegen.extend_output(create_call_function(1, False)) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.source and tx.output.side_effects.has_existing_dict_mutation(): + msg = ( + "A dict has been modified while we have an existing mappingproxy object. " + "A mapping proxy object, as the name suggest, proxies a mapping " + "object (usually a dict). If the original dict object mutates, it " + "is reflected in the proxy object as well. For an existing proxy " + "object, we do not know the original dict it points to. Therefore, " + "for correctness we graph break when there is dict mutation and we " + "are trying to access a proxy object." + ) + + unimplemented( + gb_type="mapping proxy affected by dictionary mutation", + context=f"Source: {self.source}, Dict mutation detected", + explanation=msg, + hints=[ + "Avoid modifying dictionaries that might be referenced by mapping proxy objects", + "Or avoid using the mapping proxy objects after modifying its underlying dictionary", + ], + ) + return self.dv_dict.call_method(tx, name, args, kwargs) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if self.python_type() is types.MappingProxyType: + return ConstantVariable.create(name in types.MappingProxyType.__dict__) + return super().call_obj_hasattr(tx, name) + + +class NNModuleHooksDictVariable(ConstDictVariable): + # Special class to avoid adding any guards on the nn module hook ids. + def install_dict_keys_match_guard(self) -> None: + pass + + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: + pass + + +class DefaultDictVariable(ConstDictVariable): + def __init__( + self, + items: dict[VariableTracker, VariableTracker], + user_cls: type, + default_factory: Optional[VariableTracker] = None, + **kwargs: Any, + ) -> None: + super().__init__(items, user_cls, **kwargs) + assert user_cls is collections.defaultdict + if default_factory is None: + default_factory = ConstantVariable.create(None) + self.default_factory = default_factory + + def is_python_constant(self) -> bool: + # Return false for unsupported defaults. This ensures that a bad handler + # path is not taken in BuiltinVariable for getitem. + if self.default_factory not in [list, tuple, dict] and not self.items: + return False + return super().is_python_constant() + + def debug_repr(self) -> str: + assert self.default_factory is not None + return ( + f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})" + ) + + @staticmethod + def is_supported_arg(arg: VariableTracker) -> bool: + if isinstance(arg, variables.BuiltinVariable): + return arg.fn in (list, tuple, dict, set) + else: + return isinstance( + arg, + ( + variables.functions.BaseUserFunctionVariable, + variables.functions.PolyfilledFunctionVariable, + ), + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__getitem__": + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + + if args[0] in self: + return self.getitem_const(tx, args[0]) + else: + if ( + istype(self.default_factory, ConstantVariable) + and self.default_factory.value is None + ): + raise_observed_exception(KeyError, tx, args=[args[0]]) + else: + default_var = self.default_factory.call_function(tx, [], {}) + super().call_method( + tx, "__setitem__", [args[0], default_var], kwargs + ) + return default_var + else: + return super().call_method(tx, name, args, kwargs) + + def reconstruct(self, codegen: "PyCodegen") -> None: + # emit `defaultdict(default_factory, new_dict)` + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(collections), + codegen.create_load_attr("defaultdict"), + ] + ) + ) + codegen(self.default_factory) + self.reconstruct_kvs_into_new_dict(codegen) + codegen.extend_output(create_call_function(2, False)) + + +# TODO: Implementing this via inheritance rather than composition is a +# footgun, because self method calls in dict will route back to the set +# implementation, which is almost assuredly wrong +class SetVariable(ConstDictVariable): + """We model a sets as dictionary with None values""" + + CONTAINS_GUARD = GuardBuilder.SET_CONTAINS + + def __init__( + self, + items: list[VariableTracker], + **kwargs: Any, + ) -> None: + # pyrefly: ignore[bad-assignment] + items = dict.fromkeys(items, SetVariable._default_value()) + # pyrefly: ignore[bad-argument-type] + super().__init__(items, **kwargs) + + def debug_repr(self) -> str: + if not self.items: + return "set()" + else: + return "{" + ",".join(k.vt.debug_repr() for k in self.items) + "}" + + @property + def set_items(self) -> set["ConstDictVariable._HashableTracker"]: + return set(self.items.keys()) + + @staticmethod + def _default_value() -> VariableTracker: + # Variable to fill in he keys of the dictionary + return ConstantVariable.create(None) + + def as_proxy(self) -> Any: + return {k.vt.as_proxy() for k in self.set_items} + + def python_type(self) -> type: + return set + + def as_python_constant(self) -> Any: + return {k.vt.as_python_constant() for k in self.set_items} + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach([x.vt for x in self.set_items]) + codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) + + def _fast_set_method( + self, + tx: "InstructionTranslator", + fn: Any, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + try: + res = fn( + *[x.as_python_constant() for x in [self, *args]], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + except Exception as exc: + raise_observed_exception( + type(exc), tx, args=list(map(ConstantVariable.create, exc.args)) + ) + # pyrefly: ignore[unbound-name] + return VariableTracker.build(tx, res) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # We forward the calls to the dictionary model + from ..utils import check_constant_args + + if ( + name + in ( + "isdisjoint", + "union", + "intersection", + "difference", + "symmetric_difference", + ) + and check_constant_args(args, kwargs) + and self.python_type() is set + ): + py_type = self.python_type() + return self._fast_set_method(tx, getattr(py_type, name), args, kwargs) + + if name == "__init__": + temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs) + tx.output.side_effects.mutation(self) + self.items.clear() + self.items.update(temp_set_vt.items) # type: ignore[attr-defined] + return ConstantVariable.create(None) + elif name == "add": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + name = "__setitem__" + args = [args[0], SetVariable._default_value()] + elif name == "pop": + if kwargs or args: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + # Choose an item at random and pop it via the Dict.pop method + try: + result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment] + except KeyError as e: + raise_observed_exception( + KeyError, tx, args=list(map(ConstantVariable.create, e.args)) + ) + # pyrefly: ignore[unbound-name] + super().call_method(tx, name, [result], kwargs) + # pyrefly: ignore[unbound-name] + return result + elif name == "isdisjoint": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return variables.UserFunctionVariable( + polyfills.set_isdisjoint + ).call_function(tx, [self, args[0]], {}) + elif name == "intersection": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return variables.UserFunctionVariable( + polyfills.set_intersection + ).call_function(tx, [self, *args], {}) + elif name == "intersection_update": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return variables.UserFunctionVariable( + polyfills.set_intersection_update + ).call_function(tx, [self, *args], {}) + elif name == "union": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return variables.UserFunctionVariable(polyfills.set_union).call_function( + tx, [self, *args], {} + ) + elif name == "difference": + if kwargs: + raise_args_mismatch( + tx, name, f"Expect: 0 kwargs, Actual: {len(kwargs)} kwargs" + ) + return variables.UserFunctionVariable( + polyfills.set_difference + ).call_function(tx, [self, *args], {}) + elif name == "difference_update": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return variables.UserFunctionVariable( + polyfills.set_difference_update + ).call_function(tx, [self, *args], {}) + elif name == "symmetric_difference": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return variables.UserFunctionVariable( + polyfills.set_symmetric_difference + ).call_function(tx, [self, *args], {}) + elif name == "symmetric_difference_update": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return variables.UserFunctionVariable( + polyfills.set_symmetric_difference_update + ).call_function(tx, [self, *args], {}) + elif name == "update" and self.is_mutable(): + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return variables.UserFunctionVariable(polyfills.set_update).call_function( + tx, [self, *args], {} + ) + elif name == "remove": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + if args[0] not in self: + raise_observed_exception(KeyError, tx, args=args) + return super().call_method(tx, "pop", args, kwargs) + elif name == "discard": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + if args[0] in self: + return super().call_method(tx, "pop", args, kwargs) + else: + return ConstantVariable.create(value=None) + elif name in ("issubset", "issuperset"): + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + + op = { + "issubset": operator.le, + "issuperset": operator.ge, + } + other = args[0].realize() + if not istype(other, SetVariable): + other = variables.BuiltinVariable(set).call_function(tx, [other], {}) + return variables.BuiltinVariable(op.get(name)).call_function( + tx, [self, other], {} + ) + elif name in ("__and__", "__or__", "__xor__", "__sub__"): + m = { + "__and__": "intersection", + "__or__": "union", + "__xor__": "symmetric_difference", + "__sub__": "difference", + }.get(name) + if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): + msg = ConstantVariable.create( + f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'" + ) + raise_observed_exception(TypeError, tx, args=[msg]) + assert m is not None + return self.call_method(tx, m, args, kwargs) + elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"): + if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): + msg = ConstantVariable.create( + f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'" + ) + raise_observed_exception(TypeError, tx, args=[msg]) + m = { + "__iand__": "intersection_update", + "__ior__": "update", + "__ixor__": "symmetric_difference_update", + "__isub__": "difference_update", + }.get(name) + assert m is not None + self.call_method(tx, m, args, kwargs) + return self + elif name == "__eq__": + if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): + return ConstantVariable.create(False) + r = self.call_method(tx, "symmetric_difference", args, kwargs) + return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined] + elif name in cmp_name_to_op_mapping: + if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): + return ConstantVariable.create(NotImplemented) + return ConstantVariable.create( + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined] + ) + return super().call_method(tx, name, args, kwargs) + + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + raise RuntimeError("Illegal to getitem on a set") + + def install_dict_keys_match_guard(self) -> None: + # Already EQUALS_MATCH guarded + pass + + +class FrozensetVariable(SetVariable): + def debug_repr(self) -> str: + if not self.items: + return "frozenset()" + else: + return "{" + ",".join(k.vt.debug_repr() for k in self.items) + "}" + + @property + def set_items(self) -> set["ConstDictVariable._HashableTracker"]: + return self.items.keys() + + def python_type(self) -> type: + return frozenset + + def as_python_constant(self) -> Any: + return frozenset({k.vt.as_python_constant() for k in self.set_items}) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach([x.vt for x in self.set_items]) + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_global("frozenset"), + ] + ) + ) + codegen.extend_output(create_call_function(0, False)) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name in ["add", "pop", "update", "remove", "discard", "clear"]: + raise RuntimeError(f"Illegal call_method {name} on a frozenset") + elif name == "__init__": + # frozenset is immutable. Calling __init__ again shouldn't have any effect + # In[1]: s = frozenset([1, 2]) + # + # In[2]: s.__init__([3, 4]) + # + # In[3]: s + # frozenset({1, 2}) + return ConstantVariable.create(None) + elif name in ( + "copy", + "difference", + "intersection", + "symmetric_difference", + ): + r = super().call_method(tx, name, args, kwargs) + return FrozensetVariable(r.items) # type: ignore[attr-defined] + return super().call_method(tx, name, args, kwargs) + + def is_python_hashable(self): + """ + Frozensets are immutable and hashable in Python. + """ + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +class DictKeySetVariable(SetVariable): + def debug_repr(self) -> str: + if not self.items: + return "dict_keys([])" + else: + return ( + "dict_keys([" + ",".join(k.vt.debug_repr() for k in self.items) + "])" + ) + + def install_dict_keys_match_guard(self) -> None: + # Already EQUALS_MATCH guarded + pass + + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: + # Already EQUALS_MATCH guarded + pass + + @property + def set_items(self) -> Any: + return self.items + + def python_type(self) -> type: + return dict_keys + + def as_python_constant(self) -> Any: + return dict.fromkeys( + {k.vt.as_python_constant() for k in self.set_items}, None + ).keys() + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name in ["add", "pop", "update", "remove", "discard", "clear"]: + raise RuntimeError(f"Illegal call_method {name} on a dict_keys") + return super().call_method(tx, name, args, kwargs) + + +class DictViewVariable(VariableTracker): + """ + Models _PyDictViewObject + + This is an "abstract" class. Subclasses will override kv and the items method + """ + + kv: Optional[str] = None + + def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None: + super().__init__(**kwargs) + assert self.kv in ("keys", "values", "items") + assert isinstance(dv_dict, ConstDictVariable) + self.dv_dict = dv_dict + + @property + def view_items(self) -> Any: + assert self.kv is not None + return getattr(self.dv_dict.items, self.kv)() + + @property + def view_items_vt(self) -> list[VariableTracker]: + # Returns an iterable of the unpacked items + # Implement in the subclasses + raise NotImplementedError + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + return self.view_items_vt + + def reconstruct(self, codegen: "PyCodegen") -> None: + assert self.kv is not None + codegen(self.dv_dict) + codegen.load_method(self.kv) + codegen.call_method(0) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + assert self.kv is not None + if name in self.python_type().__dict__: + return ConstantVariable.create(True) + return ConstantVariable.create(False) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__len__": + return self.dv_dict.call_method(tx, name, args, kwargs) + elif name == "__iter__": + return ListIteratorVariable( + self.view_items_vt, mutation_type=ValueMutationNew() + ) + return super().call_method(tx, name, args, kwargs) + + +class DictKeysVariable(DictViewVariable): + kv = "keys" + + @property + def set_items(self) -> set[VariableTracker]: + return set(self.view_items) + + @property + def view_items_vt(self) -> list[VariableTracker]: + # Returns an iterable of the unpacked items + return [x.vt for x in self.view_items] + + def python_type(self) -> type: + return dict_keys + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__contains__": + return self.dv_dict.call_method(tx, name, args, kwargs) + elif name in ( + "__and__", + "__iand__", + "__or__", + "__ior__", + "__sub__", + "__isub__", + "__xor__", + "__ixor__", + ): + # These methods always returns a set + m = getattr(self.set_items, name) + r = m(args[0].set_items) # type: ignore[attr-defined] + return SetVariable(r) + if name in cmp_name_to_op_mapping: + if not isinstance(args[0], (SetVariable, DictKeysVariable)): + return ConstantVariable.create(NotImplemented) + return ConstantVariable.create( + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined] + ) + return super().call_method(tx, name, args, kwargs) + + +class DictValuesVariable(DictViewVariable): + # DictValuesVariable is an iterable but cannot be compared. + kv = "values" + + @property + def view_items_vt(self) -> list[VariableTracker]: + return list(self.view_items) + + def python_type(self) -> type: + return dict_values + + +class DictItemsVariable(DictViewVariable): + kv = "items" + + @property + def view_items_vt(self) -> list[VariableTracker]: + # Returns an iterable of the unpacked items + return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items] + + def python_type(self) -> type: + return dict_items + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # TODO(guilhermeleobas): This should actually check if args[0] + # implements the mapping protocol. + if name == "__eq__": + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + if isinstance(args[0], DictItemsVariable): + return self.dv_dict.call_method(tx, "__eq__", [args[0].dv_dict], {}) + return ConstantVariable.create(False) + return super().call_method(tx, name, args, kwargs) + + def is_python_hashable(self): + """ + Dictionary item views are not hashable in Python. + """ + return False diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/distributed.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..cbf80e45bd0ed597c2d9ae4e3c7e131da52f2d34 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/distributed.py @@ -0,0 +1,507 @@ +""" +Distributed computing variable tracking classes for PyTorch Dynamo. + +This module implements variable tracking for distributed computing components: +- Process Groups (for collective communication) +- Device Meshes (for distributed tensor sharding) +- Placement Types (for specifying distribution strategies) +- Distributed Tensors and their operations +- Backward hooks for distributed module operations + +These classes are responsible for tracking distributed operations during graph +compilation while maintaining proper guards and handling distributed-specific +behaviors. They ensure correct handling of distributed components like process +groups, device meshes, and placement strategies while preserving proper semantics +for distributed tensor operations in the compiled code. + +The implementation provides special handling for distributed package availability +checks and proper tracking of distributed state and operations across processes. +""" + +import functools +import inspect +from collections.abc import Sequence +from typing import Any, TYPE_CHECKING + +import torch +from torch.fx.experimental._backward_state import BackwardState + +from .. import compiled_autograd, variables +from .._trace_wrapped_higher_order_op import trace_wrapped +from ..bytecode_transformation import create_call_function +from ..exc import unimplemented +from ..external_utils import call_module_hooks_from_backward_state +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource +from ..utils import istype +from .base import VariableTracker +from .constant import ConstantVariable, EnumVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class DistributedVariable(VariableTracker): + """ + The base distributed variable that encapsulates common methods + for the distributed objects (i.e. ProcessGroup, DeviceMesh, etc.). + Concrete distributed objects could inherit this class and add object + specific logic. + + i.e. It provides the check on the distributed package existence + and hold the tracking value for the corresponding distributed object. + """ + + def __init__(self, value: Any, **kwargs: Any) -> None: + super().__init__(**kwargs) + if not DistributedVariable.is_available(): + unimplemented( + gb_type="torch.distributed package is not available!", + context="", + explanation="The PyTorch package doesn't include torch.distributed when building from source.", + hints=[ + "Set USE_DISTRIBUTED=1 to enable it when building PyTorch from source." + ], + ) + self.value = value + + def python_type(self) -> type: + return type(self.value) + + @staticmethod + def is_available() -> bool: + # check if the distributed package is available or not + return torch.distributed.is_available() + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +def is_from_local(value: object) -> bool: + if not DistributedVariable.is_available(): + return False + from torch.distributed.tensor import DTensor + + return inspect.isfunction(value) and value is DTensor.from_local + + +def is_constant_pg_functions(value: object) -> bool: + if not DistributedVariable.is_available(): + return False + + from torch.distributed.distributed_c10d import ( + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + _resolve_group_name_by_ranks_and_tag, + get_process_group_ranks, + ) + + constant_processgroup_functions = [ + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + get_process_group_ranks, + _resolve_group_name_by_ranks_and_tag, + ] + + return inspect.isfunction(value) and value in constant_processgroup_functions + + +class WorldMetaClassVariable(DistributedVariable): + """ + Tracks torch.distributed.GroupMember and torch.distributed.group, which are + instances of the metaclass _WorldMeta. + """ + + @classmethod + def is_group_member_type(cls, value: object) -> bool: + if not cls.is_available(): + return False + + from torch.distributed.distributed_c10d import _WorldMeta + + return type(value) is _WorldMeta + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "WORLD": + assert self.source + source = AttrSource(base=self.source, member="WORLD") + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + return ProcessGroupVariable(self.value.WORLD) + elif name == "NON_GROUP_MEMBER": + assert self.source + source = AttrSource(base=self.source, member="NON_GROUP_MEMBER") + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + return EnumVariable(self.value.NON_GROUP_MEMBER) + return super().var_getattr(tx, name) + + +class PlacementClassVariable(DistributedVariable): + @staticmethod + def is_placement_type(value: object) -> bool: + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + + from torch.distributed.tensor.placement_types import Placement + + return isinstance(value, type) and issubclass(value, Placement) + + def as_python_constant(self) -> Any: + return self.value + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.source: + # NOTE: we don't need to track mutations to the placement class as they + # are supposed to be immutable. + new_obj = self.value.__new__(self.value) + var = PlacementVariable(new_obj) + if inspect.getattr_static(self.value, "__init__", None): + var.call_method(tx, "__init__", args, kwargs) + return var + + return super().call_function(tx, args, kwargs) + + +class PlacementVariable(DistributedVariable): + @staticmethod + def is_placement(value: object) -> bool: + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + from torch.distributed.tensor.placement_types import Placement + + return isinstance(value, Placement) + + def as_python_constant(self) -> Any: + return self.value + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "dim": + return ConstantVariable.create(self.value.dim) + return super().var_getattr(tx, name) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from . import ConstantVariable + + # Placement types dynamo tracking only allows following methods + # and __setattr__ is for case like `Shard(dim)` and methods. + # Methods in the list must satisfy: + # 1. Input arguments are constants and do not need to be guarded on; + # 2. Output is constant with respect to their inputs + constant_fold_functions = [ + "__init__", + "__setattr__", + "is_shard", + "is_partial", + "is_replicate", + ] + + if name in constant_fold_functions: + try: + value_type = type(self.value) + if inspect.getattr_static(value_type, "__getattr__", None) is not None: + unimplemented( + gb_type="Placement with custom __getattr__ not supported", + context=f"{value_type.__name__} with custom __getattr__", + explanation="Dynamo does not support Placement types with custom __getattr__ methods", + hints=[ + "Use Placement types without custom __getattr__ methods", + "Move the Placement usage outside the compiled region", + ], + ) + method = inspect.getattr_static(value_type, name) + except AttributeError: + method = None + if method is object.__init__: + return ConstantVariable.create(None) + + args = [x.as_python_constant() for x in args] + kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + assert method is not None + if name == "__setattr__": + method(self.value, *args, **kwargs) + return self + constant_val = method(self.value, *args, **kwargs) + return ConstantVariable.create(constant_val) + + return super().call_method(tx, name, args, kwargs) # type: ignore[arg-type] + + def reconstruct(self, codegen: "PyCodegen") -> None: + # Reconstruct the Placement object by calling its constructor + # e.g., Shard(0), Replicate(), Partial() + from torch.distributed.tensor.placement_types import Partial, Replicate, Shard + + placement_type = type(self.value) + + # Load the placement class + codegen.add_push_null( + lambda: codegen.load_import_from( + "torch.distributed.tensor.placement_types", placement_type.__name__ + ) + ) + + # For Shard, we need to pass the dim argument + if isinstance(self.value, Shard): + codegen(ConstantVariable.create(self.value.dim)) + codegen.extend_output(create_call_function(1, False)) + # Replicate and Partial have no required args + elif istype(self.value, (Replicate, Partial)): + codegen.extend_output(create_call_function(0, False)) + else: + super().reconstruct(codegen) + + +class DeviceMeshVariable(DistributedVariable): + @staticmethod + def is_device_mesh(value: object) -> bool: + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + + from torch.distributed.device_mesh import DeviceMesh + + return istype(value, DeviceMesh) + + def as_python_constant(self) -> Any: + return self.value + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "ndim": + return ConstantVariable.create(self.value.ndim) + if name == "device_type": + return ConstantVariable.create(self.value.device_type) + if name == "mesh_dim_names": + source = self.source + if source: + source = AttrSource(base=source, member="mesh_dim_names") + return VariableTracker.build(tx, self.value.mesh_dim_names, source) + return super().var_getattr(tx, name) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "size": + const_args = [x.as_python_constant() for x in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + return ConstantVariable.create(self.value.size(*const_args, **const_kwargs)) + if name == "get_coordinate": + return ConstantVariable.create(self.value.get_coordinate()) + if name == "get_rank": + return ConstantVariable.create(self.value.get_rank()) + if name == "get_local_rank": + const_args = [x.as_python_constant() for x in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + return ConstantVariable.create( + self.value.get_local_rank(*const_args, **const_kwargs) + ) + if name == "get_group": + const_args = [x.as_python_constant() for x in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + return ProcessGroupVariable( + self.value.get_group(*const_args, **const_kwargs) + ) + if name == "_get_or_create_default_group": + return ProcessGroupVariable(self.value._get_or_create_default_group()) + if name == "_flatten": + from .builder import SourcelessBuilder + + const_args = [x.as_python_constant() for x in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + return SourcelessBuilder.create( + tx, self.value._flatten(*const_args, **const_kwargs) + ) + return super().call_method(tx, name, args, kwargs) + + +class ProcessGroupVariable(DistributedVariable): + """ + We don't want a ProcessGroup object to end up in our output graph. + + But it's common for dynamo to intercept a PG that is then used to get info like + rank() or world_size(), as well as passed to utility functions in distributed_c10d + which desugar it into plain types like a ranklist and tag. + + For convenience and proper guarding, we construct a variable type. + + TODO: make it possible to use ProcessGroupVariable as input to simple functions + like _expand_group without dynamo complaining about making a proxy for it. + It is not a tensor-like type, and we don't want a proxy- but dynamo assumes + torch library functions are dealing with tensor-like types and would have proxies + for their args. + TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors + or just graph-break whenever one of our special cases is not hit? + """ + + def as_python_constant(self) -> Any: + return self.value + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "rank": + return variables.ConstantVariable.create(self.value.rank()) + if name == "size": + return variables.ConstantVariable.create(self.value.size()) + if name == "_get_backend_name": + return variables.ConstantVariable.create(self.value._get_backend_name()) + + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "group_name": + return variables.ConstantVariable.create(self.value.group_name) + if name in ["rank", "size"]: + return variables.LambdaVariable( + lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) + ) + # TODO should this just raise unimplemented? + return super().var_getattr(tx, name) + + @staticmethod + def is_process_group(value: object) -> bool: + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + from torch._C._distributed_c10d import ProcessGroup + from torch.testing._internal.distributed.fake_pg import FakeProcessGroup + + return istype(value, (ProcessGroup, FakeProcessGroup)) + + +class BackwardHookVariable(VariableTracker): + """ + Handles torch.utils.hooks.BackwardHook for module-level backward + hooks. + """ + + @staticmethod + def create( + tx: "InstructionTranslator", + module: VariableTracker, + user_hooks: VariableTracker, + user_pre_hooks: VariableTracker, + ) -> "BackwardHookVariable": + if not compiled_autograd.compiled_autograd_enabled: + unimplemented( + gb_type="Module-level backwards hooks require compiled autograd.", + context="", + explanation="", + hints=[ + "Enable compiled autograd by setting torch._dynamo.config.compiled_autograd = True." + ], + ) + + def _in_graph_bw_hooks( + bw_state: BackwardState, + ) -> torch.utils.hooks.BackwardHook: + """ + Rather than installing the user hooks in the graph (which + don't survive AotAutograd), we install hooks that will call + trace_wrapped in the backward pass that CompiledAutograd + can turn into actual hook calls. + """ + return torch.utils.hooks.BackwardHook( + None, + ( + functools.partial( + trace_wrapped, + fn=call_module_hooks_from_backward_state, + bw_state=bw_state, + hooks_name=user_hooks_name, + module_name=module_name, + ), + ), + ( + functools.partial( + trace_wrapped, + fn=call_module_hooks_from_backward_state, + bw_state=bw_state, + hooks_name=user_pre_hooks_name, + module_name=module_name, + ), + ), + ) + + module_name, bw_state_proxy = tx.output.add_backward_state_hook(module, "mod") + user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks) + user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks) + proxy = tx.output.create_proxy( + "call_function", + _in_graph_bw_hooks, + (bw_state_proxy,), + {}, + ) + proxy.node.meta["example_value"] = torch.utils.hooks.BackwardHook(None, (), ()) + return BackwardHookVariable(proxy, module, user_hooks, user_pre_hooks) + + def __init__( + self, + proxy: torch.fx.Proxy, + module: VariableTracker, + user_hooks: VariableTracker, + user_pre_hooks: VariableTracker, + **options: Any, + ) -> None: + super().__init__(**options) + self.proxy = proxy + self.module = module + self.user_hooks = user_hooks + self.user_pre_hooks = user_pre_hooks + + def as_proxy(self) -> torch.fx.Proxy: + return self.proxy + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name in ("setup_input_hook", "setup_output_hook"): + return self._setup_hook(tx, name, *args, **kwargs) + return super().call_method(tx, name, args, kwargs) + + def _setup_hook( + self, tx: "InstructionTranslator", hook_method_name: str, args: VariableTracker + ) -> VariableTracker: + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + hook_method_name, + (self.as_proxy(), args.as_proxy()), + {}, + ), + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..9638278300bcf7df327cdd338d927c35f6b6cdad --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py @@ -0,0 +1,3059 @@ +""" +Function-related variable tracking classes for Dynamo's symbolic execution. + +This module contains classes that track different types of functions during graph +compilation, including: +- User-defined functions and methods +- Built-in functions and methods +- Wrapped functions (e.g. from decorators) +- Special function types (e.g. functools.partial) +- Triton kernels and related function types + +These classes are responsible for: +- Tracking function calls and their arguments +- Managing function closures and cell variables +- Handling function attributes and special methods +- Maintaining guards for function identity and closure contents +- Supporting function inlining and specialization +- Enabling proper symbolic execution of different function types + +The variable trackers here work together with the rest of Dynamo to enable +accurate graph capture while handling Python's various function-related behaviors. +""" + +import builtins +import functools +import inspect +import itertools +import logging +import sys +import traceback +import types +from collections import namedtuple +from collections.abc import Callable, Sequence +from types import CellType, FunctionType +from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar +from typing_extensions import Never +from weakref import WeakKeyDictionary + +import torch +from torch._dynamo.exc import get_stack_above_dynamo +from torch._guards import Source +from torch.utils._pytree import is_namedtuple_class + +from .. import config, graph_break_hints, polyfills, variables +from ..bytecode_transformation import create_call_function, create_rot_n, is_generator +from ..exc import ( + format_skip_frame_message, + get_dynamo_observed_exception, + handle_observed_exception, + InfiniteGeneratorError, + ObservedException, + ObservedGeneratorExit, + ObservedUserStopIteration, + raise_observed_exception, + SkipFrame, + StepUnsupported, + unimplemented, + Unsupported, +) +from ..guards import GuardBuilder, install_guard +from ..source import ( + AttrSource, + ClosureSource, + CollectionsSource, + ConstantSource, + DefaultsSource, + GetItemSource, + SkipGuardSource, + TorchSource, + TypeSource, +) +from ..utils import ( + check_constant_args, + check_unspec_or_constant_args, + cmp_name_to_op_mapping, + identity, + is_function, + is_wrapper_or_member_descriptor, + istype, + make_cell, +) +from .base import ( + AsPythonConstantNotImplementedError, + AttributeMutationNew, + raise_type_error_exc, + ValueMutationNew, + VariableTracker, +) +from .constant import ConstantVariable + + +try: + from torch.distributed.fsdp._fully_shard import _fsdp_param_group +except ModuleNotFoundError: + _fsdp_param_group = None # type: ignore[assignment] + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import ( + InliningGeneratorInstructionTranslator, + InliningInstructionTranslator, + InstructionTranslator, + InstructionTranslatorBase, + ) + from torch._dynamo.variables.ctx_manager import ContextWrappingVariable + from torch._higher_order_ops.triton_kernel_wrap import ( + TritonGridType, + TritonKernelType, + ) + + from .lists import BaseListVariable, ListVariable + from .tensor import TensorVariable + + +_F = TypeVar("_F", bound=Callable[..., Any]) +CO_VARARGS = 0x04 +CO_VARKEYWORDS = 0x08 +_SUPPORTED_TREE_MAP_KWARGS = frozenset({"namespace", "none_is_leaf", "is_leaf"}) +_TREE_MAP_ONLY_SUPPORTED_KWARGS = frozenset({"is_leaf"}) + + +# Module-level cache keyed by the function object +_spec_cache: WeakKeyDictionary[Any, Any] = WeakKeyDictionary() + + +@functools.lru_cache +def get_pytree_SUPPORTED_NODES_source(): + return AttrSource( + AttrSource(AttrSource(TorchSource(), "utils"), "_pytree"), "SUPPORTED_NODES" + ) + + +class FunctionSpec: + def __init__(self, func: FunctionType): + code = func.__code__ + vn = code.co_varnames + + self.posonly_count = code.co_posonlyargcount + self.arg_count = code.co_argcount + self.kwonly_count = code.co_kwonlyargcount + + self.posonly_names = vn[: self.posonly_count] + self.pos_or_kw_names = vn[self.posonly_count : self.arg_count] + self.all_pos_names = self.posonly_names + self.pos_or_kw_names + self.kwonly_names = vn[self.arg_count : self.arg_count + self.kwonly_count] + + off = self.arg_count + self.kwonly_count + self.varargs_name = vn[off] if code.co_flags & CO_VARARGS else None + off += 1 if self.varargs_name else 0 + self.varkw_name = vn[off] if code.co_flags & CO_VARKEYWORDS else None + + def update_defaults(self, func: FunctionType) -> None: + # Defaults can change from function call to function call. So re-update + # them on every call. + self.defaults = func.__defaults__ or () + self.kwdefaults = func.__kwdefaults__ or {} + + # Map positional-default names → their index in self.defaults + self.pos_default_map = dict( + zip(self.all_pos_names[-len(self.defaults) :], range(len(self.defaults))) + ) + + +def _get_spec(func: FunctionType) -> FunctionSpec: + spec = _spec_cache.get(func) + if spec is None: + spec = FunctionSpec(func) + _spec_cache[func] = spec + return spec + + +def bind_args_cached( + func: FunctionType, + tx: "InstructionTranslator", + fn_source: Optional[Source], + args: Sequence[Any], + kwargs: dict[str, Any], +) -> dict[str, VariableTracker]: + spec = _get_spec(func) + spec.update_defaults(func) + ba = {} + rem_kw = dict(kwargs) + + # 1) Bind all positional (pos-only + pos-or-kw) + # 1.1) Apply pos-defaults first (maybe overridden later) + for name, idx in spec.pos_default_map.items(): + default_source = None + if fn_source and not ( + ConstantVariable.is_literal(spec.defaults[idx]) + and config.skip_guards_on_constant_func_defaults + ): + default_source = DefaultsSource(fn_source, idx) + ba[name] = wrap_bound_arg(tx, spec.defaults[idx], default_source) + # 1.2) Fill in provided positional args + for i, name in enumerate(spec.all_pos_names): + if i < len(args): + # Maybe override pos-defaults applied above + ba[name] = wrap_bound_arg(tx, args[i]) + elif name in rem_kw and ( + # `kwargs` can have the same key as a pos-only arg `name`. + # If this case happens, we should not consume the `name` here and + # keep it in `kwargs`: + # >>> def fn(a, /, **kwargs): return (a, kwargs) + # >>> fn(1, a=2) + # (1, {'a': 2}) + name not in spec.posonly_names + ): + # Maybe override pos-defaults applied above + ba[name] = wrap_bound_arg(tx, rem_kw.pop(name)) + elif name not in ba: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create( + f"Missing required positional argument: {name}" + ) + ], + ) + + # 2) *args + extra = args[len(spec.all_pos_names) :] + if spec.varargs_name: + ba[spec.varargs_name] = wrap_bound_arg(tx, tuple(extra)) + elif extra: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create( + f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}" + ) + ], + ) + + # 3) Keyword-only + for name in spec.kwonly_names: + if name in rem_kw: + ba[name] = wrap_bound_arg(tx, rem_kw.pop(name)) + elif name in spec.kwdefaults: + kwdefault_source = None + if fn_source: + kwdefault_source = DefaultsSource(fn_source, name, is_kw=True) + ba[name] = wrap_bound_arg(tx, spec.kwdefaults[name], kwdefault_source) + else: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create( + f"Missing required keyword-only argument: {name}" + ) + ], + ) + + # 4) **kwargs + if spec.varkw_name: + ba[spec.varkw_name] = wrap_bound_arg(tx, rem_kw) + elif rem_kw: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create(f"Unexpected keyword arguments: {list(rem_kw)}") + ], + ) + + return ba + + +def wrap_bound_arg( + tx: "InstructionTranslator", val: Any, source: Optional[Source] = None +) -> VariableTracker: + # Source propagation is best effort since not every object we encounter has a source to begin with. + if isinstance(val, VariableTracker): + return val + elif not source: + return VariableTracker.build(tx, val) + else: + # Create a lazy variable to avoid guarding on __defaults__ unless really + # needed. + return variables.LazyVariableTracker.create(val, source) + + +def wrap_args_kwargs(tx: "InstructionTranslator", result: dict[str, Any]) -> None: + for k, v in list(result.items()): + if isinstance(v, (tuple, dict)): + # args/kwargs + result[k] = wrap_bound_arg(tx, v) + + +def init_cellvars( + parent: "InstructionTranslator", + result: dict[str, VariableTracker], + code: types.CodeType, +) -> None: + """ + Update `result` to add mapping from local name to new cells created + directly by `code`, or update SideEffects in `parent` if the a local cell is + already in `result` (cell argument). + """ + side_effects = parent.output.side_effects + + for name in code.co_cellvars: + new_cell = side_effects.track_cell_new() + if name in result: + # This handles when a function argument is a cell (e.g., captured by + # a nested func). See `MAKE_CELL` bytecode for more info. + side_effects.store_cell(new_cell, result.pop(name)) + result[name] = new_cell + + +def _create_nested_fn( + code: types.CodeType, + f_globals: dict[str, Any], + name: str, + defaults: Optional[tuple[object, ...]], + closure: Optional[tuple[CellType]], + kwdefaults: Optional[dict[str, Any]], + annotations: Optional[dict[str, Any]], +) -> types.FunctionType: + from types import FunctionType + + func = FunctionType(code, f_globals, name, defaults, closure) + func.__kwdefaults__ = kwdefaults + + if isinstance(annotations, tuple): + from itertools import pairwise + + annotations = dict(pairwise(annotations)) + + # TypeError: __annotations__ must be set to a dict object + assert annotations is None or isinstance(annotations, dict) + func.__annotations__ = annotations # type: ignore[assignment] + + return func + + +fn_known_dunder_attrs = { + "__annotations__", + "__defaults__", + "__kwdefaults__", + "__code__", + "__globals__", + "__closure__", + "__doc__", +} + + +def fn_var_getattr( + tx: "InstructionTranslator", fn: object, source: Optional[Source], name: str +) -> VariableTracker: + source = source and AttrSource(source, name) + + if source and name == "__annotations__": + # We get a large number of silly guards from annotations from inspect + # module. Changing annotations is rare, and it impacting the extracted + # graph is even rarer. So skip guards. + source = SkipGuardSource(source) + + subobj = None + try: + subobj = inspect.getattr_static(fn, name) + except AttributeError: + # function does not have a __getattr__ or __getattribute__ method, + # so we can safely assume that this attribute is absent + raise_observed_exception(AttributeError, tx) + + # Special handling for known dunder attributes + if name in fn_known_dunder_attrs: + subobj = getattr(fn, name) + if source: + return variables.LazyVariableTracker.create(subobj, source) + return VariableTracker.build(tx, subobj) + + +class BaseUserFunctionVariable(VariableTracker): + def get_filename(self) -> str: + return self.get_code().co_filename # type: ignore[attr-defined] + + def get_name(self) -> str: + return self.get_code().co_name # type: ignore[attr-defined] + + def get_globals(self): + raise NotImplementedError + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # Ignore patch_track_step_called from torch/optim/lr_scheduler.py - it just patches + # the optimizer.step method and we don't need to trace it + if ( + self.get_name() == "patch_track_step_called" + and self.get_filename().endswith("torch/optim/lr_scheduler.py") + ): + return ConstantVariable.create(None) + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) # type: ignore[attr-defined] + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + result = False + + try: + result = hasattr(self.get_function(), name) # type: ignore[attr-defined] + except NotImplementedError: + if name == "__name__" and isinstance(self, NestedUserFunctionVariable): + result = True + return variables.ConstantVariable.create(result) + + def closure_vars(self, tx: "InstructionTranslator") -> dict[str, VariableTracker]: + return {} + + # Override to set whether or not nested graph breaks should be allowed + # if we create an inlining tx for this BaseUserFunctionVariable. + # See symbolic_convert.py for where this function is called. + def should_allow_nested_graph_breaks(self): + return True + + +class UserFunctionVariable(BaseUserFunctionVariable): + """Some unsupported user-defined global function""" + + _nonvar_fields = { + "fn", + "is_constant", + *BaseUserFunctionVariable._nonvar_fields, + } + + _TREE_MAP_MODULES = frozenset( + { + "optree", + "optree.ops", + "torch.utils._pytree", + "torch.utils._cxx_pytree", + } + ) + + @classmethod + def create_with_source(cls, value: Any, source: Any) -> "UserFunctionVariable": + install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) + return cls(value, source=source) + + def __init__( + self, + fn: types.FunctionType | torch.jit.ScriptFunction, # type: ignore[type-arg] + is_constant: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if getattr(fn, "_dynamo_marked_constant", False): + # This method should be treated as a constant for the purposes of compilation + self.is_constant = True + else: + self.is_constant = False + + # TODO putting this here to avoid duplication, because we could hit this + # from several paths (e.g., SuperVariable or `var_getattr`s). + if not isinstance(fn, (types.FunctionType, torch.jit.ScriptFunction)): + unimplemented( + gb_type="can't handle functions not implemented in python ", + context=f"{fn}", + explanation="Dynamo can only handle functions defined in python", + hints=[ + "Move usage of this function out of `torch.compile` region", + *graph_break_hints.INFERENCE_MODE, + ], + ) + # TODO(anijain2305) - Replace directly calling UserFunctionVariable with + # VariableBuilder, which handles the wrapping of _torchdynamo_inline. + # unpack @torch._dynamo.optimize()(fn) wrapped function + fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) + self.fn = fn + + def as_python_constant(self) -> Any: + if istype(self, UserFunctionVariable): + return self.fn + # subclasses (such as methods) usually aren't a constant + return super().as_python_constant() + + def self_args(self) -> list[VariableTracker]: + return [] + + def get_function(self) -> types.FunctionType: + return self.fn + + def get_code(self) -> types.CodeType: + return self.fn.__code__ + + def python_type(self) -> type: + return types.FunctionType + + def has_self(self) -> bool: + return getattr(self.fn, "__self__", None) is not None + + def get_globals(self) -> dict[str, Any]: + return self.fn.__globals__ + + def get_source(self) -> Source: + source = self.source + + if source and isinstance(self, variables.UserMethodVariable): + source = self.source_fn # type: ignore[assignment] + return source # type: ignore[return-value] + + def bind_args( + self, + parent: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> dict[str, VariableTracker]: + """ + Assume `args` and `kwargs` are VariableTracker arguments for a call to + this function, create new bindings for initial locals. + """ + assert not self.is_constant + + fn: types.FunctionType = self.fn + + if not isinstance(fn, FunctionType): + raise TypeError("Only supports regular Python functions.") + root_tx = parent.output.root_tx + + source = self.get_source() + result = bind_args_cached(fn, root_tx, source, args, kwargs) # type: ignore[arg-type] + + init_cellvars(parent, result, fn.__code__) + closure = self.fn.__closure__ or () + assert len(closure) == len(self.fn.__code__.co_freevars) + for idx, name, cell in zip( + itertools.count(), self.fn.__code__.co_freevars, closure + ): + # TODO refactor these 3 branches. + side_effects = parent.output.side_effects + if cell in side_effects: + cell_var = side_effects[cell] + + elif source: + closure_cell = GetItemSource(ClosureSource(source), idx) + closure_cell_contents = AttrSource(closure_cell, "cell_contents") + try: + contents_var = VariableTracker.build( + parent, cell.cell_contents, closure_cell_contents + ) + except ValueError: + # Cell has not yet been assigned + contents_var = variables.DeletedVariable() + cell_var = side_effects.track_cell_existing( + closure_cell, cell, contents_var + ) + + else: + # TODO figure out why source isn't available here, and whether + # we can fix that and remove this branch. + try: + contents_var = VariableTracker.build(parent, cell.cell_contents) + except ValueError: + # Cell has not yet been assigned + contents_var = variables.DeletedVariable() + cell_var = side_effects.track_cell_existing(None, cell, contents_var) + + result[name] = cell_var + + return result + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + source = self.get_source() + return fn_var_getattr(tx, self.fn, source, name) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + result = hasattr(self.fn, name) + return variables.ConstantVariable.create(result) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # Handle patch_dynamo_config call + if self.fn is torch._dynamo.patch_dynamo_config: + try: + args_const = [arg.as_python_constant() for arg in args] + kwargs_const = { + key: val.as_python_constant() for key, val in kwargs.items() + } + changes = torch._dynamo.patch_dynamo_config( + *args_const, **kwargs_const + ).changes + return variables.DynamoConfigPatchVariable(changes) + except AsPythonConstantNotImplementedError as e: + raise RuntimeError( + "Cannot convert patch_dynamo_config args/kwargs to constants. " + "Please fix your call to patch_dynamo_config by using simpler inputs. " + f"args: {args}, kwargs: {kwargs}" + ) from e + elif self.fn is torch._dynamo.error_on_graph_break: + try: + bound = inspect.signature(self.fn).bind(*args, **kwargs) + error_on_graph_break = bound.arguments[ + "error_on_graph_break" + ].as_python_constant() + assert isinstance(error_on_graph_break, bool) + return variables.ErrorOnGraphBreakVariable(error_on_graph_break) + except Exception as e: + raise RuntimeError( + "Improper error_on_graph_break() call. Please fix your call to error_on_graph_break(). " + f"args: {args}, kwargs: {kwargs}" + ) from e + # Handle a `nonstrict_trace(fn)` call + elif self.fn is torch._dynamo.nonstrict_trace: + bound = inspect.signature(self.fn).bind(*args, **kwargs) + fn_var = bound.args[0] + if not isinstance(fn_var, BaseUserFunctionVariable): + typ = fn_var.python_type() + msg = f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>" + unimplemented( + gb_type="TypeError from user code", + context=f"call_function({self.value}, {args}, {kwargs})", # type: ignore[attr-defined] + explanation=msg, + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + if not isinstance(fn_var, UserFunctionVariable): + fn_name = fn_var.get_name() + msg = f"Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # noqa: B950 + unimplemented( + gb_type="Limitation of `nonstrict_trace", + context=f"{self}", + explanation=msg, + hints=[ + f"make sure definition of {fn_name} is outside ", + "`torch.compile` region", + ], + ) + # pyrefly: ignore[missing-attribute] + fn = fn_var.fn + return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True) + + if self.is_constant: + return invoke_and_store_as_constant( + tx, self.fn, self.get_name(), args, kwargs + ) + + if ( + not tx.output.current_tracer.unsafe_allow_externally_visible_side_effects + and self.fn + is torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer + ): + with torch._dynamo.side_effects.allow_externally_visible_side_effects_in_subtracer( + tx + ): + return super().call_function(tx, args, kwargs) + + if ( + getattr(tx.output.current_tracer, "description", None) + == "torch.utils.checkpoint.checkpoint" + and not tx.output.current_tracer.allow_side_effects_in_hop + ): + try: + from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState + except Exception: + FSDPState = None # type: ignore[assignment, misc] + if FSDPState is not None and self.fn in [ + FSDPState._pre_forward, + FSDPState._post_forward, + ]: + with torch._dynamo.side_effects.allow_side_effects_in_hop(tx): + return super().call_function(tx, args, kwargs) + + tree_map_result = self._maybe_call_tree_map_fastpath(tx, args, kwargs) + if tree_map_result is not None: + return tree_map_result + + return super().call_function(tx, args, kwargs) + + def _maybe_call_tree_map_fastpath( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> Optional[VariableTracker]: + rewrite = self._rewrite_tree_map_only_call(tx, args, kwargs) + if rewrite is not None: + tree_map_fn, tree_map_args, tree_map_kwargs = rewrite + else: + tree_map_fn = self + tree_map_args = args + tree_map_kwargs = kwargs + + if not ( + isinstance(tree_map_fn, UserFunctionVariable) + and tree_map_fn._is_tree_map_function() + and not ({*tree_map_kwargs} - _SUPPORTED_TREE_MAP_KWARGS) + and len(tree_map_args) >= 2 + ): + return None + + map_fn = tree_map_args[0] + first_tree = tree_map_args[1] + rest = tree_map_args[2:] + return first_tree.call_tree_map( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + def _is_tree_map_function(self) -> bool: + return ( + getattr(self.fn, "__name__", None) == "tree_map" + and getattr(self.fn, "__module__", None) in self._TREE_MAP_MODULES + ) + + def _is_tree_map_only_function(self) -> bool: + return ( + getattr(self.fn, "__name__", None) == "tree_map_only" + and getattr(self.fn, "__module__", None) in self._TREE_MAP_MODULES + ) + + def _rewrite_tree_map_only_call( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> Optional[ + tuple[ + "UserFunctionVariable", + Sequence[VariableTracker], + dict[str, VariableTracker], + ] + ]: + if not self._is_tree_map_only_function(): + return None + + if len(args) != 3: + return None + if {*kwargs} - _TREE_MAP_ONLY_SUPPORTED_KWARGS: + return None + + type_selector, map_fn, tree_arg = args + allowed_types = self._extract_tree_map_only_types(type_selector) + if allowed_types is None: + return None + + tree_map_callable = self._lookup_tree_map_function() + if tree_map_callable is None: + return None + + wrapped_map_fn = TreeMapOnlyFunctionVariable( + allowed_types, + map_fn, + source=getattr(map_fn, "source", None), + ) + tree_map_variable = variables.UserFunctionVariable(tree_map_callable) + return tree_map_variable, [wrapped_map_fn, tree_arg], dict(kwargs) + + def _lookup_tree_map_function(self) -> Optional[types.FunctionType]: + module_name = getattr(self.fn, "__module__", None) + if not module_name: + return None + module = sys.modules.get(module_name) + if module is None: + return None + tree_map = getattr(module, "tree_map", None) + if isinstance(tree_map, types.FunctionType): + return tree_map + return None + + def _extract_tree_map_only_types( + self, selector: VariableTracker + ) -> Optional[tuple[type, ...]]: + if not selector.is_python_constant(): + return None + try: + raw_value = selector.as_python_constant() + except NotImplementedError: + return None + + flattened = self._flatten_type_spec(raw_value) + if not flattened: + return None + if not all(isinstance(typ, type) for typ in flattened): + return None + return tuple(dict.fromkeys(flattened)) + + def _flatten_type_spec(self, value: Any) -> Optional[list[type]]: + if isinstance(value, type): + return [value] + if isinstance(value, tuple): + collected: list[type] = [] + for entry in value: + flat = self._flatten_type_spec(entry) + if flat is None: + return None + collected.extend(flat) + return collected + union_type = getattr(types, "UnionType", None) + if union_type is not None and isinstance(value, union_type): + collected = [] + for entry in value.__args__: + flat = self._flatten_type_spec(entry) + if flat is None: + return None + collected.extend(flat) + return collected + return None + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.fn) + + def is_python_equal(self, other): + return isinstance(other, variables.UserFunctionVariable) and self.fn is other.fn + + +class TreeMapOnlyFunctionVariable(BaseUserFunctionVariable): + _nonvar_fields = { + "allowed_types", + *BaseUserFunctionVariable._nonvar_fields, + } + + def __init__( + self, + allowed_types: tuple[type, ...], + map_fn: VariableTracker, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.allowed_types = allowed_types + self.map_fn = map_fn + + def python_type(self) -> type: + return FunctionType + + def _matches_allowed_type(self, node: VariableTracker) -> bool: + try: + node_type = node.python_type() + except NotImplementedError: + return False + return any(issubclass(node_type, allowed) for allowed in self.allowed_types) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not args: + return self.map_fn.call_function(tx, args, kwargs) + leaf = args[0] + if self._matches_allowed_type(leaf): + return self.map_fn.call_function(tx, args, kwargs) + if len(args) != 1 or kwargs: + # Defer to the original map function so we fall back to normal + # tracing instead of triggering a graph break. + return self.map_fn.call_function(tx, args, kwargs) + return leaf + + +class BuiltinMethodVariable(BaseUserFunctionVariable): + def __init__( + self, fn: types.BuiltinMethodType, is_constant: bool = False, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + assert isinstance(fn, types.BuiltinMethodType) + self.fn = fn + + @staticmethod + def is_supported_builtin_method(obj: Any) -> bool: + method_self = obj.__self__ + method_name = obj.__name__ + + # TODO(anijain2305) - Add support for more builtin methods + # Supports tuple.__new__ and frozenset({....}).__contains__ + return (method_self is tuple and method_name == "__new__") or ( + type(method_self) is frozenset and method_name == "__contains__" + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + method_self = self.fn.__self__ + name = self.fn.__name__ + obj_source = self.source and AttrSource(self.source, "__self__") + obj_vt = VariableTracker.build(tx, method_self, obj_source) + return obj_vt.call_method(tx, name, args, kwargs) + + +class LocalGeneratorObjectVariable(VariableTracker): + def __init__( + self, + code: types.CodeType, + f_globals: dict[str, Any], + inline_tracer: "InliningGeneratorInstructionTranslator", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.code = code + self.f_globals = f_globals + self.inline_tracer = inline_tracer + + def get_code(self) -> types.CodeType: + return self.code + + def get_filename(self) -> str: + return self.get_code().co_filename + + def get_name(self) -> str: + return self.get_code().co_name + + def get_function(self) -> Never: + raise NotImplementedError + + def has_self(self) -> bool: + return False + + def __name__(self) -> str: + return self.get_name() + + def __str__(self) -> str: + return f"{self.__class__.__name__}({self.get_name()})" + + __repr__ = __str__ + + def reconstruct(self, codegen: "PyCodegen") -> None: + from torch._dynamo.side_effects import disallow_side_effects_in_generator + from torch._dynamo.symbolic_convert import ( + InstructionTranslator, + save_and_restart_speculation_log, + temporarely_allow_writes_to_output_graph, + ) + + tx = InstructionTranslator.current_tx() + save = save_and_restart_speculation_log(tx) + disallow = disallow_side_effects_in_generator(tx) + temp = temporarely_allow_writes_to_output_graph(tx) + + with save, disallow, temp: + tracer = self.inline_tracer + if not tracer.generator_exhausted: + self.remaining_items = self.force_unpack_var_sequence(tx) + variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen) + + def bind_args( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> dict[str, VariableTracker]: + return self.vt.bind_args(tx, args, kwargs) # type: ignore[attr-defined] + + def get_globals(self) -> dict[str, Any]: + return self.f_globals + + def python_type(self) -> type: + return types.GeneratorType + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + tracer = self.inline_tracer + + if self._is_generator_exhausted(): + raise_observed_exception(StopIteration, tx) + + try: + # Hierarchically, tx can be seen as the parent of the inline tracer + # created on call_function. Any exception needs to be propagated to tx + # for Dynamo to behave correctly + return tracer.inline_call_() + except ObservedException as e: + tracer.generator_exhausted = True + raise e + except InfiniteGeneratorError: + # test/dynamo/test_misc.py::test_iterator_limit + raise + except Unsupported as e: + torch._dynamo.eval_frame.skip_code(self.get_code()) + raise SkipFrame from e + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if name in self.python_type().__dict__: + return ConstantVariable.create(True) + return ConstantVariable.create(False) + + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return False + + def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return True + + def force_unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list[VariableTracker]: + result: list[VariableTracker] = [] + self.force_apply_to_var_sequence(tx, result.append) + return result + + def force_apply_to_var_sequence( + self, tx: "InstructionTranslator", fn: Callable[[VariableTracker], Any] + ) -> None: + while True: + try: + fn(self.next_variable(tx)) + except ObservedUserStopIteration: + handle_observed_exception(tx) + break + + # no nested graph breaks in generators + def should_allow_nested_graph_breaks(self): + return False + + def _setup_exception( + self, tx: "InstructionTranslator", exc: VariableTracker + ) -> None: + tracer = self.inline_tracer + try: + tracer._raise_exception_variable(exc) + except ObservedException as e: + # if no handler is available (i.e. user code doesn't catch it), the + # exception is raised again. + tracer.exception_handler(e) + + def _is_generator_just_started(self) -> bool: + return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0 + + def _is_generator_exhausted(self) -> bool: + return getattr(self.inline_tracer, "generator_exhausted", False) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__next__": + return self.next_variable(tx) + elif name == "__iter__": + # iter(gen) returns itself + return self + elif name == "send": + # Sends a value into the generator function. Returns the next value + # yielded by the generator, or raises StopIteration if the generator + # exits without yielding another value + if self._is_generator_just_started() and len(args): + # can't send non-None value to a just-started generator + # Test: GeneratorCPythonTests.test_send_non_none_to_new_gen + if not all(arg.is_constant_none() for arg in args): + raise_observed_exception(TypeError, tx) + tracer = self.inline_tracer + tracer.push_many(args) + return self.next_variable(tx) + elif name == "close": + # * Raises a GeneratorExit at the point where the generator function was paused. + # * If the generator function catches the exception and returns a + # value, this value is returned from close() - Python 3.13+ + # * If the generator function is already closed, or raises GeneratorExit + # (by not catching the exception), close() returns None. + # * If the generator yields a value, a RuntimeError is raised. + # * If the generator raises any other exception, it is propagated to the caller. + # * If the generator has already exited due to an exception or normal + # exit, close() returns None and has no other effect. + + # Return None if close is called on a just-started generator + # See test GeneratorCloseCpythonTests::test_close_not_started + + tracer = self.inline_tracer + if self._is_generator_just_started() or self._is_generator_exhausted(): + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + + # Raise GeneratorExit to see if user code catches it. Any other exception + # is propagated to the parent frame. + try: + self._setup_exception( + tx, variables.ExceptionVariable(GeneratorExit, ()) + ) + # There's an extra block on Python 3.12+ to handle StopIteration + # see: https://github.com/python/cpython/blob/8f93dd8a8f237b277abad20d566df90c5cbd7f1e/Objects/genobject.c#L394-L397 + # + # 1 0 RETURN_GENERATOR + # 2 POP_TOP + # 4 RESUME 0 + + # 2 6 LOAD_CONST 1 (1) + # 8 YIELD_VALUE 1 + # 10 RESUME 1 + # 12 POP_TOP + # 14 RETURN_CONST 0 (None) + # >> 16 CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR) + # 18 RERAISE 1 + # ExceptionTable: + # 4 to 14 -> 16 [0] lasti + if ( + sys.version_info >= (3, 12) + and tracer.next_instruction.opname == "CALL_INTRINSIC_1" + ): + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + except ObservedGeneratorExit: + # If it doesn't catch, we just return None, as per the text above + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + + try: + # Raise RuntimeError if the generator yields any other value + if self.next_variable(tx): + raise_observed_exception(RuntimeError, tx) + except ObservedGeneratorExit: + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + except ObservedUserStopIteration: + # In Python 3.13+, one can capture GeneratorExit and return a value + # See test_generator.py::test_close_capture_GeneratorExit_return + # https://discuss.python.org/t/let-generator-close-return-stopiteration-value/24786/26 + # https://github.com/python/cpython/pull/104771 + assert tracer.symbolic_result is not None + return tracer.symbolic_result + elif name == "throw": + # * Raises an exception at the point where the generator was paused, and + # returns the next value yielded by the generator. + # * If the generator exits without yielding, raise StopIteration + # * If the generator function does not catch the passed-in exception, + # or raises a different exception, then that exception propagates to the caller. + + # Setup the exception table and jump target in case of try...finally + tracer = self.inline_tracer + try: + # In Python 3.9, the exception is represented as a triple (typ, val, tb) + # In such cases, we re-raise the exception object given to avoid + # creating a new object, so that IS_OP works. + # See: https://github.com/pytorch/pytorch/pull/146496 + self._setup_exception(tx, args[1] if len(args) == 3 else args[0]) + except ObservedException: # noqa: TRY203 + # propagate the exception back to the parent caller + raise + + retval = self.next_variable(tx) + + # The exception raised before is still active. We need to check the exception + # table one more time to find the next target. But why? Let's walk + # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M + # + # z = 0 + # def whoo(): + # global z + # z = 0 + # try: + # yield 1 + # except ValueError: + # yield 2 + # finally: + # z += 1 + # z += 10 + # + # gen = whoo() + # next(gen) + # gen.throw(ValueError) + # print('z', z) -> z = 1 + # + # ... + # >> 58 PUSH_EXC_INFO + # + # 8 60 LOAD_GLOBAL 2 (ValueError) + # 70 CHECK_EXC_MATCH + # 72 POP_JUMP_IF_FALSE 7 (to 88) + # 74 POP_TOP + # + # 9 76 LOAD_CONST 3 (2) + # 78 YIELD_VALUE 3 <------ ValueError is still active here + # 80 RESUME 1 + # 82 POP_TOP + # 84 POP_EXCEPT + # 86 jump_backward 34 (to 20) + # ... + # + # ExceptionTable: + # 4 to 8 -> 124 [0] lasti + # 12 to 18 -> 58 [0] + # 20 to 56 -> 124 [0] lasti + # 58 to 82 -> 90 [1] lasti <------ move to 90 + # 84 to 86 -> 96 [0] + # 88 to 88 -> 90 [1] lasti + # 90 to 94 -> 96 [0] + # 96 to 116 -> 118 [1] lasti + # 118 to 122 -> 124 [0] lasti + # + # In this scenario, a generator can yield after `throw()` is called. Even + # after the exception is raised a few lines above, it remains active + # within the `78 YIELD_VALUE` instruction. When the generator resumes + # after the second yield on instruction `80 RESUME`, we cannot simply + # return the control flow to the next instruction. Instead, one must + # check the exception table (or equivalent) to find the next target + # In this case, it says the instruction pointer must be moved to 90. + # + # Without this step, if we let the trace proceed to the next + # instruction, it would follow the control flow where the exception + # raised by `throw()` was handled and swallowed, potentially leading + # to incorrect behavior. + exc_type = type("__InternalThrowException", (Exception,), {}) + + try: + self._setup_exception(tx, variables.ExceptionVariable(exc_type, ())) + self.next_variable(tx) + except get_dynamo_observed_exception(exc_type): + # We should get back the exception raised before. + pass + else: + raise_observed_exception(RuntimeError, tracer) + return retval + + return super().call_method(tx, name, args, kwargs) + + +class ContextlibContextManagerLocalGeneratorObjectVariable( + LocalGeneratorObjectVariable +): + """ + .. note:: + + This is only used when the function is annotated with @contextlib.contextmanager + + It is a special case of a generator function as we do not allow return a context manager + from a torch.compile function. + """ + + +class LocalGeneratorFunctionVariable(BaseUserFunctionVariable): + """functions that behaves like iterators + + .. note:: + + This is a wrapper around (Nested)UserFunctionVariable + """ + + def __init__( + self, + vt: VariableTracker, + *, + generator_cls: type = LocalGeneratorObjectVariable, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.vt = vt + self.generator_cls = generator_cls + + def __getattr__(self, name): + if name in self.__class__.__dict__: + return getattr(self, name) + return getattr(self.vt, name) + + def get_globals(self) -> dict[str, Any]: + return self.vt.get_globals() # type: ignore[attr-defined] + + def _build_inline_tracer( + self, + tx: "InstructionTranslatorBase", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "InliningInstructionTranslator": + from torch._dynamo.symbolic_convert import InliningInstructionTranslator + + return InliningInstructionTranslator.build_inline_tracer( + tx, + self, + args, + kwargs, + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not is_generator(self.vt.get_code()): # type: ignore[attr-defined] + unimplemented( + gb_type="non-generator contextlib.contextmanager", + context=str(self.vt.get_code()), # type: ignore[attr-defined] + explanation="Cannot compile function decorated with `@contextlib.contextmanager` that is not a generator" + ", i.e. does not use `yield`", + hints=[ + "Use `yield` in the function body instead of `return`.", + "Remove the `@contextlib.contextmanager` decorator.", + ], + ) + + inline_tracer = self._build_inline_tracer(tx, list(args), kwargs) + code = self.vt.get_code() # type: ignore[attr-defined] + f_globals = self.vt.get_globals() # type: ignore[attr-defined] + + # calling a generator returns a generator object + return self.generator_cls( + code, + f_globals, + inline_tracer, # type: ignore[arg-type] + source=self.source, + ) + + +class FunctionDecoratedByContextlibContextManagerVariable( + LocalGeneratorFunctionVariable +): + """ + .. note:: + + This is only used when the function is annotated with @contextlib.contextmanager + """ + + def __init__(self, vt: VariableTracker, **kwargs: Any): + super().__init__( + vt, + generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable, + **kwargs, + ) + + def _build_inline_tracer( + self, + tx: "InstructionTranslatorBase", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "InliningGeneratorInstructionTranslator": + # NOTE: This only exists to not break support for context manager when + # config.enable_faithful_generator_behavior = False and + # config.enable_trace_contextlib = True. In case the former is false, + # Dynamo should still be able to trace through @contextmanager functions + tracer = super()._build_inline_tracer(tx, args, kwargs) + assert isinstance( + tracer, + torch._dynamo.symbolic_convert.InliningGeneratorInstructionTranslator, + ) + tracer.is_generator_from_ctx_manager = True + return tracer + + +class UserMethodVariable(UserFunctionVariable): + """Some unsupported user-defined method""" + + def __init__( + self, + fn: Callable[..., Any], + obj: VariableTracker, + source_fn: Optional[Callable[..., Any]] = None, + **kwargs: Any, + ) -> None: + super().__init__(fn=fn, **kwargs) # type: ignore[arg-type] + self.obj = obj + self.source_fn = source_fn + # Note on source and source_fn + # Be careful with `source` when delegating to UserFunctionVariable + # (base-class) methods. In this __init__, `source` is a *bound method* + # object, but the base class expects the underlying *function* object. + # One way is to simplly use `__func__` to unwrap it. + # + # For recursive dict-tag optimizations, it can be faster to fetch the + # function directly from `cls.__dict__`; that's why we pass on + # `source_fn`. Whenever it is possible to access the function from + # cls.__dict__, we pass that on to `source_fn`. Because bind_args + # operates on the unbound function, most guards should target + # `source_fn` rather than the original `source`. + if source_fn is None and kwargs.get("source") is not None: + self.source_fn = AttrSource(kwargs.get("source"), "__func__") # type: ignore[assignment, arg-type] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.fn}, {self.obj})" + + def self_args(self) -> list[VariableTracker]: + return [self.obj] + + def python_type(self) -> type[types.MethodType]: + return types.MethodType + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # NOTE this is to handle methods annotated by `nonstrict_trace`. + # a `nonstrict_trace`-ed function will be wrapped by + # `VariableTracker.build` and route to `TorchInGraphFunctionVariable`, + # but in the case of method, we manually wrap it with `UserMethodVariable` + # inside `UserDefinedObjectVariable.var_getattr`. + # + # We might be able to simplify this away by canonicalizing the + # function/method wrapping code paths. + from ..trace_rules import is_nonstrict_trace_callable + + if is_nonstrict_trace_callable(self.fn): + call_args = [*self.self_args(), *args] + var = variables.TorchInGraphFunctionVariable( + self.fn, nonstrict_traceable=True + ) + return var.call_function(tx, call_args, kwargs) + + # For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution + # rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method + # since we ensure `forward` of allowed modules can be traced by AOT safely. + # Note this is not only for allowed modules, as user customized modules can extend from + # allowed modules but using parent's `forward` method, which is also covered by this branch. + + # If we are tracing the higher order op, we want Dynamo to step inside + # the module call so that Dynamo can see the underlying parameters and + # buffers and raise them as inputs to the graph. The is_root_tracer + # check bypasses the if condition for non-root tracers and directly + # calls the super().call_function at the end, which is basically + # equivalent of inlining the method. + if tx.output.is_root_tracer() and isinstance( + self.obj, variables.NNModuleVariable + ): + module_attr = getattr(self.fn, "__module__", "") + # inline torch.nn.utils.parametrize + if ( + module_attr is not None + and module_attr.startswith("torch.nn.") + and module_attr != "torch.nn.utils.parametrize" + or self.is_constant + ): + return self.obj.call_method( + tx, self.fn.__name__, list(args), kwargs, constant=self.is_constant + ) + elif ( + _fsdp_param_group is not None + and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state # type: ignore[attr-defined] + ): + return variables.TorchCtxManagerClassVariable(self.fn).call_function( + tx, (self.obj, *args), kwargs + ) + if self.is_constant: + fn = getattr(self.obj.value, self.fn.__name__) # type: ignore[attr-defined] + return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs) + return super().call_function(tx, args, kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "__self__": + return self.obj + if name == "__func__": + # We might have a better way to access the function object, this + # information is stored in self.source_fn, use that to construct the + # variable tracker. + return VariableTracker.build(tx, self.fn, self.source_fn) # type: ignore[arg-type] + return super().var_getattr(tx, name) + + +class WrappedUserMethodVariable(UserMethodVariable): + def __init__( + self, + wrapped: UserMethodVariable, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: + kwargs.pop("fn", None) + kwargs.pop("obj", None) + super().__init__(wrapped.fn, wrapped.obj, **kwargs) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + +class WrappedUserFunctionVariable(UserFunctionVariable): + def __init__( + self, + wrapped: UserFunctionVariable, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: + kwargs.pop("fn", None) + super().__init__(wrapped.fn, **kwargs) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + +def invoke_and_store_as_constant( + tx: "InstructionTranslator", + fn: Callable[..., Any], + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], +) -> VariableTracker: + def convert(x: VariableTracker) -> Any: + if x.is_tensor(): + return cast("TensorVariable", x).get_real_value() + return x.as_python_constant() + + args = [convert(x) for x in args] + kwargs = {k: convert(v) for k, v in kwargs.items()} + res = fn(*args, **kwargs) + return tx.output.register_attr_or_module( + res, + name, + source=ConstantSource(name), + ) + + +class NestedUserFunctionVariable(BaseUserFunctionVariable): + _nonvar_fields = { + "f_globals", + *BaseUserFunctionVariable._nonvar_fields, + } + + def __init__( + self, + fn_name: VariableTracker, + code: VariableTracker, + f_globals: dict[str, Any], + defaults: Optional[VariableTracker], + kwdefaults: Optional[VariableTracker], + annotations: Optional[VariableTracker], + closure: Optional[VariableTracker], + # This is present when this function is created by + # `functools.wrap(wrapped_fn)(this_fn)`. + wrapped_fn: Optional[VariableTracker] = None, + **kwargs: Any, + ) -> None: + if kwargs.get("mutation_type") is None: + kwargs.update(mutation_type=AttributeMutationNew()) + super().__init__(**kwargs) + assert isinstance(fn_name.as_python_constant(), str) + assert isinstance(code.as_python_constant(), types.CodeType) + assert isinstance(f_globals, dict) + self.fn_name = fn_name + self.code = code + self.f_globals = f_globals + self.defaults = defaults + self.kwdefaults = kwdefaults + self.annotations = annotations + self.closure = closure + self.wrapped_fn: Optional[VariableTracker] = wrapped_fn + + def self_args(self) -> list[VariableTracker]: + return [] + + def get_code(self) -> types.CodeType: + return self.code.as_python_constant() + + def python_type(self) -> type: + return types.FunctionType + + def get_function(self) -> types.FunctionType: + if self.closure: + raise NotImplementedError + func = types.FunctionType( + self.code.as_python_constant(), + self.f_globals, + self.fn_name.as_python_constant(), + ) + if self.defaults: + func.__defaults__ = self.defaults.as_python_constant() + if self.kwdefaults: + func.__kwdefaults__ = self.kwdefaults.as_python_constant() + if self.annotations: + annotations = self.annotations.as_python_constant() + if isinstance(annotations, tuple): + from itertools import pairwise + + annotations = dict(pairwise(annotations)) + + # TypeError: __annotations__ must be set to a dict object + assert isinstance(annotations, dict) + func.__annotations__ = annotations + return func + + def call_setattr( + self, + tx: "InstructionTranslator", + name_var: VariableTracker, + val: VariableTracker, + ) -> VariableTracker: + tx.output.side_effects.store_attr(self, name_var.value, val) # type: ignore[attr-defined] + return ConstantVariable(None) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__setattr__": + return self.call_setattr(tx, *args) + return super().call_method(tx, name, list(args), kwargs) + + def has_closure(self) -> bool: + return self.closure is not None + + def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: + if name == "__name__": + return self.get_name() + if name == "__code__": + return self.get_code() + if name == "__defaults__": + d = getattr(self, "defaults", None) + return d.as_python_constant() if d else None + return super().const_getattr(tx, name) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if name == "__code__": + return variables.ConstantVariable.create(hasattr(self, "code")) + if name == "__defaults__": + return variables.ConstantVariable.create(hasattr(self, "defaults")) + return super().call_obj_hasattr(tx, name) + + def has_self(self) -> bool: + return False + + def get_globals(self) -> dict[str, Any]: + return self.f_globals + + def bind_args( + self, + parent: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> dict[str, VariableTracker]: + code = self.get_code() + func = types.FunctionType( + code, + self.f_globals, + self.fn_name.as_python_constant(), + tuple(self.defaults.items) if self.defaults else None, # type: ignore[attr-defined] + tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))), + ) + if self.kwdefaults: + func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant() # type: ignore[attr-defined] + bound = inspect.signature(func).bind(*args, **kwargs) + bound.apply_defaults() + result = dict(bound.arguments.items()) + wrap_args_kwargs(parent.output.root_tx, result) # type: ignore[arg-type] + init_cellvars(parent, result, code) + + for idx, name in enumerate(code.co_freevars): + assert name not in result + cell = self.closure.items[idx] # type: ignore[attr-defined, union-attr] + result[name] = cell + + return result + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from(__name__, "_create_nested_fn") + ) + codegen(self.code) + codegen.extend_output([codegen.create_load_const_unchecked(self.f_globals)]) + codegen(ConstantVariable.create(self.code.value.co_name)) # type: ignore[attr-defined] + + if self.defaults: + codegen(self.defaults) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.closure: + codegen(self.closure) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.kwdefaults: + codegen(self.kwdefaults) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.annotations: + try: + annotations = self.annotations.as_python_constant() + codegen.extend_output( + [codegen.create_load_const_unchecked(annotations)] + ) + except NotImplementedError: + codegen(self.annotations) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + codegen.extend_output(create_call_function(7, False)) + + if self.wrapped_fn: + codegen.add_push_null( + lambda: codegen.load_import_from("functools", "wraps") + ) + codegen(self.wrapped_fn) + codegen.extend_output(create_call_function(1, False)) + codegen.extend_output(create_rot_n(2)) + codegen.extend_output(create_call_function(1, True)) + + # codegen attributes + from torch._dynamo.symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + if tx.output.side_effects.has_pending_mutation(self): + for name, value in tx.output.side_effects.store_attr_mutations[ + self + ].items(): + codegen.dup_top() + codegen(value) + codegen.extend_output(create_rot_n(2)) + codegen.store_attr(name) + + +class WrappedNestedUserFunctionVariable(NestedUserFunctionVariable): + def __init__( + self, + wrapped: Any, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: + kwargs.pop("fn_name", None) + kwargs.pop("code", None) + kwargs.pop("f_globals", None) + kwargs.pop("defaults", None) + kwargs.pop("kwdefaults", None) + kwargs.pop("annotations", None) + kwargs.pop("closure", None) + kwargs.pop("wrapped_fn", None) + super().__init__( + wrapped.fn_name, + wrapped.code, + wrapped.f_globals, + wrapped.defaults, + wrapped.kwdefaults, + wrapped.annotations, + wrapped.closure, + wrapped.wrapped_fn, + ) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + +class SkipFunctionVariable(VariableTracker): + _nonvar_fields = { + "value", + "reason", + *VariableTracker._nonvar_fields, + } + + def __init__(self, value: Any, reason: Optional[str] = None, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.value = value + self.reason = reason + + def as_python_constant(self) -> Any: + return self.value + + @classmethod + def create_with_source(cls, value: Any, source: Source) -> "SkipFunctionVariable": + # Use closure match guard (i.e. guard on __code__ object instead of + # function id) to avoid guarding on nested functions. + if inspect.getattr_static(value, "_torchdynamo_disable", False): + # For torch._dynamo.disable function, ensure that the original + # function is guarded. Otherwise, the else branch will guard on the + # _dynamo.disable.__code__ + guard_on_source = source + guard_on_value = value + + while getattr(guard_on_value, "_torchdynamo_orig_callable", False): + guard_on_value = guard_on_value._torchdynamo_orig_callable + guard_on_source = AttrSource( + guard_on_source, "_torchdynamo_orig_callable" + ) + + guard_on_source.make_guard(GuardBuilder.CLOSURE_MATCH) + elif inspect.isbuiltin(value): + install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) + elif not is_wrapper_or_member_descriptor(value): + # These descriptors are not guaranteed to return the same object on + # attribute lookup. They are unlikely to be changed, so we can skip + # guarding them. + install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) + return cls(value, source=source) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if inspect.getattr_static(self.value, "_torchdynamo_disable", False): + msg = inspect.getattr_static(self.value, "_torchdynamo_disable_msg", None) + unimplemented( + gb_type="Skip calling `torch.compiler.disable()`d function", + context=str(self.value), + explanation=f"Skip calling function `{self.value}` since it was wrapped " + f"with `torch.compiler.disable` (reason: {msg})", + hints=[ + "Remove the `torch.compiler.disable` call", + ], + ) + elif self.value is torch._dynamo.graph_break: + graph_break_msg = kwargs.get("msg") + if graph_break_msg: + graph_break_msg = graph_break_msg.as_python_constant() + unimplemented( + gb_type="Call to `torch._dynamo.graph_break()`", + context=f"Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`", + explanation=f"User-inserted graph break. Message: {graph_break_msg}", + hints=[ + "Remove the `torch._dynamo.graph_break()` call.", + ], + ) + elif self.value is torch._dynamo.skip_frame: + skip_frame_msg = kwargs.get("msg") + if skip_frame_msg: + skip_frame_msg = skip_frame_msg.as_python_constant() + else: + skip_frame_msg = "" + raise SkipFrame( + format_skip_frame_message( + tx.f_code, + f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}", + ) + ) + elif self.value is torch._dynamo.step_unsupported: + raise StepUnsupported + else: + if config.dont_skip_tracing: + from .builder import SourcelessBuilder + + # re-build the function, attempting to not skip + rebuilt_fn = SourcelessBuilder.create(tx, self.value) + # if we still get SkipFunctionVariable, then we *really* should skip this function + if not isinstance(rebuilt_fn, SkipFunctionVariable): + return rebuilt_fn.call_function(tx, args, kwargs) + qualname = getattr(self.value, "__qualname__", "") + module_or = getattr(self.value, "__module__", None) + module_name = "" if module_or is None else str(module_or) + try: + path = inspect.getfile(self.value) + explanation = ( + f"Dynamo developers have intentionally marked that the function `{qualname}` " + f"in file `{path}` should not be traced." + ) + hints = [ + f"Avoid calling the function `{qualname}`.", + ] + # TODO improve trace_rules reasoning to provide better hints. + # How do we tell that a function/file should NOT be removed from skip files? + # Do a very basic check for now. + if "_dynamo" not in path: + hints += [ + f"Apply `@torch._dynamo.dont_skip_tracing` to the function `{qualname}` " + "to force tracing into the function. " + "More graph breaks may occur as a result of attempting to trace into the function.", + "Please file an issue to PyTorch.", + ] + except TypeError: + known_python_builtin_modules = {"_abc", "_warnings"} + if module_or in known_python_builtin_modules: + explanation = ( + f"Dynamo does not know how to trace the Python builtin " + f"`{module_name}.{qualname}`." + ) + hints = [ + "If you are attempting to call a logging function (e.g. `_warnings.warn`), " + "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.", + "Please file an issue on GitHub " + "so the PyTorch team can add support for it. ", + ] + elif module_or is not None and module_or.startswith("optree"): + explanation = f"Dynamo cannot trace optree C/C++ function {module_name}.{qualname}." + hints = [ + " Consider using torch.utils._pytree - " + "https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py" + ] + # also warn on it because most users won't see the graph break message + torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints)) + else: + explanation = ( + f"Dynamo does not know how to trace the builtin `{module_name}.{qualname}.` " + f"This function is either a Python builtin (e.g. _warnings.warn) " + f"or a third-party C/C++ Python extension (perhaps created with pybind)." + ) + hints = [ + "If it is a Python builtin, please file an issue on GitHub " + "so the PyTorch team can add support for it and see the next case for a workaround.", + "If it is a third-party C/C++ Python extension, please " + "either wrap it into a PyTorch-understood custom operator " + "(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html " + "for more details) or, if it is traceable, use " + "`torch.compiler.allow_in_graph`.", + ] + # also warn on it because most users won't see the graph break message + torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints)) + if qualname == "allow_in_graph": + explanation = ( + "Found an allow_in_graph decorator to a function which " + "is created inside the parent function that is getting " + "compiled. This is not supported for now." + ) + hints = [] + reason = self.reason if self.reason else "" + unimplemented( + gb_type="Attempted to call function marked as skipped", + context=f"module: {module_name}, qualname: {qualname}, skip reason: {reason}", + explanation=explanation, + hints=hints, + ) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + return variables.ConstantVariable.create(hasattr(self.value, name)) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + + return fn_var_getattr(tx, self.value, self.source, name) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +class WrappedSkipFunctionVariable(SkipFunctionVariable): + def __init__( + self, + wrapped: VariableTracker, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: + kwargs.pop("value", None) + kwargs.pop("reason", None) + super().__init__(wrapped.value, reason=wrapped.reason, **kwargs) # type: ignore[attr-defined] + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + +class WrapperUserFunctionVariable(VariableTracker): + """ + Used to represent a wrapper object that contains the actual callable as an + attribute. For example, torch.jit.script/trace have the original function at + their _torchdynamo_inline attribute. Similarly, functions with + __script_if_tracing_wrapper have the original attr at "__original_fn". + """ + + def __init__(self, wrapper_obj: Any, attr_to_trace: str, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.wrapper_obj = wrapper_obj + self.attr_to_trace = attr_to_trace + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == self.attr_to_trace: + val = getattr(self.wrapper_obj, self.attr_to_trace) + source = self.source and AttrSource(self.source, name) + return VariableTracker.build(tx, val, source) + + return super().var_getattr(tx, name) + + def self_args(self) -> list[VariableTracker]: + return [] + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if hasattr(self.wrapper_obj, "cache_info"): + target_fn = getattr(self.wrapper_obj, self.attr_to_trace, None) + module_name = getattr(target_fn, "__module__", "") or "" + + if module_name.split(".", maxsplit=1)[0] != "torch": + msg = ( + "Dynamo detected a call to a `functools.lru_cache`-wrapped " + "function. Dynamo ignores the cache wrapper and directly " + "traces the wrapped function. Silent incorrectness is only " + "a *potential* risk, not something we have observed. " + 'Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.' + ) + + torch._dynamo.utils.warn_once(msg) + + dynamo_logger = torch._dynamo.utils.logging.getLogger("torch._dynamo") + if dynamo_logger.isEnabledFor(logging.DEBUG): + user_stack = torch._guards.TracingContext.extract_stack() + user_stack = get_stack_above_dynamo() + user_stack + frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) + user_stack_formatted = "".join(traceback.format_list(user_stack)) + user_stack_trace = f"call to a lru_cache wrapped function at: {frame_loc[0]}:{frame_loc[1]}\n" + user_stack_trace += str(user_stack_formatted) + dynamo_logger.debug(user_stack_trace) + + all_args = self.self_args() + list(args) + return variables.UserFunctionVariable( + polyfills.getattr_and_trace # type: ignore[arg-type] + ).call_function( + tx, + [self, variables.ConstantVariable(self.attr_to_trace), *all_args], + kwargs, + ) + + +class WrapperUserMethodVariable(WrapperUserFunctionVariable): + """ + Similar to WrapperUserFunctionVariable, but for methods. The only delta is + saving the vt for `self` object of the method which is then used by + WrapperUserFunctionVariable in `call_function` method. + """ + + def __init__( + self, + wrapper_obj: Any, + attr_to_trace: str, + self_obj: VariableTracker, + **kwargs: Any, + ) -> None: + super().__init__(wrapper_obj, attr_to_trace, **kwargs) + self.obj = self_obj + + def self_args(self) -> list[VariableTracker]: + return [self.obj] + + +def _traceable_collective_remaps() -> dict[Any, Any]: + # We can't rely on importing from distributed, since it's not always built + if torch.distributed.is_available(): + from torch.distributed._functional_collectives import ( + traceable_collective_remaps, + ) + + return traceable_collective_remaps + return {} + + +def _traceable_collectives_source( + tx: "InstructionTranslator", fn: Callable[..., Any] +) -> AttrSource: + assert torch.distributed.is_available(), "Illegal invocation." + assert fn in _traceable_collective_remaps().values() + + inner_name = fn.__name__ + path_source = tx.import_source("torch.distributed._functional_collectives") + return AttrSource(path_source, inner_name) + + +class CollectiveFunctionRewriteVariable(UserFunctionVariable): + """ + Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives. + + This class provides both a way to check if a function is remappable, and perform the remapping. + + In the case that a function is 'remappable' but only for some combinations of call-time arguments, + we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse + than status-quo as we currently graph-break on all distributed.* collectives. + """ + + def __init__( + self, + fn: Callable[..., Any], + *, + replacement_var: UserFunctionVariable, + **kwargs: Any, + ) -> None: + super().__init__(fn, **kwargs) # type: ignore[arg-type] + assert isinstance(replacement_var, UserFunctionVariable) + self.replacement_var = replacement_var + + @staticmethod + def create( + tx: "InstructionTranslator", + old_fn: Callable[..., Any], + source: Source, + **options: Any, + ) -> "CollectiveFunctionRewriteVariable": + new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn) + return CollectiveFunctionRewriteVariable( + old_fn, + replacement_var=UserFunctionVariable(new_fn, source=new_source, **options), + source=source, + **options, + ) + + @staticmethod + def can_rewrite(variable: Any) -> bool: + return ( + inspect.isfunction(variable) and variable in _traceable_collective_remaps() + ) + + @staticmethod + def rewrite( + tx: "InstructionTranslator", fn: Callable[..., Any] + ) -> tuple[Any, AttrSource]: + new_fn = _traceable_collective_remaps()[fn] + return new_fn, _traceable_collectives_source(tx, new_fn) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # call_function must check any unsupported arguments and graph-break. + # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn, + # since that's the contract for putting a mapping in `traceable_collective_remaps` + import torch.distributed as dist + from torch.distributed._functional_collectives import REDUCE_OP_TO_STR + + # Merge args into kwargs so positional and keyword args + # can be processed the same way. + signature = inspect.signature(self.fn) + kwargs = dict(signature.bind(*args, **kwargs).arguments) + args = () + + if "async_op" in kwargs and kwargs["async_op"].as_python_constant(): + unimplemented( + gb_type="async_op=True for distributed collectives", + context=f"{self.fn}, {args=}, {kwargs=}", + explanation=f"`torch.compile` doesn't support `async_op=True for {self.fn}", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + if self.fn in ( + dist.all_reduce, + dist.reduce_scatter_tensor, + dist._reduce_scatter_base, + ): + reduce_op_var = kwargs.get("op") + reduce_op = ( + reduce_op_var.value # type: ignore[attr-defined] + if reduce_op_var is not None + else signature.parameters["op"].default + ) + if reduce_op not in REDUCE_OP_TO_STR: + raise ValueError(f"Unsupported all_reduce op: {reduce_op}") + kwargs["op"] = variables.ConstantVariable.create( + REDUCE_OP_TO_STR[reduce_op] + ) + return self.replacement_var.call_function(tx, args, kwargs) + + +class FunctoolsWrapsVariable(UserFunctionVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not kwargs and len(args) == 1: + + def wraps(fn: Any) -> VariableTracker: + if isinstance(fn, variables.NestedUserFunctionVariable): + return fn.clone(wrapped_fn=args[0]) + unimplemented( + gb_type="functools.wraps", + context=f"{fn}", + explanation="`torch.compile` can't trace `functools.wraps` on functions defined outside the compile region", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + return variables.LambdaVariable(wraps) + + return super().call_function(tx, args, kwargs) + + +class CollectionsNamedTupleFunction(UserFunctionVariable): + def as_python_constant(self) -> Any: + return self.fn + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + constant_args = check_constant_args(args, kwargs) + if constant_args: + try: + value = self.fn( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + except TypeError as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + return variables.UserDefinedClassVariable( + # pyrefly: ignore[unbound-name] + value, + mutation_type=ValueMutationNew(), + ) + unimplemented( + gb_type="namedtuple construction", + context=f"{args=}, {kwargs=}", + explanation="`torch.compile` only support certain input types for namedtuple", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + +class FunctoolsPartialVariable(VariableTracker): + def __init__( + self, + func: VariableTracker, + args: Sequence[VariableTracker], + keywords: dict[str, VariableTracker], + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.func = func + assert isinstance(args, list) + self.args = args + assert isinstance(keywords, dict) + self.keywords = keywords + # fake_value is used for id calculation. Creating this value and id'ng + # on it is sufficient for the tracing purposes. + self.fake_value = functools.partial(identity) + + def python_type(self) -> type: + return functools.partial + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial")) + codegen(self.func) + if self.args: + codegen.foreach(self.args) + if not self.keywords: + codegen.extend_output(create_call_function(len(self.args) + 1, False)) + return + + codegen.foreach(self.keywords.values()) + keys = tuple(self.keywords.keys()) + codegen.extend_output( + codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, False) + ) + + def get_function(self) -> Any: + return self.as_python_constant() + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + merged_args = self.args + list(args) + merged_kwargs = {**self.keywords, **kwargs} + return self.func.call_function(tx, merged_args, merged_kwargs) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + # functools.partial uses slots, so attributes are constant + return variables.ConstantVariable.create( + hasattr(functools.partial(identity), name) + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + source = self.source and AttrSource(self.source, name) + # Handle __slots__ + if name == "func": + return self.func + if name == "args": + return variables.ListVariable(self.args, source=source) + if name == "keywords": + items = {ConstantVariable.create(k): v for k, v in self.keywords.items()} + return variables.ConstDictVariable(items, source=source) + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + raise_observed_exception(AttributeError, tx) + + def as_python_constant(self) -> Any: + return functools.partial( + self.func.as_python_constant(), + *[arg.as_python_constant() for arg in self.args], + **{k: v.as_python_constant() for k, v in self.keywords.items()}, + ) + + def guard_as_python_constant(self) -> Any: + """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" + return functools.partial( + self.func.guard_as_python_constant(), + *[v.guard_as_python_constant() for v in self.args], + **{k: v.guard_as_python_constant() for k, v in self.keywords.items()}, + ) + + def is_python_hashable(self) -> bool: + return ( + self.func.is_python_hashable() + and all(arg.is_python_hashable() for arg in self.args) + and all(value.is_python_hashable() for value in self.keywords.values()) + ) + + def get_python_hash(self): + func_hash = self.func.get_python_hash() + args_hash = (arg.get_python_hash() for arg in self.args) + values_hash = (value.get_python_hash() for value in self.keywords.values()) + return hash((func_hash, *args_hash, *values_hash)) + + def is_python_equal(self, other): + return ( + self.func.is_python_equal(other.func) + and all( + arg_a.is_python_equal(arg_b) + for (arg_a, arg_b) in zip(self.args, other.args) + ) + and all( + value_a.is_python_equal(value_b) + for (value_a, value_b) in zip( + self.keywords.values(), other.keywords.values() + ) + ) + ) + + +class PolyfilledFunctionVariable(VariableTracker): + _nonvar_fields = { + "fn", + "wrapped_fn", + "traceable_fn", + *VariableTracker._nonvar_fields, + } + + @classmethod + @functools.cache + def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]: + return {} + + @classmethod + def create_with_source( + cls, value: Any, source: Source + ) -> "PolyfilledFunctionVariable": + install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) + + return cls(value, source=source) + + def __init__(self, fn: _F, **kwargs: Any) -> None: + super().__init__(**kwargs) + # pyrefly: ignore[invalid-type-var] + self.fn: _F = fn + + handler = self._get_polyfill_handlers().get(fn, fn) + traceable_fn = None + assert callable(handler), f"Polyfill handler {handler} is not callable for {fn}" + for candidate_attr in ( + "__torch_dynamo_polyfill__", # registered polyfill + "__python_implementation__", # self handler from third-party libraries + ): + candidate = getattr(handler, candidate_attr, None) + if candidate: + assert callable(candidate) + traceable_fn = candidate + break + else: + raise RuntimeError( + f"Polyfill handler {handler} does not have a traceable function" + ) + # pyrefly: ignore[invalid-type-var] + self.wrapped_fn = handler + # pyrefly: ignore[invalid-type-var] + self.traceable_fn: _F = traceable_fn + + @property + def polyfill_fn(self) -> Callable[..., Any]: + return self.traceable_fn + + def can_constant_fold_through(self) -> bool: + return getattr( + self.wrapped_fn, "__torch_dynamo_can_constant_fold_through__", False + ) + + def get_function(self) -> Any: + return self.as_python_constant() + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.can_constant_fold_through() and check_unspec_or_constant_args( + args, kwargs + ): + result = ( + self.fn( # use the original function which is faster than the polyfill + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + ) + return VariableTracker.build(tx, result) + + # Special case for sum on tuple/list of ints + if ( + self.fn is builtins.sum + and len(args) == 1 + and not kwargs + and isinstance(args[0], (variables.ListVariable, variables.TupleVariable)) + and all( + (x.is_python_constant() and isinstance(x.as_python_constant(), int)) + or (isinstance(x, variables.SymNodeVariable) and x.python_type() is int) + for x in args[0].items + ) + ): + return variables.SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", + torch.sym_sum, + (tuple(a.as_proxy() for a in args[0].items),), + {}, + ), + sym_num=torch.sym_sum( + [ + ( + x.as_python_constant() + if x.is_python_constant() + else x.sym_num # type: ignore[attr-defined] + ) + for x in args[0].items + ] + ), + ) + + traceable_function_variable = VariableTracker.build(tx, self.traceable_fn) + return traceable_function_variable.call_function(tx, args, kwargs) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__call__": + return self.call_function(tx, args, kwargs) + + method = getattr(self.fn, name, None) + if not (method or is_function(method)): + raise_type_error_exc(tx, f"Cannot find callable {name} in {self.fn}") + options = {} + if self.source: + options["source"] = AttrSource(self.source, name) + # pyrefly: ignore[bad-specialization] + polyfilled_method_variable = PolyfilledFunctionVariable(method, **options) + return polyfilled_method_variable.call_function(tx, args, kwargs) + + def as_python_constant(self) -> Any: + return self.fn + + +class TracebackVariable(VariableTracker): + # We don't track traceback. A call to any function in this module is a no-op + def call_function( # type: ignore[empty-body] + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: ... + + +class SysFunctionVariable(VariableTracker): + def __init__(self, value: Any, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.value = value + + def exc_info(self, tx: "InstructionTranslator") -> "variables.TupleVariable": + if len(tx.exn_vt_stack): + exn = tx.exn_vt_stack[-1] + typ = exn.exc_type # type: ignore[union-attr] + tb = None + items = [ + VariableTracker.build(tx, typ), + exn, + VariableTracker.build(tx, tb), + ] + else: + items = [ + variables.ConstantVariable(None), + variables.ConstantVariable(None), + variables.ConstantVariable(None), + ] + return variables.TupleVariable(items) # type: ignore[arg-type] + + def exception(self, tx: "InstructionTranslator") -> VariableTracker: + return self.exc_info(tx).items[1] + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.value is sys.exc_info: + return self.exc_info(tx) + assert self.value is sys.exception + return self.exception(tx) + + +from torch._higher_order_ops.triton_kernel_wrap import ( + create_tma_experimental_metadata, + create_tma_stable_metadata, + TMADescriptorMetadata, + TritonHOPifier, +) + + +class DynamoTritonHOPifier(TritonHOPifier): + def raise_unsupported(self, msg: str) -> Never: + unimplemented( + gb_type="triton kernel unsupported feature", + context="", + explanation=f"Encountered triton kernel unsupported feature: {msg}", + hints=[], + ) + + def is_callable(self, maybe_callable: VariableTracker) -> bool: + return isinstance( + maybe_callable, (NestedUserFunctionVariable, UserFunctionVariable) + ) + + def get_value(self, val: VariableTracker) -> Any: + return val.value # type: ignore[attr-defined] + + def check_grid(self, grid: "BaseListVariable") -> tuple[torch.fx.proxy.Proxy, ...]: + from .lists import BaseListVariable + + if isinstance(grid, BaseListVariable): + return grid.as_proxy() + else: + unimplemented( + gb_type="unsupported grid type for triton hop check_grid", + context=f"grid type = {type(grid)}", + explanation="`torch.compile` only supports list-like grid for check_grid", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + def call_grid( + self, grid: Any, meta: dict[str, Any], tx: "InstructionTranslator" + ) -> Any: + meta_var = {variables.ConstantVariable.create(k): v for k, v in meta.items()} + grid = grid.call_function(tx, [meta_var], {}) + return grid + + # We use this function to wrap call_prune_configs + def call_user_defined_fn( + self, + user_fn: Callable[..., Any], + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + tx: Optional["InstructionTranslator"], + variable: Any, + ) -> VariableTracker: + from .builder import SourcelessBuilder + + wrapped_user_function = SourcelessBuilder.create(tx, user_fn) # type: ignore[arg-type] + result = wrapped_user_function.call_function(tx, args, kwargs) + return result + + def wrap_user_defined_obj( + self, + user_obj: Any, + tx: Optional["InstructionTranslator"], + variable: Any, + name: str, + ) -> VariableTracker: + from .builder import VariableBuilder + + wrapped_user_obj = VariableBuilder( + tx, AttrSource(variable.kernel_source, f"{name}") + )._wrap(user_obj) + return wrapped_user_obj + + def maybe_unpack_configs( + self, configs: Any, tx: Optional["InstructionTranslator"] + ) -> list[Any]: + # unpack the list of configs + configs = configs.unpack_var_sequence(tx) + + # guard_as_python_constant inserts guards for Dynamo to check if the configs object changed. + configs = [config.guard_as_python_constant() for config in configs] + + return configs + + def maybe_unpack_heuristic_result(self, result: VariableTracker) -> Any: + if not result.is_python_constant(): + self.raise_unsupported( + "@triton.heuristics must return constant values because configs can only contain constant values." + ) + + return result.guard_as_python_constant() + + # We need to override call_getitem here so that we can add the source in the case + # where we call the triton kernel with a grid + def call_getitem( # type: ignore[override] + self, + variable: "TritonKernelVariable", + args: Sequence[Any], + ) -> "TritonKernelVariable": + # __getitem__ should only be called if we don't already have a grid + # Only grid needs to be passed + if variable.grid is not None or len(args) != 1: + self.raise_unsupported( + "Triton kernels should be called with only a single grid" + ) + return type(variable)( + kernel=variable.kernel, + kernel_idx=variable.kernel_idx, + grid=args[0], + kernel_source=variable.source, + ) + + def call_HOP( + self, + variable: "TritonKernelVariable", + grids: Any, + combined_args_raw: dict[str, Any], + tx: "InstructionTranslator", + ) -> "variables.ConstantVariable": + from .dicts import ConstDictVariable + + # as we can only pass tensors as non-const args in fx graph, + # here we replace TMA descriptors + # (TMADescriptorExperimentalVariable and TMADescriptorStableVariable + # instances) with the underlying tensors, while moving the + # TMA descriptor-related metadata to a separate argument, + # so that we can reconstruct the TMA descriptors downstream + tma_descriptor_metadata: TMADescriptorMetadata = {} + for k in list(combined_args_raw.keys()): + v = combined_args_raw[k] + if isinstance( + v, (TMADescriptorExperimentalVariable, TMADescriptorStableVariable) + ): + tma_descriptor_metadata[k] = v.to_metadata() + combined_args_raw[k] = v.get_tensor() + + combined_args = { + variables.ConstantVariable.create(k): v + for k, v in combined_args_raw.items() + } + + from torch._higher_order_ops.triton_kernel_wrap import ( + kernel_side_table, + triton_kernel_wrapper_mutation, + ) + + # Combine args and kwargs and pass as a dict so that if user defined triton + # kernel uses variables as 'grid' or 'kernel', it does not conflict with + # parameters of the wrapper function + constant_args = { + k: v.as_python_constant() + for k, v in combined_args_raw.items() + if isinstance(v, VariableTracker) and v.is_python_constant() + } + non_constant_args = { + k: v + for k, v in combined_args.items() + if not (isinstance(v, VariableTracker) and v.is_python_constant()) + } + + for v in non_constant_args.values(): + v = v.realize() + if not (v.is_tensor() or v.is_symnode_like()): + self.raise_unsupported( + f"Unexpected argument type for a Triton kernel: {repr(v)}." + ) + + constant_args_idx = kernel_side_table.add_constant_args(constant_args) + meta = ConstDictVariable(non_constant_args, dict) + tx.output.create_proxy( + "call_function", + triton_kernel_wrapper_mutation, + (), + { + "kernel_idx": variable.kernel_idx, + "constant_args_idx": constant_args_idx, + "grid": grids, + "tma_descriptor_metadata": tma_descriptor_metadata, + "kwargs": meta.as_proxy(), + }, + ) + + return variables.ConstantVariable( + None, + ) + + +dynamo_triton_hopifier_singleton = DynamoTritonHOPifier() + + +class TritonKernelVariable(VariableTracker): + grid: "TritonGridType" + kernel: "TritonKernelType" + kernel_idx: Optional[int] + kernel_source: "AttrSource" + + def __init__( + self, kernel: Any, kernel_idx: Optional[int], grid: Any, **kwargs: Any + ) -> None: + self.kernel_source = kwargs.pop("kernel_source", None) + super().__init__(**kwargs) + dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + return dynamo_triton_hopifier_singleton.call_triton_kernel( # type: ignore[return-value] + self, args, kwargs, tx + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__getitem__": + return dynamo_triton_hopifier_singleton.call_getitem(self, args) + elif name == "run": + return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx) # type: ignore[return-value] + + # Bail out to parent's implementation + return super().call_method(tx, name, args, kwargs) + + def specialize_symbolic(self, arg: Any) -> Any: + from .constant import ConstantVariable + from .tensor import SymNodeVariable + + # See [Note: Specialize tl.constexpr args in user-defined triton kernels] + if isinstance(arg, SymNodeVariable): + return ConstantVariable.create(arg.evaluate_expr()) + return arg + + +class TMADescriptorExperimentalVariable(VariableTracker): + def __init__( + self, + data_ptr: "variables.DataPtrVariable", + dims: list[VariableTracker], + block_dims: list[VariableTracker], + element_size: VariableTracker, + **kwargs: Any, + ) -> None: + assert isinstance(data_ptr, variables.DataPtrVariable) + super().__init__(**kwargs) + self.data_ptr = data_ptr + self.dims = dims + self.block_dims = block_dims + self.element_size = element_size + + def to_metadata(self) -> Any: + return create_tma_experimental_metadata( + [dim.as_proxy() for dim in self.dims], + [dim.as_proxy() for dim in self.block_dims], + self.element_size.as_proxy(), + ) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from( + "triton.tools.experimental_descriptor", + f"create_{len(self.dims)}d_tma_descriptor", + ) + ) + self.data_ptr.reconstruct(codegen) + args = [*self.dims, *self.block_dims, self.element_size] + codegen.foreach(args) + codegen.call_function(len(args) + 1, False) + + def get_tensor(self) -> VariableTracker: + return self.data_ptr.from_tensor + + +class TMADescriptorStableVariable(VariableTracker): + def __init__( + self, + tensor: "TensorVariable", + block_shape: "ListVariable", + **kwargs: Any, + ) -> None: + assert tensor.is_tensor() + super().__init__(**kwargs) + self.tensor = tensor + self.block_shape = block_shape + + def to_metadata(self) -> Any: + return create_tma_stable_metadata( + self.block_shape.as_proxy(), + ) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from( + "triton.tools.tensor_descriptor", + "TensorDescriptor", + ) + ) + codegen.load_method("from_tensor") + self.tensor.reconstruct(codegen) + codegen(self.block_shape) + codegen.call_method(2) + + def get_tensor(self) -> Any: + return self.tensor + + +class CreateTMADescriptorExperimentalVariable(VariableTracker): + def __init__( + self, + rank: int, + **kwargs: Any, + ) -> None: + assert rank in (1, 2) + super().__init__(**kwargs) + self.rank = rank + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + ptr = kwargs["ptr"] if "ptr" in kwargs else args[0] + + if not isinstance(ptr, variables.DataPtrVariable): + raise Unsupported( + "Please ensure there were no graph breaks between " + f"create_{self.rank}d_tma_descriptor and the upstream " + ".data_ptr() call." + ) + + if self.rank == 1: + if len(args) + len(kwargs) != 4: + raise_type_error_exc( + tx, + f"TMA metadata rank=1 requires exactly 4 arguments, got {len(args) + len(kwargs)}", + ) + dims = [ + kwargs["dim"] if "dim" in kwargs else args[1], + ] + block_dims = [ + kwargs["block_dim"] if "block_dim" in kwargs else args[2], + ] + else: + if len(args) + len(kwargs) != 6: + raise_type_error_exc( + tx, + f"TMA metadata rank=2 requires exactly 6 arguments, got {len(args) + len(kwargs)}", + ) + dims = [ + kwargs["dim1"] if "dim1" in kwargs else args[1], + kwargs["dim0"] if "dim0" in kwargs else args[2], + ] + block_dims = [ + kwargs["block_dim1"] if "block_dim1" in kwargs else args[3], + kwargs["block_dim0"] if "block_dim0" in kwargs else args[4], + ] + element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1] + + return TMADescriptorExperimentalVariable( + data_ptr=ptr, + dims=dims, + block_dims=block_dims, + element_size=element_size, + ) + + +class CreateTMADescriptorStableVariable(VariableTracker): + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + tensor = kwargs["tensor"] if "tensor" in kwargs else args[0] + block_shape = kwargs["block_shape"] if "block_shape" in kwargs else args[1] + + return TMADescriptorStableVariable( + tensor=tensor, # type: ignore[arg-type] + block_shape=block_shape, # type: ignore[arg-type] + ) + + +class PyTreeGetNodeTypeFunctionVariable(UserFunctionVariable): + """ + `torch.utils._pytree._get_node_type` function is very hot function. We want to special case it to reduce Dynamo tracing time. + + def _get_node_type(tree: Any) -> Any: + node_type = type(tree) + # All namedtuple types are implicitly registered as pytree nodes. + # XXX: Other parts of the codebase expect namedtuple types always return + # `namedtuple` instead of the actual namedtuple type. Even if the type + # is explicitly registered. + if is_namedtuple_class(node_type): + return namedtuple + return node_type + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if len(args) != 1: + raise_type_error_exc( + tx, + f"pytree_get_node_type requires exactly 1 argument, got {len(args)}", + ) + type_source = None + if args[0].source: + install_guard(args[0].source.make_guard(GuardBuilder.TYPE_MATCH)) + type_source = TypeSource(args[0].source) + python_type = args[0].python_type() + if is_namedtuple_class(python_type): + type_source = AttrSource(CollectionsSource(), "namedtuple") + return VariableTracker.build(tx, namedtuple, type_source) + return VariableTracker.build(tx, python_type, source=type_source) + + +class PyTreeTreeIsLeafFunctionVariable(UserFunctionVariable): + """ + `torch.utils._pytree.tree_is_leaf` function is a hot function. We want to special case it to reduce Dynamo tracing time. + + def tree_is_leaf( + tree: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + ) -> bool: + if is_leaf is not None and is_leaf(tree): + return True + return _get_node_type(tree) not in SUPPORTED_NODES + + When is_leaf is None (the common case), we can optimize by not tracing into the function. + When is_leaf is not None, we fall back to regular tracing since it requires executing user code. + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # tree_is_leaf(tree, is_leaf=None) + if len(args) < 1 or len(args) > 2: + raise_type_error_exc( + tx, + f"tree_is_leaf requires 1 or 2 arguments, got {len(args)}", + ) + + # Check if is_leaf parameter is provided + is_leaf = kwargs.get("is_leaf", ConstantVariable.create(None)) + if len(args) == 2: + is_leaf = args[1] + + if not is_leaf.is_constant_none(): + return super().call_function(tx, args, kwargs) + + # Optimize the case where is_leaf is None + # return _get_node_type(tree) not in SUPPORTED_NODES + tree = args[0] + node_type_var = PyTreeGetNodeTypeFunctionVariable( + torch.utils._pytree._get_node_type + ).call_function(tx, [tree], {}) + + # If the SUPPORTED_NODES was seen earlier and mutated, there would be a + # source and that will give us the mutated SUPPORTED_NODES. + supported_nodes_var = VariableTracker.build( + tx, + torch.utils._pytree.SUPPORTED_NODES, + source=get_pytree_SUPPORTED_NODES_source(), + ) + out = supported_nodes_var.call_method(tx, "__contains__", [node_type_var], {}) + return ConstantVariable.create(not out.value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/higher_order_ops.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/higher_order_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..253386a94eeee02876fc0ce2fc7ca7036f170aaf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/higher_order_ops.py @@ -0,0 +1,4827 @@ +# mypy: ignore-errors + +""" +This module contains classes and utilities for handling higher-order operators in Dynamo. +It provides functionality for tracing and transforming control flow constructs like +conditions (torch.cond), loops (torch.while_loop), maps (torch.ops.higher_order.map), +and other higher-order operations. + +The module includes specialized VariableTracker classes for different types of +higher-order operations, along with utilities for: +- Speculating and capturing subgraphs +- Managing control flow +- Handling autograd function applications +- Supporting function transformations +- Processing activation checkpoints + +These classes work together to enable Dynamo to correctly trace and compile code +containing complex control flow patterns and higher-order functions while preserving +their semantic behavior. +""" + +import contextlib +import functools +import inspect +import itertools +import logging +import types +import warnings +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Literal, Optional, TYPE_CHECKING + +import torch._C +import torch.fx +import torch.nn +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import get_fake_value +from torch._dynamo.variables.builtin import BuiltinVariable +from torch._dynamo.variables.constant import ConstantVariable +from torch._dynamo.variables.ctx_manager import RepararametrizeModuleContextVariable +from torch._dynamo.variables.functions import UserFunctionVariable +from torch._dynamo.variables.nn_module import UnspecializedNNModuleVariable +from torch._dynamo.variables.tensor import SymNodeVariable, TensorVariable +from torch._guards import Source +from torch._ops import HigherOrderOperator +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils import _pytree as pytree + +from .. import graph_break_hints, variables +from ..exc import ( + ObservedException, + UncapturedHigherOrderOpError, + unimplemented, + Unsupported, +) +from ..source import AttrSource, DictGetItemSource +from ..utils import proxy_args_kwargs, set_example_value +from .base import VariableTracker +from .dicts import ConstDictVariable +from .lazy import LazyVariableTracker +from .lists import ListVariable, TupleVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +log = logging.getLogger(__name__) +hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile") + + +@dataclass +class OutputSpec: + """ + Contains the treespec of the output of the speculated subgraph, and the + information to mask out the constant values from the output during + flattening and inserting them back during unflattening. Cleaning up + constants from the graph makes the graph simpler for AOTDispatcher and + Inductor. + """ + + treespec: pytree.TreeSpec + # list of True/False to identify the locations of const values in the + # subgraph output. True means that value at that index is a constant. + masks_to_filter_const_values: Optional[list[bool]] = None + # The actual constant values that were present in the subgraph output. Note + # that this is the same length as the mask, we just look at the indices + # where mask is True. + const_values: Optional[list[Any]] = None + # Number of intermediate nodes that are also made subgraph outputs. + num_intermediate_nodes_as_outputs: int = 0 + + def __post_init__(self): + if ( + self.masks_to_filter_const_values is not None + or self.const_values is not None + ): + assert len(self.masks_to_filter_const_values) == len(self.const_values) + + +def raise_hard_error_if_graph_break(reason): + def deco(fn): + @functools.wraps(fn) + def graph_break_as_hard_error(*args, **kwargs): + try: + return fn(*args, **kwargs) + except (Unsupported, ObservedException) as e: + import sys + + if isinstance(e, Unsupported): + exc = UncapturedHigherOrderOpError( + f"{reason} Got {e.msg}", e.real_stack + ) + else: + msg = e.msg if hasattr(e, "msg") else type(e) + real_stack = e.real_stack if hasattr(e, "real_stack") else None + exc = UncapturedHigherOrderOpError( + f"{reason} Got {msg}", real_stack + ) + raise exc.with_traceback(sys.exc_info()[2]) from None + + return graph_break_as_hard_error + + return deco + + +# This function is a syntax sugar for creating a dummy new subtracer so that +# newly added nodes are added to a separate subgraph in this subtracer instead of affecting +# the main graph. This is useful for creating sample inputs for tracing the subgraph. +# For example, in FlexAttentionHigherOrderVariable, we want to create several scalars +# to trace the score_mod function but we don't want the operators that creates the scalar to +# show up in the graph, we could this function to discard the graph changes. +# Example usage: +# with discard_graph_changes(): +# sample_input= create_sample_inputs() +# speculate_subgraph(tx, f, sample_inputs, {}) +@contextlib.contextmanager +def discard_graph_changes(tx): + ctx = tx.output.subtracer("subgraph_wrapper", None) + try: + ctx.__enter__() + yield + finally: + ctx.__exit__(None, None, None) + + +def check_meta_consistency_vt( + vars1: list[VariableTracker], + vars2: list[VariableTracker], + lhs_name: str, + rhs_name: str, + include_contiguity: bool = True, +) -> None: + from torch._higher_order_ops.utils import check_meta_consistency + + def _unwrap_var(var): + if var.is_tensor(): + return var.proxy.node.meta["example_value"] + elif isinstance(var, SymNodeVariable): + return var.sym_num + elif var.is_python_constant(): + return var.as_python_constant() + else: + unimplemented( + gb_type="cannot unwrap variable for check_meta_consistency", + context=str(var), + explanation=f"Expected {var} to be TensorVariable, SymNodeVariable, or ConstantVariable", + hints=[], + ) + + unwrapped1 = [_unwrap_var(var) for var in vars1] + unwrapped2 = [_unwrap_var(var) for var in vars2] + + return check_meta_consistency( + unwrapped1, + unwrapped2, + lhs_name, + rhs_name, + include_contiguity=include_contiguity, + ) + + +@contextlib.contextmanager +def dynamo_enable_grad(tx: "InstructionTranslator", enable=True): + from . import GradModeVariable + + org_value = torch.is_grad_enabled() + try: + GradModeVariable.create(tx, enable, initialized=True) + yield + finally: + GradModeVariable.create(tx, org_value, initialized=True) + + +@contextlib.contextmanager +def dynamo_allow_side_effects_in_hop(tx: "InstructionTranslator"): + orig_val = tx.output.current_tracer.allow_side_effects_in_hop + try: + tx.output.current_tracer.allow_side_effects_in_hop = True + yield + finally: + tx.output.current_tracer.allow_side_effects_in_hop = orig_val + + +def find_mismatched_vars(var, types, allow_none=False): + """ + Recursively finds variables whose type is not an instance of the specified types. + Args: + var: The variable to check. + types: A tuple of allowed types. + allow_none (bool): Whether to allow None values. Defaults to False. + Returns: + A set of variables whose type is not an instance of the specified types. + """ + mismatched_vars = set() + if isinstance(var, (list, tuple)): + for item in var: + mismatched_vars.update(find_mismatched_vars(item, types, allow_none)) + elif isinstance(var, (TupleVariable, ListVariable)): + for item in var.items: + mismatched_vars.update(find_mismatched_vars(item, types, allow_none)) + elif isinstance(var, ConstDictVariable): + for value in var.items.values(): + mismatched_vars.update(find_mismatched_vars(value, types, allow_none)) + else: + if not isinstance(var, types) and not (allow_none and var.is_constant_none()): + mismatched_vars.add(var) + return mismatched_vars + + +def only_consist_of(var, types, allow_none=False): + mismatch_vars = find_mismatched_vars(var, types, allow_none=allow_none) + return len(mismatch_vars) == 0 + + +# A more read-able syntax sugar for creating a UserFunctionVariable for f +# and run call_function on it. Make it return a function to preserve the calling +# convention of the original f. +def _make_inlined(tx: "InstructionTranslator", f): + assert callable(f), "Expect f to be a python callable." + + def inline_call(*args, **kwargs): + return UserFunctionVariable(f).call_function(tx, args, kwargs) + + return inline_call + + +def _call_function_with_auto_output_flattening( + tx: "InstructionTranslator", + fn: Any, + args: tuple[Any, ...], + kwargs: dict[str, Any], + flat_example_value: Any, + body_r: Optional[VariableTracker], + graph_output_vts: VariableTracker | tuple[VariableTracker, ...], +) -> Optional[VariableTracker]: + """ + Create HOP call node and reproxify output VTs for HOPs with auto output semantics. + + This function is used by HOPs with auto output semantics (see speculate_subgraph_with_auto_output_flattening) + to create the actual HOP call in the FX graph and properly handle the output variable trackers. + + The key operation is "reproxifying" - updating the proxies of the original tensor VTs + (from body_r) to point to the HOP call outputs, ensuring the outer graph correctly + references the HOP outputs while allowing body_r to contain arbitrary Python objects. + + Args: + tx: The instruction translator + fn: The HOP function to call + args: Arguments for the HOP call (typically includes the subgraph node) + kwargs: Keyword arguments for the HOP call + flat_example_value: Example value for the HOP output + body_r: The output VT structure that Dynamo continues tracing with (may be None) + graph_output_vts: Tensor/symint VTs that were actual graph outputs + + Returns: + The body_r VT (unchanged), which Dynamo will continue tracing with + """ + from .builder import wrap_fx_proxy + + # Store the invocation as a call + flat_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn, + args=args, + kwargs=kwargs, + ), + example_value=flat_example_value, + ) + + # wrap_fx_proxy creates fresh variable trackers. However, the main program + # after the speculate subgraph can still use the original tensor vts that + # are still pointing to the nodes present in the subgraph. So, we reproxify + # the original tensor vts with the subgraph outputs. This way, whenever the + # outer graph uses an original vt, it uses the subgraph output. + # + # This is critical for maintaining the separation between: + # - `body_r`: The output VT structure that Dynamo continues tracing (may + # contain non-proxyable objects, nested structures, etc.) + # - `graph_output_vts`: Only the tensor/symint VTs that were actual graph + # outputs from speculate_subgraph + # + # By overwriting the proxies of VTs in `body_r` with the proxies from the + # HOP call, we ensure the outer graph correctly references the HOP outputs + # while still allowing `body_r` to contain arbitrary Python objects. + if body_r is not None: + for orig_vt, subgraph_vt in zip(graph_output_vts, flat_variable.items): + if orig_vt.is_tensor() or isinstance(orig_vt, SymNodeVariable): + assert subgraph_vt.is_tensor() or isinstance( + subgraph_vt, SymNodeVariable + ) + orig_vt.proxy = subgraph_vt.proxy + return body_r + + +def _call_function_and_unflatten_output( + tx, fn, args, kwargs, flat_example_value, ret_spec, body_r +): + from .builder import wrap_fx_proxy + + # Store the invocation as a call + flat_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn, + args=args, + kwargs=kwargs, + ), + example_value=flat_example_value, + ) + + # wrap_fx_proxy creates fresh variable trackers. However, the main program + # after the speculate subgraph can still use the original tensor vts that + # are still pointing to the nodes present in the subgraph. So, we reproxify + # the original tensor vts with the subgraph outputs. This way, whenever the + # outer graph uses an original vt, it uses the subgraph output. + if body_r is not None: + for orig_vt, subgraph_vt in zip(body_r.items, flat_variable.items): + if orig_vt.is_tensor() or isinstance(orig_vt, SymNodeVariable): + assert subgraph_vt.is_tensor() or isinstance( + subgraph_vt, SymNodeVariable + ) + orig_vt.proxy = subgraph_vt.proxy + + if ret_spec.num_intermediate_nodes_as_outputs: + # The treespec was computed w/o any extra intermediate outputs. At this + # point, it is safe to just get rid of the extra outputs + flat_variable = TupleVariable( + flat_variable.items[: -ret_spec.num_intermediate_nodes_as_outputs] + ) + + if ret_spec.masks_to_filter_const_values: + from torch._dynamo.external_utils import insert_const_values_with_mask + + # During flattening, we removed the constant values. To ensure Dynamo + # can trace correctly, insert back the constant values in the output. + flat_variable = _make_inlined(tx, insert_const_values_with_mask)( + flat_variable, ret_spec.masks_to_filter_const_values, ret_spec.const_values + ) + + # Transform variable back into a list (previously made into a tuple by + # speculate_subgraph function) so as to respect the pytree API typing. + flat_list_variable = BuiltinVariable(list).call_function(tx, [flat_variable], {}) + return ( + _make_inlined(tx, pytree.tree_unflatten)(flat_list_variable, ret_spec.treespec) + if ret_spec.treespec + else flat_variable + ) + + +def _assert_tensors_nonaliasing(inputs, outputs): + input_tensor_ids = { + id(t) for t in pytree.tree_leaves(inputs) if isinstance(t, torch.Tensor) + } + output_tensor_ids = { + id(t) for t in pytree.tree_leaves(outputs) if isinstance(t, torch.Tensor) + } + assert input_tensor_ids.isdisjoint(output_tensor_ids), ( + "inputs to function body cannot alias outputs" + ) + + +def get_tensor_storages(tensor: torch.Tensor) -> set[StorageWeakRef]: + """ + Get storage references from a tensor. + + Handles regular tensors. Raises NotImplementedError for sparse tensors + and traceable wrapper subclasses. + + Args: + tensor: The tensor to extract storages from + + Returns: + Set of StorageWeakRef objects for the tensor's storage(s) + """ + from torch.multiprocessing.reductions import StorageWeakRef + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + storages: set[StorageWeakRef] = set() + + if not isinstance(tensor, torch.Tensor): + return storages + + if tensor.is_sparse or tensor.is_sparse_csr: + raise NotImplementedError("get_tensor_storages does not support sparse tensors") + + if is_traceable_wrapper_subclass(tensor): + raise NotImplementedError( + "get_tensor_storages does not support traceable wrapper subclasses" + ) + else: + storages.add(StorageWeakRef(tensor._typed_storage())) + + return storages + + +class StorageAliasingTracker: + """ + Tracks storage references to detect aliasing between tensors. + + This class encapsulates the logic for collecting storages from tensors + and checking for aliasing conflicts. Used to filter intermediate outputs + that would create input-output or output-output aliasing. + """ + + def __init__(self): + self.excluded_storages: set = set() + + def _collect_storages_from_tensor(self, example_value): + self.excluded_storages.update(get_tensor_storages(example_value)) + + def collect_from_inputs(self, tx): + """Collect storages from graph input placeholders.""" + from torch._higher_order_ops.utils import _collect_fake_inputs + + for node in tx.output.graph.nodes: + if node.op == "placeholder": + example_value = _collect_fake_inputs([node])[0] + if isinstance(example_value, torch.Tensor): + self._collect_storages_from_tensor(example_value) + else: + break + + def collect_from_outputs(self, graph_output_vts): + """Collect storages from existing graph outputs.""" + from torch._higher_order_ops.utils import _collect_fake_inputs + + for vt in graph_output_vts: + proxy = vt.as_proxy() + example_value = _collect_fake_inputs([proxy.node])[0] + if isinstance(example_value, torch.Tensor): + self._collect_storages_from_tensor(example_value) + + def check_and_track(self, proxy_node) -> bool: + """ + Check if a tensor can be added as a subgraph output without causing aliasing issues. + + Given a proxy node, extracts its example tensor value and checks if its storage + aliases with any previously tracked storages (from inputs or other outputs). + If there's no aliasing conflict, the tensor's storage is added to the tracked set. + + Args: + proxy_node: An FX proxy node whose example_value is the tensor to check. + + Returns: + True if the tensor doesn't alias with tracked storages (safe to add as output), + False if it aliases (should be filtered out). + """ + from torch._higher_order_ops.utils import _collect_fake_inputs + from torch.multiprocessing.reductions import StorageWeakRef + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + example_value = _collect_fake_inputs([proxy_node])[0] + + # Non-tensor outputs (e.g., symints) don't have aliasing concerns + if not isinstance(example_value, torch.Tensor): + return True + + # Check if any storage aliases with existing inputs/outputs + tensor_storages = get_tensor_storages(example_value) + if tensor_storages & self.excluded_storages: + return False + + # Track this tensor's storage (for wrapper subclasses, inner storages were already checked) + if not is_traceable_wrapper_subclass(example_value): + if not (example_value.is_sparse or example_value.is_sparse_csr): + self.excluded_storages.add( + StorageWeakRef(example_value._typed_storage()) + ) + + return True + + +def collect_intermediate_outputs( + tx, subtracer, graph_output_vts, filter_aliased_intermediates=False +): + extra_outputs = [] + existing_out_proxies = {vt.as_proxy() for vt in graph_output_vts} + + # Build the aliasing tracker if we're filtering + tracker = None + if filter_aliased_intermediates: + tracker = StorageAliasingTracker() + tracker.collect_from_inputs(tx) + tracker.collect_from_outputs(graph_output_vts) + + for out in subtracer.tracked_tensor_or_symint_vt: + proxy = out.as_proxy() + + # Skip if already in output + if proxy in existing_out_proxies: + continue + + # TODO floats are not supported in HOP input/output + if isinstance(out, SymNodeVariable) and out.python_type() is float: + continue + + if not filter_aliased_intermediates: + extra_outputs.append(out) + else: + # Filter out intermediates that alias with inputs or outputs. + # This is needed for HOPs like invoke_subgraph that don't support aliasing. + # TODO: If a filtered intermediate is captured by side effects (e.g., appended + # to a list), it will fail later with "does not belong to this Graph" error + # when the outer graph tries to use it. See test_side_effect_with_aliased_intermediate. + if tracker.check_and_track(proxy.node): + extra_outputs.append(out) + + return extra_outputs + + +def _check_all_tensorvariable(args): + if not all(type(a.realize()) is TensorVariable for a in args): + unimplemented( + gb_type="HOP: non torch.Tensor leaf", + context=f"args types: {[type(a.realize()) for a in args]}", + explanation="Expected all leaves to be of torch.Tensor type.", + hints=[], + ) + + +def _check_supported_callable_arg( + tx: "InstructionTranslator", func_var: VariableTracker, arg_name +): + is_callable = ( + BuiltinVariable(callable).call_function(tx, [func_var], {}).as_python_constant() + ) + if not is_callable: + unimplemented( + gb_type="HOP: non-callable variable", + context=f"arg name: {arg_name}, func_var type: {str(func_var)}", + explanation=f"{arg_name} should be a callable but is of type {str(func_var)}.", + hints=[], + ) + + +def _call_while_loop( + self: VariableTracker, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + stack_output: bool, +) -> VariableTracker: + from torch._higher_order_ops.while_loop import _create_unbacked_symint + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + cond_fn, body_fn, operands, additional_inputs = args + + # Input checks + for i, k in enumerate(["cond_fn", "body_fn", "operands"]): + if v := kwargs.pop(k, None): + assert i == len(args), ( + "did not provide the right number of non-keyword args" + ) + args.append(v) + + if kwargs or len(args) != 4: + unimplemented( + gb_type="torch.while_loop: improper args/kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"torch.while_loop expects 4 positional arguments (got {len(args)}) " + f"and no keyword arguments (got {len(kwargs)}) " + "Usage: while_loop(cond_fn, body_fn, operands)", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # cond_fn and body_fn input check + _check_supported_callable_arg(tx, cond_fn, "cond_fn") + _check_supported_callable_arg(tx, body_fn, "body_fn") + + # operands input check + operands_seq = operands.unpack_var_sequence(tx) + + # additional_inputs input check + if not isinstance(additional_inputs, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.while_loop: improper additional_inputs", + context=str(additional_inputs), + explanation=f"Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + additional_inputs_seq = additional_inputs.unpack_var_sequence(tx) + + with discard_graph_changes(tx): + # Note: this must be run under discard graph changes. + def unspecialize_carried_inputs(tx, carry) -> VariableTracker: + # See NOTE [unspecialize int carry with unbacked symints] + if ( + carry.is_python_constant() + and isinstance(carry.as_python_constant(), int) + ) or isinstance(carry, SymNodeVariable): + example_value = _create_unbacked_symint( + tx.output.fake_mode, ignore_fresh_unbacked_symbols=True + ) + proxy = tx.output.current_tracer.create_graph_input( + "unbacked_symint", type(example_value), example_value + ) + return SymNodeVariable.create(tx, proxy, example_value) + else: + # See NOTE [unspecialize constant tensor carry] + assert carry.is_tensor() + cloned_carry = carry.clone() + cloned_carry.proxy.node.meta["example_value"].constant = None + return cloned_carry + + # clone inputs across subgraphs, to avoid unbacked memoization in fake prop + cond_operands_seq = [ + unspecialize_carried_inputs( + tx, + ( + carry.call_method(tx, "clone", args=(), kwargs={}) + if carry.is_tensor() + else carry + ), + ) + for carry in operands_seq + ] + body_operands_seq = [ + unspecialize_carried_inputs( + tx, + ( + carry.call_method(tx, "clone", args=(), kwargs={}) + if carry.is_tensor() + else carry + ), + ) + for carry in operands_seq + ] + + # create cond subgrpahs + ( + (cond_r, _cond_treespec), + cond_graph, + cond_lifted_freevars, + ) = speculate_subgraph( + tx, + cond_fn, + cond_operands_seq + additional_inputs_seq, + {}, + "while_loop", + source_target=self.value, + # NOTE [why we cannot use "automatic" for while_loop]: + # The reason is that we want to enforce + # the ordering of inputs and outputs to be consistent and the ordering + # of cond_fn and body_fn to the consistent. + # e.g. suppose we use "automatic" and we have: + # + # def body_fn(ph1, ph2): + # new_a, new_b = ph2.cos(), ph1.sin() + # return new_a, new_b + # + # a, b = torch.randn(3), torch.randn(3) + # new_a, new_b = body_fn(a, b) + # + # Using automatic, the ordering of arguments will be the order that they're + # used. In this example, the capture graph looks like: + # + # def captured_body(ph1, ph2): + # new_a, new_b = ph1.cos(), ph2.add_(1) + # return new_a, new_b + # + # This is fine when we change the calling convention of captured_body to be + # new_a, new_b = captured_body(b, a). + # But for while_loop, the next iteration's input is previous iteration output + # we'll end up feeding captured_body(new_a, new_b) instead. + # So it's best we always enforce the ordering of carried_inputs the same as outputs + # with "flatten_manual". + set_subgraph_inputs="flatten_manual", + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + remove_consts_from_outputs=False, + ) + cond_nn_modules = dict(tx.output.nn_modules) + validate_subgraph_output_types(cond_r) + if cond_r.is_tensor(): + cond_r_meta = _extract_tensor_metadata( + cond_r.proxy.node.meta["example_value"], include_contiguity=False + ) + if cond_r_meta.dtype != torch.bool or cond_r_meta.shape != torch.Size([]): + unimplemented( + gb_type="torch.while_loop: unsupported cond_fn return type", + context=str(cond_r), + explanation=f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + elif cond_r.is_python_constant(): + # short-circuiting while_loop when cond_fn returns a constant such as 0, 1 True or False + pred = cond_r.as_python_constant() + if pred: + unimplemented( + gb_type="torch.while_loop: infinite loop detected", + context=str(cond_r), + explanation=f"Infinite loop detected because while_loop's cond_fn always returns the same value {pred}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + else: + return operands + + # create body subgraph + ( + (body_r, body_treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + body_fn, + body_operands_seq + additional_inputs_seq, + {}, + "while_loop", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + should_flatten_outputs=True, + supports_input_mutation=False, + supports_aliasing=False, + remove_consts_from_outputs=False, + ) + validate_subgraph_output_types(body_r) + + # We set include contiguity=False because we have vmap x HOP tests, where if + # include_contiguity=True will call t.is_contiguous inside of vmap and get an error + # "querying is_contiguous inside of vmap for memory_format other than + # torch.contiguous_format is not yet implemented". This is okay because stride + # is still checked. + check_meta_consistency_vt( + body_r.unpack_var_sequence(tx), + operands_seq, + "body_fn_output", + "carried_inputs", + include_contiguity=False, + ) + + ( + cond_graph, + body_graph, + cond_shared, + _body_shared, + cond_unique, + body_unique, + ) = _merge_graph_inputs( + cond_graph, + cond_lifted_freevars, + "cond_fn", + body_graph, + body_lifted_freevars, + "body_fn", + ) + + # Note: cond_shared and body_shared refer to the same proxy in parent graph + # so using either of them is OK. Use cond_shared as it doesn't matter. + additional_lifted_inputs = cond_shared + cond_unique + body_unique + + body_nn_modules = dict(tx.output.nn_modules) + + cond_gm = torch.fx.GraphModule(cond_nn_modules, cond_graph) + body_gm = torch.fx.GraphModule(body_nn_modules, body_graph) + cond_name = tx.output.install_subgraph("cond_fn", cond_gm) + body_name = tx.output.install_subgraph("body_fn", body_gm) + + cond_node = make_attr(tx, cond_name) + body_node = make_attr(tx, body_name) + + operands_proxy = tuple(operand.as_proxy() for operand in operands_seq) + additional_inputs_proxy = tuple( + [inp.as_proxy() for inp in additional_inputs_seq] + additional_lifted_inputs + ) + p_args = ( + cond_node, + body_node, + operands_proxy, + additional_inputs_proxy, + ) + return _call_function_and_unflatten_output( + tx, + self.value, + p_args, + {}, + None, + body_treespec, + body_r, + ) + + +def are_same_graph_modules(fn_name, a_mod, b_mod, fake_mode): + from torch._subclasses._fake_tensor_utils import _CacheKeyState + from torch._subclasses.fake_tensor import extract_tensor_metadata + + # Maps the equivalent nodes from a to b + node_map = {} + + def check_all_args(a_nodes, b_nodes): + for arg_a, arg_b in zip(a_nodes, b_nodes): + if isinstance(arg_a, torch.fx.Node): + if node_map[arg_a] != arg_b: + return False + elif isinstance(arg_a, slice): + if not isinstance(arg_b, slice): + return False + if not check_all_args( + (arg_a.start, arg_a.stop, arg_a.step), + (arg_b.start, arg_b.stop, arg_b.step), + ): + return False + elif arg_a != arg_b: + # This is a catch-all for everything else. `slice` was a + # surprise but can there be other data structures that can + # contain fx.Nodes in them? + return False + return True + + for a_node, b_node in zip(a_mod.graph.nodes, b_mod.graph.nodes): + if a_node.op != b_node.op: + return False + + if a_node.op == "placeholder": + a_value = a_node.meta["example_value"] + b_value = b_node.meta["example_value"] + + if isinstance(a_value, torch.Tensor): + if not isinstance(b_value, torch.Tensor): + return False + # Extract fake tensor metadata for a and b and then compare + a_result = [] + state = _CacheKeyState(fake_mode.shape_env) + a_metadata = extract_tensor_metadata(a_value) + a_metadata._flatten_into(a_result, fake_mode, state) + + b_result = [] + state = _CacheKeyState(fake_mode.shape_env) + b_metadata = extract_tensor_metadata(b_value) + b_metadata._flatten_into(b_result, fake_mode, state) + if a_result != b_result: + return False + elif isinstance(a_value, torch.SymInt): + if not isinstance(b_value, torch.SymInt): + return False + if a_value is not b_value: + return False + elif a_node.op == "call_function": + if a_node.target is not b_node.target: + return False + a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs)) + b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs)) + if not check_all_args(a_flat, b_flat): + hc_log.debug( + "%s: Graph comparison failed at node (call_function): %s", + fn_name, + a_node, + ) + return False + elif a_node.op == "call_method": + if a_node.target != b_node.target: + return False + a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs)) + b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs)) + if not check_all_args(a_flat, b_flat): + hc_log.debug( + "%s: Graph comparison failed at node (call_method) : %s", + fn_name, + a_node, + ) + return False + elif a_node.op == "output": + a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs)) + b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs)) + if not check_all_args(a_flat, b_flat): + hc_log.debug("%s: Graph comparison failed at the output node", fn_name) + return False + elif a_node.op == "get_attr": + a_attr = getattr(a_mod, a_node.target) + b_attr = getattr(b_mod, b_node.target) + if isinstance(a_attr, torch.fx.GraphModule): + if not isinstance(b_attr, torch.fx.GraphModule): + return False + # This is an example of a HOP inside a HOP + if not are_same_graph_modules(fn_name, a_attr, b_attr, fake_mode): + return False + else: + # TODO - write an example with tensor as a graph attribute in + # the Fx graph + raise NotImplementedError(f"get_attr with {type(a_attr)}") + else: + # TODO - call_module is not supported because Dynamo Fx graph does + # not install a call_module + raise NotImplementedError(f"Graph equivalence check saw a {a_node.op}") + + # Two nodes are equal - add them to them map + node_map[a_node] = b_node + + return True + + +def validate_args_and_maybe_create_graph_inputs( + sub_args, + tracer, + tx, + set_subgraph_inputs, + description, + sub_args_names=None, +): + from . import AutogradFunctionContextVariable + from .builder import wrap_fx_proxy_cls + + assert tracer.parent is not None + + if set_subgraph_inputs == "flatten_manual": + flat_args, tree_spec = _make_inlined(tx, pytree.tree_flatten)( + ListVariable(sub_args) + ).unpack_var_sequence(tx) + + flat_inputs = validate_args_and_maybe_create_graph_inputs( + flat_args.unpack_var_sequence(tx), + tracer, + tx, + set_subgraph_inputs="manual", + description=description, + ) + + return _make_inlined(tx, pytree.tree_unflatten)( + ListVariable(flat_inputs), tree_spec + ).unpack_var_sequence(tx) + else: + if sub_args_names is not None: + # Can be greater if user passes some args as kwargs + assert len(sub_args_names) >= len(sub_args) + args = [] + for idx, a in enumerate(sub_args): + assert isinstance(a, VariableTracker) + if set_subgraph_inputs == "automatic": + args.append(a) + continue + elif set_subgraph_inputs == "automatic_with_forced_inputs": + if isinstance(a, variables.TensorVariable): + node = a.maybe_fx_node() + example_value = node.meta["example_value"] + arg_name = ( + a.as_proxy().node.name + if sub_args_names is None + else sub_args_names[idx] + ) + new_proxy = tracer.create_graph_input( + arg_name, a.python_type(), example_value + ) + example_value = node.meta.get("example_value", None) + a = wrap_fx_proxy_cls( + target_cls=type(a), + tx=tx, + proxy=new_proxy, + example_value=example_value, + ) + elif set_subgraph_inputs == "semi_automatic": + if isinstance(a, AutogradFunctionContextVariable): + example_value = a.as_proxy().node.meta["example_value"] + arg_name = ( + a.as_proxy().node.name + if sub_args_names is None + else sub_args_names[idx] + ) + tracer.create_graph_input(arg_name, a.python_type(), example_value) + elif a.maybe_fx_node() is not None: + node = a.maybe_fx_node() + example_value = node.meta["example_value"] + arg_name = ( + a.as_proxy().node.name + if sub_args_names is None + else sub_args_names[idx] + ) + new_proxy = tracer.create_graph_input( + arg_name, a.python_type(), example_value + ) + example_value = node.meta.get("example_value", None) + a = wrap_fx_proxy_cls( + target_cls=type(a), + tx=tx, + proxy=new_proxy, + example_value=example_value, + ) + args.append(a) + continue + + if a.is_python_constant(): + # This arg is not used in the body of the higher order op. + # Currently, this new input is added to make the calls + # happy, which expect a fixed number of arguments. In + # future, we can clean this up. + arg_name = ( + "const_unused" + if sub_args_names is None + else f"const_unused_{sub_args_names[idx]}" + ) + tracer.create_graph_input( + arg_name, a.python_type(), a.as_python_constant() + ) + new_arg = a + # Weird special case, we probably want to delete it or fold it + # into the next case (of `a` being placeable into a graph) + elif isinstance(a, AutogradFunctionContextVariable): + example_value = a.as_proxy().node.meta["example_value"] + arg_name = ( + a.as_proxy().node.name + if sub_args_names is None + else sub_args_names[idx] + ) + tracer.create_graph_input(arg_name, a.python_type(), example_value) + new_arg = a + # If `a` can be put into a graph + elif a.maybe_fx_node() is not None: + node = a.maybe_fx_node() + example_value = node.meta.get("example_value", None) + arg_name = node.name if sub_args_names is None else sub_args_names[idx] + new_proxy = tracer.create_graph_input( + arg_name, a.python_type(), example_value + ) + new_arg = wrap_fx_proxy_cls( + target_cls=type(a), + tx=tx, + proxy=new_proxy, + example_value=example_value, + ) + # If `a` cannot be put into a graph + else: + # HOPs work much better if they use speculate_subgraph(set_subgraph_inputs="automatic"). + unimplemented( + gb_type="HOP body taking non-Tensor as input", + context=str(sub_args), + explanation=f"{description} with body that accepts non-Tensors as input. " + f"Got type {a.python_type()} at index {idx}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + args.append(new_arg) + return args + + +# This helper function is used to make sure two graphs share the same input signature. For example, +# in torch.cond, two branches might lift different set of tensors as inputs. This function helps to +# dedup the inputs and modify the graphs to take the same set of inputs. +def _merge_graph_inputs( + l_graph, l_lifted_freevars, l_name, r_graph, r_lifted_freevars, r_name +): + def dedup_and_sort_lifted_freevars(l_lifted_freevars, r_lifted_freevars): + # The nn module attributes are guaranteed to be registered into the top-level graph module during + # higher order op speculation. Therefore, get_attr nodes in two branches with the same + # target refer to the same attribute and we can safely deduplicate them with their target. + # + # Note: ideally, dynamo should just create a single proxy for the same attribute of a nn module. But + # true_branch and false_branch belong to two separate tracing contexts, they may register the same + # attribute to top level separately. This creates two get_attr proxies for the same attribute + # that have different meta data such as stack_trace (one stack trace for the true_branch, + # and the other for false_branch). It seems better to discard the proxy explicitly in cond + # than make dynamo create a single proxy for the same get_attr target. + def shared_getattrs(l_lifted_proxies, r_lifted_proxies): + true_targets = { + proxy.node.target: proxy + for proxy in l_lifted_proxies + if proxy.node.op == "get_attr" + } + l_shared_getattrs = {} + r_shared_getattrs = {} + + for false_proxy in r_lifted_proxies: + if ( + false_proxy.node.op == "get_attr" + and false_proxy.node.target in true_targets + ): + true_proxy = true_targets[false_proxy.node.target] + l_shared_getattrs[true_proxy] = true_proxy + r_shared_getattrs[false_proxy] = true_proxy + return l_shared_getattrs, r_shared_getattrs + + l_shared_getattrs, r_shared_getattrs = shared_getattrs( + l_lifted_freevars.keys(), r_lifted_freevars.keys() + ) + + l_shared_freevars = (l_lifted_freevars.keys() & r_lifted_freevars.keys()).union( + l_shared_getattrs.keys() + ) + r_shared_freevars = (l_lifted_freevars.keys() & r_lifted_freevars.keys()).union( + r_shared_getattrs.keys() + ) + unique_l_freevars = l_lifted_freevars.keys() - l_shared_freevars + unique_r_freevars = r_lifted_freevars.keys() - r_shared_freevars + + def _sort_by_name(vars): + return sorted(vars, key=lambda var: var.node.name) + + return ( + list(_sort_by_name(list(l_shared_freevars))), + list(_sort_by_name(list(r_shared_freevars))), + list(_sort_by_name(list(unique_l_freevars))), + list(_sort_by_name(list(unique_r_freevars))), + ) + + (l_shared, r_shared, unique_l, unique_r) = dedup_and_sort_lifted_freevars( + l_lifted_freevars, r_lifted_freevars + ) + + # Let's say we capture cond(pred, true_fn, false_fn, (x,)) + # With set_graph_input set to automatic, + # true_fn has lifted variables x, a, b, c + # false_fn has lifted variables x, a, b, d + # Then fixup_branch_inps make sure both branches have the same signature, i.e.: + # - true_fn(x, a, b, c_true_branch, d_false_branch) + # - false_fn(x, a, b, c_true_branch, d_false_branch) + # + # More formally, the signature has three parts in the following order: + # 1. used in both branches: x, a, b + # 2. only used in true branches: c, suffixed with _true_branch + # 3. only used in false branches: d, suffixed with _false_branch + # Within each part, we re-order the nodes by name to have a derterministic ordering for testing. + def fixup_branch_inps(graph, lifted_freevars, shared, unique_l, unique_r): + def _insert_or_replace_phs(new_args, name_suffix): + for arg in new_args: + new_ph = graph.placeholder(arg.node.name + name_suffix) + new_ph.meta = arg.node.meta + # Override with new_ph if there exists a old placeholder. + if arg in lifted_freevars: + old_ph = lifted_freevars[arg].node + old_ph.replace_all_uses_with(new_ph) + # replace_all_uses_with doesn't clean users. Clean it manually so that we could erase it. + old_ph.users = {} + graph.erase_node(old_ph) + + first_not_ph_node = next( + node for node in graph.nodes if node.op != "placeholder" + ) + with graph.inserting_before(first_not_ph_node): + _insert_or_replace_phs(shared, "") + _insert_or_replace_phs(unique_l, "_" + l_name) + _insert_or_replace_phs(unique_r, "_" + r_name) + + fixup_branch_inps(l_graph, l_lifted_freevars, l_shared, unique_l, unique_r) + fixup_branch_inps(r_graph, r_lifted_freevars, r_shared, unique_l, unique_r) + return l_graph, r_graph, l_shared, r_shared, unique_l, unique_r + + +# NOTE: [HigherOrderOperator subgraph input ordering] +# The input ordering of the higher order ops is determined by the order of +# the creation of the placeholder. +# Manually created inputs are created in validate_args_and_maybe_create_graph_inputs before +# speculating subgraph. +# During subgraph speculation, we may lift closured tensors and free symbols as inputs, +# their ordering is determined by the time they are lifted: earlier lifted ones precede later +# lifted ones. +# +# Suppose the placeholders are +# O1, O2, X1, O3, O4, X2, X3, O5 where Xs are lifted phs +# The following code re-order the placeholders to +# O1, O2, O3, O4, O5, X1, X2, X3 +def move_lifted_freevars_phs_to_end( + graph: torch.fx.Graph, lifted_freevars: dict[Any, torch.fx.Node] +): + lifted_ph_set = {child_p.node for child_p in lifted_freevars.values()} + + prev_phs = [n for n in graph.nodes if n.op == "placeholder"] + + # No need to reorder when graph doesn't have args or doesn't + # have lifted freevars or all inputs are lifted freevars. + if ( + len(prev_phs) == 0 + or len(lifted_ph_set) == 0 + or len(prev_phs) == len(lifted_ph_set) + ): + return + + # Step 1: find first X1 + for x1 in prev_phs: + if x1 in lifted_ph_set: + break + + assert x1 is not None and x1.op == "placeholder" + # Step 2: starting from the X1, skip Xs and prepend Os before X1. + cand_x = x1.next + while cand_x is not None and cand_x.op == "placeholder": + if cand_x in lifted_ph_set: + cand_x = cand_x.next + else: + nxt = cand_x.next + cand_x._remove_from_list() + x1.prepend(cand_x) + cand_x = nxt + + # Step 3: assert that all placeholders are in the correct order as . + # in lifted_freevars + after_phs = [node for node in graph.nodes if node.op == "placeholder"][ + -len(lifted_freevars) : + ] + assert len(after_phs) == len(lifted_freevars) + for child_proxy, ph in zip(lifted_freevars.values(), after_phs): + assert child_proxy.node is ph, ( + "The order of placeholders is different from the order of lifted_freevars" + ) + + graph.lint() + + +def check_aliasing_and_input_mutation( + subtracer, graph, supports_input_mutation, supports_aliasing, source_target +): + if not supports_input_mutation: + mutation_info = subtracer.has_input_mutation() + if mutation_info.has_mutation: + context = f"{mutation_info.msg} in\n {graph}" + unimplemented( + gb_type="Encountered input mutation during higher order op tracing", + context=context, + explanation=f"Higher order ops do not support input mutation. Found in {source_target.name}", + hints=[ + "Consider using the debug context to change user code to avoid mutation.", + "Please open an issue.", + ], + ) + + if not supports_aliasing: + aliasing_info = subtracer.has_aliasing() + if aliasing_info.has_aliasing: + context = f"{aliasing_info.msg} in\n {graph}" + unimplemented( + gb_type="Encountered aliasing during higher order op tracing", + context=context, + explanation=f"Higher order ops do not support aliasing. Found in {source_target.name}", + hints=[ + "Replace `return input` with `return input.clone()` to avoid aliasing.", + "Consider using the debug context to change user code to avoid aliasing.", + "Please open an issue.", + ], + ) + + +def trace_hop_function( + f, + tx, + subtracer, + enable_grad, + restore_side_effects, + args, + sub_kwargs, +): + # For autograd.Function and other legacy HOPs, we do NOT couple + # restore_side_effects with allow_side_effects_in_hop. + # This preserves the old behavior where: + # - restore_side_effects=False means ctx mutations persist + # - But non-ctx side effects still cause graph breaks (under_activation_checkpoint was False) + enable_side_effects_with_extra_outputs = False + + autograd_ctx = ( + dynamo_enable_grad(tx, enable_grad) + if enable_grad is not None + else contextlib.nullcontext() + ) + side_effects_ctx = ( + dynamo_allow_side_effects_in_hop(tx) + if enable_side_effects_with_extra_outputs + else contextlib.nullcontext() + ) + + # For handling side effects, we can make an argument that we don't + # have to do anything here. The side effects infra does a good job + # of graph breaking if we mutate any nonlocal or global variable + # while subtracing. As a result if tracing succeeds, side effects + # data structure will only contain read-only data structures that + # are put there for tracking purposes. + # But on the other hand, there is an argument that if we ever write + # a new side effect in Dynamo which does not go through the side + # effect infra, we can end up in bad state. + # Therefore we restore the side effects after tracing. The catch is + # that we have to special handle tensor variables. If we have seen a + # nonlocal variable tensor during subtracing, we want to keep a + # track of that tensor, so that later subtracing or the root tracer + # itself does not create a new proxy for the already observed tensor + # variable. + if restore_side_effects: + prev_side_effects = tx.output.side_effects.clone() + + with autograd_ctx, side_effects_ctx: + output = f.call_function(tx, args, sub_kwargs) + + if restore_side_effects: + new_side_effects = tx.output.side_effects.clone() + prev_side_effects.track_runahead_tensor_and_symvar_side_effects( + new_side_effects + ) + tx.output.side_effects = prev_side_effects + return output + + +def trace_hop_function_with_auto_output_flattening( + f, + tx, + subtracer, + enable_grad, + allow_side_effects, + args, + sub_kwargs, +): + autograd_ctx = ( + dynamo_enable_grad(tx, enable_grad) + if enable_grad is not None + else contextlib.nullcontext() + ) + side_effects_ctx = ( + dynamo_allow_side_effects_in_hop(tx) + if allow_side_effects + else contextlib.nullcontext() + ) + + with autograd_ctx, side_effects_ctx: + output = f.call_function(tx, args, sub_kwargs) + + return output + + +def get_hop_args( + tx, f, subtracer, sub_args, sub_kwargs, set_subgraph_inputs, description +): + sub_args_names = maybe_positional_arg_names(f) + # User mismatch in the number of args. Will eventually lead to an error. + if sub_args_names is not None and len(sub_args_names) < len(sub_args): + sub_args_names = None + args = validate_args_and_maybe_create_graph_inputs( + sub_args, + subtracer, + tx, + set_subgraph_inputs, + description, + sub_args_names, + ) + + validate_args_and_maybe_create_graph_inputs( + sub_kwargs.values(), + subtracer, + tx, + set_subgraph_inputs="automatic", + description=description, + ) + return args + + +# TODO - The eventual goal is to replace +# speculate_subgraph_with_auto_output_flattening with speculate_subgraph or +# merge them two into one. We are following a staged approach because of +# existing implementation complexity for control flow ops. +def speculate_subgraph_with_auto_output_flattening( + tx: "InstructionTranslator", + f: VariableTracker, + sub_args: Sequence[VariableTracker], + sub_kwargs: Optional[dict[str, VariableTracker]], + description: str, + *, + # source_target is the .value of HigherOrderOpVariable and is the + # target of the proxy that we created for the higherOrderOperator. + source_target: Optional[HigherOrderOperator] = None, + enable_grad: Optional[bool] = None, + # TODO - We can probably just make everyone use automatic for wrap_semantics + set_subgraph_inputs: Literal[ + "automatic", "semi_automatic", "flatten_manual", "manual" + ] = "automatic", + # If True, exposes intermediates to subgraph outputs to allow later tensor ops to + # access intermediates from the subgraph, this is useful for mutation + allow_side_effects: bool = False, + # Controls whether to filter aliased intermediates when collecting extra outputs. + # This is only relevant when allow_side_effects=True. + # - True: Filter out intermediates that alias with inputs or outputs (strict, for invoke_subgraph) + # - False: Allow aliased intermediates (for checkpoint/autograd.Function which get desugared/inlined) + # + # Example where filtering is needed: + # + # @invoke_subgraph + # def gn(x): + # view = x.view(2, 4) # intermediate that aliases input x + # y = torch.sin(view) + # return torch.cos(view) + # + # def fn(x): + # res = gn(x) + # return res + 4 + # + # In this case, if we don't filter `view`, we would later error because some HOPs + # have strict aliasing checks on inputs/outputs. + # + # This does however introduce a subtle issue when we do something like: + # + # captured = [] + # + # @invoke_subgraph + # def gn(x): + # view = x.view(2, 4) # intermediate that aliases input x + # y = torch.sin(view) + # captured.append(view) + # return torch.cos(view) + # + # def fn(x): + # res = gn(x) + # return res + captured[0] + # + # In this case, we will not replay the side effect on `captured` in the graph, + # which fails with a not-so-nice error. We will address this in a follow-up PR + # because this case is rare. This is not a regression because side effects were + # never supported for invoke_subgraph anyway. + filter_aliased_intermediates: bool = False, + # TODO - supports input_mutation and aliasing should be False by default for strictness + supports_input_mutation: bool = True, + supports_aliasing: bool = True, + # Pass in an originating tracer - this is needed for preserving context + # across fwd-bwd for autograd.Function + tracer: Optional["torch._dynamo.output_graph.SubgraphTracer"] = None, +) -> tuple[ + VariableTracker, # output: The VT that Dynamo continues tracing with + torch.fx.Graph, # graph: The FX graph representing the subgraph computation + dict[ + torch.fx.Proxy, torch.fx.Proxy + ], # lifted_freevars: Free variables lifted as inputs + VariableTracker + | tuple[ + VariableTracker, ... + ], # graph_output_vts: Tensor/symint VTs that are actual FX graph outputs +]: + """ + Speculate subgraph for Higher-Order Operators (HOPs) with automatic output flattening. + + ## Automatic output flattening + + For many HOPs, the representation exists only as a container for the + subgraph. In later compiler stages or at runtime, the HOP is desugared and + simply executes the subgraph directly, as if it were inlined. For such hops, + we follow automatic output flattening. + For example: + - invoke_subgraph + - activation checkpointing (torch.utils.checkpoint.checkpoint) + - autograd.Function + - nested_compile_region + + This is in contrast to control flow HOPs which do not follow this desugaring: + - torch.cond (conditional execution based on predicate) + - torch.while_loop (iterative execution) + - torch.map (parallel execution over batch dimension) + + For control flow HOPs, the HOP behavior is fundamentally different from just + running the body function once. + + ## Key Advantage: Disentangling VTs from Graph Outputs + + Desugaring simplify HOP processing by allowing us to disentangle the output + variable trackers (VTs) from the HOP subgraph outputs. This mirrors typical + Dynamo processing where: + - VTs "run ahead" representing the program state for continued tracing + - The graph is a side data structure tracking computation seen so far + + This separation is crucial for HOPs with non-proxyable outputs (e.g., custom + user-defined objects containing tensors). The function may return complex Python + objects for Dynamo to continue tracing, but only the tensor/symint VTs need to + be registered as actual FX graph outputs. + + Example: + class Foo: + def __init__(self, a, b): + self.a = a # tensor + self.b = b # tensor + + def gn(x): + return Foo(torch.sin(x), torch.cos(x)) + + result = some_hop(gn, x) # Returns Foo instance + out = result.a + result.b # Dynamo can continue tracing + + Here, `output` VT is a UserDefinedObjectVariable wrapping Foo, but + `graph_output_vts` contains only the tensor VTs (a and b) that should be + actual FX graph outputs. This allows Dynamo to continue tracing with the + Foo object while the graph only needs to output the constituent tensors. + + ## Return Values + + Unlike `speculate_subgraph`, this function returns: + - output: The VT that Dynamo continues tracing with (may be complex Python objects) + - graph: The FX graph representing the subgraph computation + - lifted_freevars: Free variables lifted as inputs to the subgraph + - graph_output_vts: Only the tensor/symint VTs that are actual FX graph outputs + + The key difference is `graph_output_vts` instead of `treespec`, which gives more + flexibility for handling non-proxyable outputs. + """ + if sub_kwargs is None: + sub_kwargs = {} + + assert set_subgraph_inputs in { + "automatic", + "semi_automatic", + "flatten_manual", + "manual", + }, "Please use one of the supported set_subgraph_inputs options." + + # See NOTE [Temporary argument `set_subgraph_inputs`] + if sub_kwargs and set_subgraph_inputs != "automatic": + unimplemented( + gb_type="invalid set_subgraph_inputs and sub_kwargs settings", + context=f"set_subgraph_inputs: {set_subgraph_inputs}, sub_kwargs: {sub_kwargs}", + explanation="`sub_kwargs` cannot be used when `set_subgraph_inputs` is not set to 'automatic'.", + hints=[ + "Use `set_subgraph_inputs='automatic'` when passing `sub_kwargs`.", + *graph_break_hints.USER_ERROR, + ], + ) + + try: + # ensure guards on args get installed in parent subgraph + f, sub_args, sub_kwargs = LazyVariableTracker.realize_all( + (f, sub_args, sub_kwargs), + ) + + with tx.output.subtracer(source_target, tracer, description) as subtracer: + args = get_hop_args( + tx, f, subtracer, sub_args, sub_kwargs, set_subgraph_inputs, description + ) + + # Special case - if users uses + # `traced_with_externally_visible_side_effects`, we still need to + # return the intermediates as outputs. However, this API gets + # triggered during the hop tracing, and we don't know at this point + # of time, if the API will take into effect. To handle this, we have + # a flag traced_with_externally_visible_side_effects (default=False) + # that is set to True anytime + # `traced_with_externally_visible_side_effects` is set. We reset it + # with the old value after the hop is traced out. + old_value = ( + tx.output.current_tracer.traced_with_externally_visible_side_effects + ) + + output = trace_hop_function_with_auto_output_flattening( + f, + tx, + subtracer, + enable_grad, + allow_side_effects, + args, + sub_kwargs, + ) + + # NOTE: [Separation of graph outputs and output VTs] + # In Dynamo (outside of speculate_subgraph), VTs and the graph are + # separate concepts: + # - VTs (VariableTrackers) can "run ahead" and continue Dynamo tracing + # - The graph is just a side data structure tracking computation seen so far + # + # This separation is crucial for HOPs with non-proxyable outputs (e.g., + # custom user-defined objects containing tensors). The function may return + # complex Python objects for Dynamo to continue tracing, but only the + # tensor/symint VTs need to be registered as actual graph outputs. + # + # Example: + # class Foo: + # def __init__(self, a, b): + # self.a = a # tensor + # self.b = b # tensor + # + # def gn(x): + # return Foo(torch.sin(x), torch.cos(x)) + # + # Here, `output` VT is a UserDefinedObjectVariable wrapping Foo, but + # `graph_output_vts` contains only the tensor VTs (a and b) that should + # be actual FX graph outputs. + # Collect only tensor and symint VTs that should be graph outputs. + # We walk the output structure and extract proxyable VTs. + graph_output_vts = [] + + def visit(vt): + if vt.is_tensor() or isinstance(vt, SymNodeVariable): + graph_output_vts.append(vt) + + VariableTracker.visit(visit, output) + graph_output_vts = tuple(graph_output_vts) + + # NOTE - [Return subgraph intermediates as subgraph outputs] + # This helps HOPs which allow side effects. Consider the + # following example + # + # def gn(x, z): + # o = torch.matmul(x, x) @ x + # out = x.sin() + # z.append(out) + # return torch.cos(torch.sin(o)) + + # def fn(x): + # z = [] + # out1 = torch.utils.checkpoint.checkpoint( + # gn, + # x, + # z, + # use_reentrant=False, + # ) + # return out1, z[0] + # + # In this example, list `z` is in outer scope and gets appended + # in the subgraph with `out`. But `out` is not an output of the + # subgraph. This can cause issue because later on when the outer + # graph returns `z[0]` it needs to have access to the graph node + # `out`. To solve this problem, we just return all intermediates + # from the subgraph. + + # TODO - Today this is supported only for AC. AC HOP gets + # desugared in AOTDispatcher so even though subgraph has extra + # unused outputs in Dynamo, its ok even if we don't DCE them in + # Dynamo. As AOTDispatcher desugars/inlines the subgraph, the + # subgraph boundary disappears. And even for AC, today this only + # works when the skip_fwd_side_effects_in_bwd_under_checkpoint + # flag is True, i.e., only when we allow side-effects. But, we + # want this to be supported for other Hops as well, specifically + # nested_compile_region and autograd.Function. Today, its safe + # because we error out on seeing a side-effect. + + allow_side_effects = ( + allow_side_effects + or tx.output.current_tracer.traced_with_externally_visible_side_effects + ) + if allow_side_effects: + extra_outputs = collect_intermediate_outputs( + tx, subtracer, graph_output_vts, filter_aliased_intermediates + ) + graph_output_vts = graph_output_vts + tuple(extra_outputs) + + tx.output.current_tracer.traced_with_externally_visible_side_effects = ( + old_value + ) + + validate_subgraph_output_types(graph_output_vts) + + # The output proxies might not belong to this SubgraphTracer + # (if they are free variables that were never lifted) + # so lift them here. + # output_proxies = output.as_proxy() + if isinstance(graph_output_vts, tuple): + output_proxies = [a.as_proxy() for a in graph_output_vts] + output_proxies = pytree.tree_map( + subtracer.maybe_lift_tracked_freevar_to_input, output_proxies + ) + output_proxies = tuple(output_proxies) + else: + output_proxies = output.as_proxy() + output_proxies = pytree.tree_map( + subtracer.maybe_lift_tracked_freevar_to_input, output_proxies + ) + + tx.output.create_node( + "output", + "output", + (subtracer.create_arg((output_proxies,))), + {}, + ) + graph = tx.output.graph + graph.lint() + lifted_freevars = subtracer.lifted_freevars + + if len(lifted_freevars) > 0: + move_lifted_freevars_phs_to_end(graph, lifted_freevars) + + check_aliasing_and_input_mutation( + subtracer, + graph, + supports_input_mutation, + supports_aliasing, + source_target, + ) + # Return both the output VT and the graph output VTs separately: + # - `output`: The VT that Dynamo continues tracing with (may be + # complex Python objects, tuples, dicts, etc.) + # - `graph`: The FX graph representing the subgraph computation + # - `lifted_freevars`: Free variables lifted as inputs to the subgraph + # - `graph_output_vts`: Only the tensor/symint VTs that are actual + # FX graph outputs (basically the vts associated with graph outputs) + return ( + output, + graph, + lifted_freevars, + graph_output_vts, + ) + except Unsupported as ex: + f_name = f"{type(f).__name__}" + if isinstance(f, UserFunctionVariable): + f_name = f.get_name() + msg = ( + f"speculate_subgraph: while introspecting {description}, we were unable " + f"to trace function `{f_name}` into a single graph. This means " + f"that Dynamo was unable to prove safety for this API and will " + f"fall back to eager-mode PyTorch, which could lead to a slowdown." + ) + log.info(msg) + log.info(ex) # noqa: G200 + raise ex + + +# See NOTE [HigherOrderOperator tracing design] for details of the design +def speculate_subgraph( + tx, + f, + sub_args, + sub_kwargs, + description, + *, + # source_target is the .value of HigherOrderOpVariable and is the + # target of the proxy that we created for the higherOrderOperator. + source_target=None, + always_restore=False, + enable_grad=None, + # NOTE [argument `set_subgraph_inputs`] + # set_subgraph_inputs controls what how to construct subgraphs' placeholders from sub_args. + # 1. if your HOP supports arbitrary inputs, use set_subgraph_inputs="automatic" (most recommended). + # 2. if your HOP supports only Tensor and symnode inputs, use set_subgraph_inputs="flatten_manual" (recommended). + # If sub_args contain Pytree structure (e.g. dict/list/tuple/set), the sub_args will be flattened first. + # Then the flattened args are manually set as subgraph's placeholders. + # 3. if your HOP must preserve inputs that are not tensor or symnode as placeholders e.g. AutogradFunctionContextVariable + # use set_subgraph_inputs="manual" (not recommended). We do not recommend it in general because it has the + # restriction that user need to manually control how to create placeholders and VariableTrackers for the args. + set_subgraph_inputs="automatic", + restore_side_effects=True, + should_flatten_outputs=False, + # if should_flatten_outputs is True, `remove_consts_from_outputs` remove the + # const outputs from the subgraph output. + remove_consts_from_outputs=True, + # TODO - supports input_mutation and aliasing should be False by default for strictness + supports_input_mutation=True, + supports_aliasing=True, + # Pass in an originating tracer - this is needed for preserving context + # across fwd-bwd for autograd.Function + tracer=None, +): + if sub_kwargs is None: + sub_kwargs = {} + + assert set_subgraph_inputs in { + "automatic", + "semi_automatic", + "flatten_manual", + "manual", + }, "Please use one of the supported set_subgraph_inputs options." + + # See NOTE [Temporary argument `set_subgraph_inputs`] + if sub_kwargs and set_subgraph_inputs != "automatic": + unimplemented( + gb_type="invalid set_subgraph_inputs and sub_kwargs settings", + context=f"set_subgraph_inputs: {set_subgraph_inputs}, sub_kwargs: {sub_kwargs}", + explanation="`sub_kwargs` cannot be used when `set_subgraph_inputs` is not set to 'automatic'.", + hints=[ + "Use `set_subgraph_inputs='automatic'` when passing `sub_kwargs`.", + *graph_break_hints.USER_ERROR, + ], + ) + + try: + # ensure guards on args get installed in parent subgraph + f, sub_args, sub_kwargs = LazyVariableTracker.realize_all( + (f, sub_args, sub_kwargs), + ) + + with tx.output.subtracer(source_target, tracer, description) as subtracer: + args = get_hop_args( + tx, f, subtracer, sub_args, sub_kwargs, set_subgraph_inputs, description + ) + + output = trace_hop_function( + f, + tx, + subtracer, + enable_grad, + restore_side_effects, + args, + sub_kwargs, + ) + + treespec = None + masks_to_filter_const_values = None + const_values = None + if should_flatten_outputs: + from torch._dynamo.external_utils import filter_out_const_values + + # Flatten the speculated subgraph output. + output, treespec = _make_inlined(tx, pytree.tree_flatten)( + output + ).unpack_var_sequence(tx) + + # Actually, transform the list (returned by flatten) into a tuple + # for dynamo consistency. + output = BuiltinVariable(tuple).call_function(tx, [output], {}) + + if remove_consts_from_outputs: + # Filter out the constants and save them into a spec. Filtering + # out constants makes the graph simpler for the backends. We + # need to ensure that after unflattening the constants are + # inserted back at the right positions for the Dynamo tracing to + # continue. This is done by filter_const_spec + output_proxies = output.as_proxy() + masks_to_filter_const_values = pytree.tree_map( + lambda x: not isinstance(x, torch.fx.Proxy), output_proxies + ) + const_values = pytree.tree_map( + lambda x: None if isinstance(x, torch.fx.Proxy) else x, + output_proxies, + ) + output = _make_inlined(tx, filter_out_const_values)( + output, masks_to_filter_const_values + ) + + # TODO - clean up num_intermediate_nodes_as_outputs - we do not need + # after AC moved to auto_output_flattening + num_intermediate_nodes_as_outputs = 0 + # Register output to graph + # Modeled off of compile_and_call_fx_graph + # TODO: support pytree output + # We check always_restore because we dont use the output or side effects of always_restore code, + # like bwd. + if always_restore: + # Nothing left to do here + return ( + ( + output, + OutputSpec( + treespec, + masks_to_filter_const_values, + const_values, + num_intermediate_nodes_as_outputs, + ), + ), + tx.output.graph, + subtracer.lifted_freevars, + ) + else: + validate_subgraph_output_types(output) + + # The output proxies might not belong to this SubgraphTracer + # (if they are free variables that were never lifted) + # so lift them here. + output_proxies = output.as_proxy() + output_proxies = pytree.tree_map( + subtracer.maybe_lift_tracked_freevar_to_input, output_proxies + ) + + tx.output.create_node( + "output", + "output", + (subtracer.create_arg((output_proxies,))), + {}, + ) + graph = tx.output.graph + graph.lint() + lifted_freevars = subtracer.lifted_freevars + + if len(lifted_freevars) > 0: + move_lifted_freevars_phs_to_end(graph, lifted_freevars) + + check_aliasing_and_input_mutation( + subtracer, + graph, + supports_input_mutation, + supports_aliasing, + source_target, + ) + + return ( + ( + output, + OutputSpec( + treespec, + masks_to_filter_const_values, + const_values, + num_intermediate_nodes_as_outputs, + ), + ), + graph, + lifted_freevars, + ) + + except Unsupported as ex: + f_name = f"{type(f).__name__}" + if isinstance(f, UserFunctionVariable): + f_name = f.get_name() + msg = ( + f"speculate_subgraph: while introspecting {description}, we were unable " + f"to trace function `{f_name}` into a single graph. This means " + f"that Dynamo was unable to prove safety for this API and will " + f"fall back to eager-mode PyTorch, which could lead to a slowdown." + ) + log.info(msg) + log.info(ex) # noqa: G200 + raise ex + + +def make_attr(tx: "InstructionTranslator", name): + node = tx.output.create_proxy( + "get_attr", + name, + (), + {}, + ) + return node + + +class TorchHigherOrderOperatorVariable(VariableTracker): + def __init__( + self, value: HigherOrderOperator, source: Optional[Source] = None, **kwargs + ) -> None: + super().__init__(**kwargs) + self.value = value + self.source = source + + @staticmethod + def make(value, source=None, **kwargs): + variable_class = _hop_name_to_variable_class.get(value.__name__) + if variable_class is not None: + return variable_class(value, source, **kwargs) + + from torch._higher_order_ops import BaseHOP + + if isinstance(value, BaseHOP): + return BaseHOPVariable(value, source, **kwargs) + unimplemented( + gb_type="unsupported HigherOrderOperator", + context=str(value), + explanation=f"Unable to create higher order operator variable for {value.__name__}.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .torch_function import can_dispatch_torch_function, dispatch_torch_function + + if can_dispatch_torch_function(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + + return self._call_function(tx, args, kwargs) + + def _call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + unimplemented( + gb_type="unsupported HigherOrderOperator function call", + context=str(self.value), + explanation=f"Unable to trace calling higher order operator variable for {self.value.__name__}.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + def as_python_constant(self): + return self.value + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): + """ + Wraps torch._functorch.autograd_function.custom_function_call + """ + + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return torch._dynamo.variables.UserMethodVariable( + self.value.__call__.__func__, + torch._dynamo.variables.UserDefinedObjectVariable( + self.value, source=self.source + ), + source=AttrSource(self.source, "__call__"), + ).call_function(tx, args, kwargs) + + +class CondHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="Cond doesn't work unless it is captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ListVariable + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + for i, k in enumerate(["pred", "true_fn", "false_fn", "operands"]): + if v := kwargs.pop(k, None): + assert i == len(args), ( + "did not provide the right number of non-keyword args" + ) + args.append(v) + + # TODO(voz): Support fake tensor dispatch for recursive + # ops - see torch/dispatch/_dispatcher.py + if len(args) != 4 or kwargs: + unimplemented( + gb_type="torch.cond: improper args/kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"torch.cond expects 4 positional arguments (got {len(args)}) " + f"and no keyword arguments (got {len(kwargs)}) " + "Usage: cond(pred, cond_fn, body_fn, operands)", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # Specialize into one of the branches since pred is constant + pred, true_fn, false_fn, operands = args + if type(args[0]) is ConstantVariable: + warnings.warn( + "Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." + " If you want torch.cond to preserve two branches, please make the predicate a boolean tensor or a SymBool.", + UserWarning, + ) + if pred.as_python_constant(): + return true_fn.call_function(tx, operands.unpack_var_sequence(tx), {}) + else: + return false_fn.call_function(tx, operands.unpack_var_sequence(tx), {}) + + # predicate + if type(pred.realize()) not in ( + ConstantVariable, + TensorVariable, + SymNodeVariable, + ): + unimplemented( + gb_type="torch.cond: improper predicate", + context=str(pred), + explanation="Expected `pred` to be a bool or a boolean tensor with a single item " + f"but got {str(type(pred))} with original python type {str(pred.python_type())}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # operands + if not isinstance(operands, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.cond: improper operands", + context=str(operands), + explanation="Expected `operands` to be a list/tuple " + f"but got {operands.python_type()}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + operands_seq = operands.unpack_var_sequence(tx) + if not only_consist_of( + operands, (TensorVariable, ConstantVariable, SymNodeVariable) + ): + unimplemented( + gb_type="torch.cond: improper operands contents", + context=str(operands), + explanation="Expected `operands` to be a list/tuple of pytrees that only consists of tensor leaves.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # branches + _check_supported_callable_arg(tx, true_fn, "true_fn") + _check_supported_callable_arg(tx, false_fn, "false_fn") + + # Our strategy for tracing the true/false branches of cond + # are to checkpoint our graphstate, run the true branch, + # roll it back to the checkpoint, and run the false + # branch, and then merge the graphstates. Well, perhaps + # "merge" is too strong a word: we mostly assert that + # the resulting graphstates have to be the same. + # + # We only permit guards to diverge (we union the guards from + # both branches). In particular, this means that side + # effects are NOT permitted inside true/false branches; this + # would be difficult to implement, because of the path + # explosion problem. + + def speculate_branch(branch): + # NB: 0 is predicate + ix = 1 if branch else 2 + # TODO: Support kwargs + ( + (ret_val, ret_spec), + ret_graph, + ret_lifted_freevars, + ) = speculate_subgraph( + tx, + args[ix], + operands_seq, + {}, + "cond", + source_target=self.value, + should_flatten_outputs=True, + # TODO - removing consts from control flow ops need more work + remove_consts_from_outputs=False, + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + # need to ensure we increase epoch so we don't memoize unbacked bindings + # across different subgraphs which can interfere with runtime assertion + # generation. + tx.fake_mode.epoch += 1 + + if not only_consist_of(ret_val, (TensorVariable, ConstantVariable)): + unimplemented( + gb_type="torch.cond: unsupported branch return type", + context=str(ret_val), + explanation="Expected branches to return a possibly nested pytree of tensors or constant ints.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + for ret in ret_val.unpack_var_sequence(tx): + if ret.is_python_constant() and not isinstance( + ret.as_python_constant(), int + ): + unimplemented( + gb_type="torch.cond: unsupported branch return type (constant non-int)", + context=str(ret_val), + explanation="Constants returned from branches must be ints.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + return ret_val, ret_spec, ret_graph, ret_lifted_freevars + + (true_r, true_spec, true_graph, true_lifted_freevars) = speculate_branch(True) + true_nn_modules = dict(tx.output.nn_modules) + + ( + false_r, + false_spec, + false_graph, + false_lifted_freevars, + ) = speculate_branch(False) + false_nn_modules = dict(tx.output.nn_modules) + + same_spec = _make_inlined(tx, pytree.TreeSpec.__eq__)( + true_spec.treespec, false_spec.treespec + ).as_python_constant() + # 3.14: NotImplemented cannot be converted to bool + if same_spec is not NotImplemented and not same_spec: + unimplemented( + gb_type="torch.cond: differing branch outputs", + context=f"true_spec: {true_spec.treespec}, false_spec: {false_spec.treespec}, same_spec: {same_spec}", + explanation="Expected branches to return the same pytree structure.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + ( + true_graph, + false_graph, + true_shared, + _false_shared, + unique_true, + unique_false, + ) = _merge_graph_inputs( + true_graph, + true_lifted_freevars, + "true_branch", + false_graph, + false_lifted_freevars, + "false_branch", + ) + + true_name = tx.output.install_subgraph( + "cond_true", + torch.fx.GraphModule(true_nn_modules, true_graph), + ) + false_name = tx.output.install_subgraph( + "cond_false", + torch.fx.GraphModule(false_nn_modules, false_graph), + ) + + true_node = make_attr(tx, true_name) + false_node = make_attr(tx, false_name) + + p_args = ( + pred.as_proxy(), + true_node, + false_node, + # We pick true_shared but it shouldn't matter + tuple(true_shared + unique_true + unique_false), + ) + + return _call_function_and_unflatten_output( + tx, + torch.ops.higher_order.cond, + p_args, + {}, + None, + true_spec, + true_r, + ) + + +class CallTorchbindHigherOrderVariable(TorchHigherOrderOperatorVariable): + def __init__(self, hop, source, script_obj_var, method_name) -> None: + super().__init__(hop, source) + self.script_obj_var = script_obj_var + self.method_name = method_name + + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .builder import wrap_fx_proxy + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + args_proxy = [arg.as_proxy() for arg in args] + kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple( + [self.script_obj_var.as_proxy(), self.method_name] + args_proxy + ), + kwargs=kwargs_proxy, + ), + ) + + +def validate_subgraph_output_types(output: VariableTracker): + """Verify that that the output of the subgraph is a tensor, + int, bool, SymBool, or SymInt. + """ + from . import TensorVariable + + if non_tensor_output := find_mismatched_vars( + output, TensorVariable, allow_none=True + ): + for out in non_tensor_output: + if ( + isinstance(out, SymNodeVariable) and out.python_type() in (int, bool) + ) or ( + out.is_python_constant() + and isinstance(out.as_python_constant(), (int, bool)) + ): + continue + unimplemented( + gb_type="HOP body output unsupported", + context=f"non-tensor outputs: {non_tensor_output}", + explanation="HigherOrderOperator body's output must consist of tensors or ints/bools only " + f"but got {out.python_type()}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + +class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="while_loop doesn't work unless it is captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + return _call_while_loop(self, tx, args, kwargs, stack_output=False) + + +class WhileLoopStackOutputHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="while_loop_stack_output doesn't work unless it is captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + return _call_while_loop(self, tx, args, kwargs, stack_output=True) + + +class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="associative_scan must be captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from torch._higher_order_ops.utils import first_slice_copy + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + def arg_extractor(combine_fn, xs, additional_inputs): + return combine_fn, xs, additional_inputs + + combine_fn, xs, additional_inputs = arg_extractor(*args, **kwargs) + + if args[0].python_type() is functools.partial: + # This is the standard case when the user calls the frontend + # and the frontend invokes dynamo + if len(args) != 2: + unimplemented( + gb_type="torch.associative_scan: improper args", + context=f"args: {args}", + explanation=f"torch.associative_scan expects 2 positional arguments (got {len(args)}) " + "Usage: associative_scan(combine_fn, xs)", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + xs_treespec = args[0].keywords["spec"] + + # combine_fn input check + # We need to get the pure combine_fn from the functools.partial + _check_supported_callable_arg( + tx, combine_fn.keywords["combine_fn"], "combine_fn" + ) + else: + # This case is hit during re-tracing, for example in export tests + # In this case, the combine_fn is a callable and not a functools.partial + xs_treespec = _make_inlined(tx, pytree.tree_structure)(xs) + + _check_supported_callable_arg(tx, combine_fn, "combine_fn") + + # xs input check + if not isinstance(xs, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.associative_scan: improper xs", + context=str(xs), + explanation=f"Expected xs to be a list/tuple but got {xs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + xs_vars = xs.unpack_var_sequence(tx) + _check_all_tensorvariable(xs_vars) + + # additional_inputs input check + if not isinstance(additional_inputs, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.associative_scan: improper additional_inputs", + context=str(additional_inputs), + explanation=f"Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + additional_inputs_vars = additional_inputs.unpack_var_sequence(tx) + _check_all_tensorvariable(additional_inputs_vars) + + scan_length = get_fake_value(xs_vars[0].as_proxy().node, tx).size()[0] + if scan_length == 0: + unimplemented( + gb_type="torch.associative_scan: zero-sized tensor", + context=str(xs_vars[0]), + explanation="associative_scan() operator doesn't support zero-sized tensors during tracing.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # Trace the subgraph + # The sub_args is a slice of original input, e.g. if input.size is (3, 4), and scan dim=0 + # the sub_args shape will be (4, ). + with discard_graph_changes(tx): + sub_args = [ + _make_inlined(tx, first_slice_copy)(leaf) + for leaf in itertools.chain(xs_vars, xs_vars) + ] + sub_args_additional_inputs = [ + t.call_method(tx, "clone", args=(), kwargs={}) + for t in additional_inputs_vars + ] + + sub_args = sub_args + sub_args_additional_inputs + ( + (combine_result, _combine_spec), + combine_graph, + combine_lifted_freevars, + ) = speculate_subgraph( + tx, + combine_fn, + sub_args, + sub_kwargs={}, + description="associative_scan_combine_fn", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + # Ensure that the output of scan is a flattened list of elements, + # because downstream operations assume that the output of HOPs + # is flattened + output_node = combine_graph.find_nodes(op="output")[0] + output_node.args = (pytree.tree_leaves(output_node.args),) + combine_graph.lint() + + # Collect the results from the combine_fn + results, _combine_treespec = _make_inlined(tx, pytree.tree_flatten)( + combine_result + ).unpack_var_sequence(tx) + + # Check whether the combine_fn returns one child tree for the output. + if _combine_treespec.as_python_constant().num_leaves < 1: + unimplemented( + gb_type="torch.associative_scan: combine_fn improper number of leaves", + context=str(_combine_treespec.as_python_constant()), + explanation="combine_fn needs to produce one pytree for the output " + f"but combine_fn produces the pytree {_combine_treespec.as_python_constant()}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # Check whether the outs produced by combine_fn has the same treespec as xs + # We need to have this check this way, because in case init is a TreeSpec and carry + # but carry is only a LeafSpec, these two cannot be compared correctly. + if ( + xs_treespec.as_python_constant().is_leaf() + != _combine_treespec.as_python_constant().is_leaf() + ) or not _make_inlined(tx, pytree.TreeSpec.__eq__)( + xs_treespec, _combine_treespec + ).as_python_constant(): + unimplemented( + gb_type="torch.associative_scan: mismatched input/output tree structure", + context=f"xs: {xs_treespec.as_python_constant()}, output: {_combine_treespec.as_python_constant()}", + explanation="The tree structure of the xs and the outs of the combine_fn are are expected to be identical, but got " + f"xs: {xs_treespec.as_python_constant()} vs output: {_combine_treespec.as_python_constant()}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # We set include contiguity=False because we have vmap x HOP tests, where if + # include_contiguity=True will call t.is_contiguous inside of vmap and get an error + # "querying is_contiguous inside of vmap for memory_format other than + # torch.contiguous_format is not yet implemented". This is okay because stride + # is still checked. + check_meta_consistency_vt( + [_make_inlined(tx, first_slice_copy)(t) for t in xs_vars], + results.items, + "initial_xs", + "combine_fn_output", + include_contiguity=False, + ) + + combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) + combine_freevars_proxy = tuple(combine_lifted_freevars.keys()) + + # Compute the proxies for the input check + proxy_vars_inputcheck = ( + tuple(sarg.as_proxy() for sarg in sub_args) + combine_freevars_proxy + ) + + from torch._higher_order_ops.utils import _maybe_fake_tracing + from torch._inductor.utils import is_pointwise_use + + with tx.fake_mode: + sub_args_fake = [ + ( + leaf.node.meta["example_value"].clone() + if hasattr(leaf.node.meta["example_value"], "clone") + else leaf.node.meta["example_value"] + ) + for leaf in pytree.tree_leaves(proxy_vars_inputcheck) + ] + pre_dispatch = False + + fx = _maybe_fake_tracing( + combine_gm, sub_args_fake, pre_dispatch=pre_dispatch + ) + + for node in fx.graph.nodes: + # Check that the combine_fn is pointwise, if combine_mode='pointwise' + if not all( + is_pointwise_use(use) or use.op == "output" for use in node.users + ): + raise RuntimeError( + "For combine_mode='pointwise', the combine_fn needs to be pointwise" + ) + + combine_fn_name = tx.output.install_subgraph( + "associative_scan_combine_fn", combine_gm + ) + + # Compute the proxies + xs_proxy = xs.as_proxy() + combine_freevars_proxy = tuple(combine_lifted_freevars.keys()) + additional_inputs_proxy = additional_inputs.as_proxy() + combine_freevars_proxy + + p_args = ( + make_attr(tx, combine_fn_name), + xs_proxy, + additional_inputs_proxy, + ) + + return _call_function_and_unflatten_output( + tx, + torch.ops.higher_order.associative_scan, + p_args, + {}, + None, + OutputSpec(xs_treespec), + None, + ) + + +class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="scan must be captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from torch._higher_order_ops.scan import _extract_carry_and_out + from torch._higher_order_ops.utils import first_slice_copy + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + # combine_fn input check + def _check_combine_fn_is_normalized(combine_fn_var): + if not isinstance( + combine_fn_var, + ( + variables.nn_module.NNModuleVariable, + variables.nn_module.UnspecializedNNModuleVariable, + variables.FunctoolsPartialVariable, + ), + ): + unimplemented( + gb_type="torch.scan: improper combine_fn", + context=str(combine_fn_var), + explanation="Expected combine_fn to be wrapped as functools.partial in scan user-facing api " + f"or a graph module if we're re-exporting but got {combine_fn_var.python_type()}.", + hints=[ + *graph_break_hints.DIFFICULT, + ], + ) + return isinstance( + combine_fn_var, + ( + variables.nn_module.NNModuleVariable, + variables.nn_module.UnspecializedNNModuleVariable, + ), + ) + + def arg_extractor(combine_fn, init, xs, additional_inputs): + return combine_fn, init, xs, additional_inputs + + combine_fn, init, xs, additional_inputs = arg_extractor(*args, **kwargs) + init_vars = init.unpack_var_sequence(tx) + xs_vars = xs.unpack_var_sequence(tx) + additional_inputs_vars = additional_inputs.unpack_var_sequence(tx) + + # combine_fn input check + combine_fn_is_normalized = _check_combine_fn_is_normalized(combine_fn) + if combine_fn_is_normalized: + combine_gm = combine_fn.value + assert isinstance(combine_gm, torch.fx.GraphModule), ( + combine_fn, + combine_gm, + ) + else: + # combine_fn input check + # We need to get the pure combine_fn from the functools.partial + _check_supported_callable_arg( + tx, combine_fn.keywords["combine_fn"], "combine_fn" + ) + # xs input check + if not isinstance(xs, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.scan: improper xs", + context=str(xs), + explanation=f"Expected xs to be a list/tuple but got {xs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + # init input check + if not isinstance(init, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.scan: improper init", + context=str(init), + explanation=f"Expected init to be a list/tuple with at least one element but got {init.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + if len(init_vars) == 0: + unimplemented( + gb_type="torch.scan: no init leaves", + context="", + explanation="Expected init leaves.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + # additional_inputs input check + if not isinstance(additional_inputs, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.scan: improper additional_inputs", + context=str(additional_inputs), + explanation=f"Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + # scan_length check + scan_length = get_fake_value(xs_vars[0].as_proxy().node, tx).size()[0] + if scan_length == 0: + unimplemented( + gb_type="torch.scan: zero-sized tensor", + context=str(xs_vars[0]), + explanation="associative_scan() operator doesn't support zero-sized tensors during tracing.", + hints=[ + *graph_break_hints.USER_ERROR, + *graph_break_hints.SUPPORTABLE, + ], + ) + _check_all_tensorvariable(init_vars) + _check_all_tensorvariable(xs_vars) + _check_all_tensorvariable(additional_inputs_vars) + + with discard_graph_changes(tx): + sub_args_init = [ + ini.call_method(tx, "clone", args=(), kwargs={}) for ini in init_vars + ] + # The sub_args_inp is a slice of original input, e.g. if input.size is (3, 4), and scan dim=0 + # the sub_args_inp shape will be (4, ). + sub_args_inp = [_make_inlined(tx, first_slice_copy)(inp) for inp in xs_vars] + sub_args_additional_inputs = [ + t.call_method(tx, "clone", args=(), kwargs={}) + for t in additional_inputs_vars + ] + + sub_args = sub_args_init + sub_args_inp + sub_args_additional_inputs + ( + (combine_result, _combine_spec), + combine_graph, + combine_lifted_freevars, + ) = speculate_subgraph( + tx, + combine_fn, + sub_args, + sub_kwargs={}, + description="scan_combine_fn", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + # Ensure that the output of scan is a flattened list of elements, + # because downstream operations assume that the output of HOPs + # is flattened + output_node = combine_graph.find_nodes(op="output")[0] + output_node.args = (pytree.tree_leaves(output_node.args),) + combine_graph.lint() + combine_freevars_proxy = list(combine_lifted_freevars.keys()) + combine_result_vars = combine_result.unpack_var_sequence(tx) + + if combine_fn_is_normalized: + carry_vars, out_vars = _extract_carry_and_out( + combine_result_vars, len(init_vars) + ) + else: + if len(combine_result_vars) != 2: + unimplemented( + gb_type="torch.scan: improper combine_fn number of returns", + context=str(combine_result_vars), + explanation=f"Expect combine_fn to return a tuple (next_carry, y) but got {combine_result_vars}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + carry_tree, out_vars = combine_result_vars + carry_vars, _ = _make_inlined(tx, pytree.tree_flatten)( + carry_tree + ).unpack_var_sequence(tx) + carry_vars = carry_vars.unpack_var_sequence(tx) + out_vars = _make_inlined(tx, pytree.tree_leaves)( + out_vars + ).unpack_var_sequence(tx) + + # additional output checking + _combine_spec = OutputSpec( + _make_inlined(tx, pytree.tree_structure)(combine_result) + ) + + check_meta_consistency_vt( + init_vars, + carry_vars, + "init", + "carry", + ) + + # Check meta data of carries and inits. If we pass this stage, we are sure that the init and carries + # have the same tree structure. + # We set include contiguity=False because we have vmap x HOP tests, where if + # include_contiguity=True will call t.is_contiguous inside of vmap and get an error + # "querying is_contiguous inside of vmap for memory_format other than + # torch.contiguous_format is not yet implemented". This is okay because stride + # is still checked. + check_meta_consistency_vt( + init_vars, + carry_vars, + "init", + "carry", + include_contiguity=False, + ) + + xs_proxy = xs.as_proxy() + init_proxy = init.as_proxy() + additional_inputs_proxy = list(additional_inputs.as_proxy()) + list( + combine_freevars_proxy + ) + + combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) + combine_fn_name = tx.output.install_subgraph("scan_combine_fn", combine_gm) + + p_args = ( + make_attr(tx, combine_fn_name), + init_proxy, + xs_proxy, + additional_inputs_proxy, + ) + + return _call_function_and_unflatten_output( + tx, + torch.ops.higher_order.scan, + p_args, + {}, + None, + _combine_spec, + None, + ) + + +def non_single_tensor_return_unsupported(api, ret): + if not ret.is_tensor(): + unimplemented( + gb_type="non-single Tensor return unsupported", + context=f"api: {api}, ret: {ret}", + explanation=f"{api} over function that returns something other than one Tensor.", + hints=[], + ) + + +class MapHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="map doesn't work unless it is captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + if len(kwargs) > 0: + unimplemented( + gb_type="torch.map: kwargs not supported", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"torch.map expects no keyword arguments (got {len(kwargs)})", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + _check_supported_callable_arg(tx, args[0], "map_fn") + + # args = f, flat_xs, flat_args + assert isinstance(args[1], (ListVariable, TupleVariable)), args[1] + assert isinstance(args[2], (ListVariable, TupleVariable)), args[2] + unpacked_xs = args[1].unpack_var_sequence(tx) + unpacked_args = args[2].unpack_var_sequence(tx) + + sample_shape = get_fake_value(unpacked_xs[0].as_proxy().node, tx).size() + + if len(sample_shape) < 1 or sample_shape[0] == 0: + unimplemented( + gb_type="torch.map: improper inputs", + context=str(sample_shape), + explanation="torch.map doesn't support scalar or non-zero sized tensors during tracing.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # To get the example output from map() we will need to provide at least one sample to + # the loop body. In our case we will always use xs[0], and our map() won't support zero + # sized tensor during tracing. + with discard_graph_changes(tx): + sliced_xs = [ + xs.call_method( + tx, + "select", + args=(VariableTracker.build(tx, 0), VariableTracker.build(tx, 0)), + kwargs={}, + ) + for xs in unpacked_xs + ] + + # TODO: Support kwargs + ( + (body_r, body_spec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + args[0], + [ + *sliced_xs, + *unpacked_args, + ], + {}, + "torch.ops.higher_order.map", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + should_flatten_outputs=True, + # TODO - removing consts from control flow ops need more work + remove_consts_from_outputs=False, + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + # Check all outputs of map are tensors. + # For map, outputting None is OK, thus ignore None values in the check + body_r_vars = body_r.unpack_var_sequence(tx) + none_mask = [x.is_constant_none() for x in body_r_vars] + _check_all_tensorvariable( + [br for bm, br in zip(none_mask, body_r_vars) if not bm] + ) + + body_nn_modules = dict(tx.output.nn_modules) + + body_name = tx.output.install_subgraph( + "map_body", + torch.fx.GraphModule(body_nn_modules, body_graph), + ) + + body_node = make_attr(tx, body_name) + + p_args = ( + body_node, + [xs.as_proxy() for xs in unpacked_xs], + [arg.as_proxy() for arg in unpacked_args] + + list(body_lifted_freevars.keys()), + ) + + return _call_function_and_unflatten_output( + tx, torch.ops.higher_order.map_impl, p_args, {}, None, body_spec, body_r + ) + + +class PrintHigherOrderVariable(TorchHigherOrderOperatorVariable): + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + args_proxy = [arg.as_proxy() for arg in args] + kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple(args_proxy), + kwargs=kwargs_proxy, + ), + ) + + +class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable): + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + # This is operator for delegation within Executorch which calls a + # specific function in the given lowered module with the given + # operators. The actual operator is defined in the Executorch codebase. + # This is a bad hierarchical violation since + # executorch_call_delegate sits at a higher level than dynamo, but + # there's no real solution to this issue yet. + if len(kwargs) > 0: + unimplemented( + gb_type="executorch_call_delegate: kwargs not supported", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"executorch_call_delegate expects no keyword arguments (got {len(kwargs)})", + hints=[], + ) + if isinstance(args[0], variables.NNModuleVariable): + lowered_module = tx.output.get_submodule(args[0].module_key) + lowered_node = make_attr(tx, args[0].module_key) + elif isinstance(args[0], variables.UnspecializedNNModuleVariable): + # This nn module is special sa delegated by executorch. Just + # install it as a attr in the graph. + lowered_module = args[0].value + lowered_node = tx.output.register_static_attr_and_return_proxy( + "delegate", lowered_module + ) + + p_args = tuple(arg.as_proxy() for arg in args[1:]) + real_sub_args = pytree.tree_map_only( + torch.fx.Proxy, lambda a: get_fake_value(a.node, tx), p_args + ) + + with tx.fake_mode: + example_value = lowered_module.original_module.module()(*real_sub_args) + + # NOTE [Guaranteeing the 1-1 correspondence of FakeTensors and real tensors]: + # executorch modules promise not to alias inputs and outputs. + # Thus, output FakeTensors will correctly not alias input FakeTensors. + _assert_tensors_nonaliasing(real_sub_args, example_value) + + p_args = (lowered_node,) + p_args + + # Store the invocation as a call + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple(p_args), + kwargs={}, + ), + example_value=example_value, + ) + + +class FunctorchHigherOrderVariable(UserFunctionVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return super().call_function(tx, args, kwargs) + + def should_allow_nested_graph_breaks(self): + return False + + +class FunctionalCallVariable(FunctorchHigherOrderVariable): + def call_function( + self, tx, args: list[VariableTracker], kwargs: dict[str, VariableTracker] + ) -> VariableTracker: + if not torch._dynamo.config.inline_inbuilt_nn_modules: + unimplemented( + gb_type="torch.func.functional_call capture is disabled", + context="", + explanation="torch.func.functional_call capture is disabled", + hints=[ + "Set `torch._dynamo.config.inline_inbuilt_nn_modules=True` to enable.", + ], + ) + return super().call_function(tx, args, kwargs) + + +class ReparametrizeModuleCallVariable(FunctorchHigherOrderVariable): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def call_function( + self, tx, args: list[VariableTracker], kwargs: dict[str, VariableTracker] + ) -> VariableTracker: + ctx_manager_vt = super().call_function(tx, args, kwargs) + return RepararametrizeModuleContextVariable(ctx_manager_vt, args[0]) + + +class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = True + supports_aliasing = True + allow_side_effects = False + + def install_subgraph_in_output_graph( + self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body" + ): + return tx.output.install_subgraph( + f"{attr_name}", + body_gmod, + ) + + def create_wrapped_node( + self, + tx: "InstructionTranslator", + fn_vt, + fn_args_vt, + kwargs, + description, + *, + subgraph_name="wrap_body", + ): + # See NOTE [HigherOrderOperator tracing design] for more details + ( + body_r, + body_graph, + body_lifted_freevars, + body_graph_output_vts, + ) = speculate_subgraph_with_auto_output_flattening( + tx, + fn_vt, + fn_args_vt, + kwargs, + description, + source_target=self.value, + allow_side_effects=self.allow_side_effects, + filter_aliased_intermediates=getattr( + self, "filter_aliased_intermediates", False + ), + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = self.install_subgraph_in_output_graph( + tx, + fn_vt, + fn_args_vt, + kwargs, + body_gmod, + attr_name=subgraph_name, + ) + body_node = make_attr(tx, body_name) + + # Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`, + # all the arguments are lifted. + lifted_args = tuple(arg for arg in body_lifted_freevars) + + proxy_args = (body_node,) + lifted_args + + example_value = pytree.tree_map_only( + torch.fx.Node, + lambda a: a.meta["example_value"], + body_graph.find_nodes(op="output")[0].args[0], + ) + + return ( + proxy_args, + {}, + example_value, + body_r, + body_gmod, + body_name, + body_graph_output_vts, + ) + + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # This flattens the kwargs into lifted args + ( + p_args, + p_kwargs, + _example_value, + body_r, + _, + _, + body_graph_output_vts, + ) = self.create_wrapped_node(tx, args[0], args[1:], kwargs, "wrap") + + if len(p_kwargs) > 0: + unimplemented( + gb_type="WrapHigherOrderVariable: kwargs unexpected", + context=f"args: {args}, kwargs: {kwargs}", + explanation="kwargs should have been flattened into lifted args.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + return _call_function_with_auto_output_flattening( + tx, + self.value, + tuple(p_args), + p_kwargs, + _example_value, + body_r, + body_graph_output_vts, + ) + + +class WrapWithSetGradEnabledHigherOrderVariable(TorchHigherOrderOperatorVariable): + """ + This hop is not exposed to users but is inserted into the graph + after export as a post-processing step. + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + if kwargs: + unimplemented( + gb_type="wrap_with_set_grad_enabled: unexpected kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"wrap_with_set_grad_enabled expects no keyword arguments (got {len(kwargs)}).", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + grad_enabled, fn_var, *rest_args = args + + if not grad_enabled.is_python_constant(): + unimplemented( + gb_type="wrap_with_set_grad_enabled: non-constant grad_enabled", + context=str(grad_enabled), + explanation="wrap_with_set_grad_enabled expects grad_enabled argument to be a constant.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + _check_supported_callable_arg(tx, fn_var, "enable_grad_fn") + + with torch.set_grad_enabled(grad_enabled.as_python_constant()): + ( + (body_r, treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + fn_var, + [*rest_args], + {}, + "torch.ops.higher_order.wrap_with_set_grad_enabled", + source_target=self.value, + set_subgraph_inputs="manual", + should_flatten_outputs=True, + ) + + if len(body_lifted_freevars) > 0: + unimplemented( + gb_type="wrap_with_set_grad_enabled: unexpected freevars", + context=str(body_lifted_freevars), + explanation="wrap_with_set_grad_enabled expects no freevars.", + hints=[], + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = tx.output.install_subgraph( + "wrap_body", + body_gmod, + ) + + body_node = make_attr(tx, body_name) + + proxy_args = tuple( + [ + grad_enabled.as_python_constant(), + body_node, + ] + + [operand.as_proxy() for operand in rest_args] + ) + example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + return _call_function_and_unflatten_output( + tx, self.value, proxy_args, {}, example_value, treespec, body_r + ) + + +class WrapWithAutocastHigherOrderVariable(TorchHigherOrderOperatorVariable): + """ + This hop is not exposed to users but is inserted into the graph + after export as a post-processing step. + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + if kwargs: + unimplemented( + gb_type="wrap_with_autocast: unexpected kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"wrap_with_autocast expects no keyword arguments (got {len(kwargs)}).", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + device_type, dtype, enabled, cache_enabled, fn_var, *rest_args = args + + for arg in [device_type, dtype, enabled, cache_enabled]: + if not arg.is_python_constant(): + unimplemented( + gb_type="wrap_with_autocast: expected constant arg", + context=str(args), + explanation="wrap_with_autocast expects device_type, dtype, enabled, " + "and cache_enabled arguments to be constants.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + _check_supported_callable_arg(tx, fn_var, "autocast") + + python_constants = [ + arg.as_python_constant() + for arg in [device_type, dtype, enabled, cache_enabled] + ] + + with torch.autocast(*python_constants): + ( + (body_r, treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + fn_var, + [*rest_args], + {}, + "torch.ops.higher_order.wrap_with_autocast", + source_target=self.value, + set_subgraph_inputs="manual", + should_flatten_outputs=True, + ) + + if len(body_lifted_freevars) > 0: + unimplemented( + gb_type="wrap_with_autocast: unexpected freevars", + context=str(body_lifted_freevars), + explanation="wrap_with_autocast expects no freevars.", + hints=[], + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = tx.output.install_subgraph( + "wrap_body", + body_gmod, + ) + + body_node = make_attr(tx, body_name) + + proxy_args = tuple( + [ + *python_constants, + body_node, + ] + + [operand.as_proxy() for operand in rest_args] + ) + example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + + return _call_function_and_unflatten_output( + tx, self.value, proxy_args, {}, example_value, treespec, body_r + ) + + +class HintsWrapperHigherOrderVariable(WrapHigherOrderVariable): + def install_subgraph_in_output_graph( + self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body" + ): + return tx.output.install_subgraph( + "hints_wrapper_body", + body_gmod, + ) + + @raise_hard_error_if_graph_break( + reason="hints_wrapper doesn't work unless it is captured completely with torch.compile." + ) + def _call_function( + self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" + ) -> "VariableTracker": + _check_supported_callable_arg(tx, args[0], "body_fn") + + # inputs + if ( + len(args) != 3 + or not isinstance(args[1], (ListVariable, TupleVariable)) + or not isinstance(args[2], ConstDictVariable) + or len(kwargs) != 1 + or "hints" not in kwargs + ): + unimplemented( + gb_type="hints_wrapper: improper args/kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"hints_wrapper expects 3 positional arguments (got {len(args)}) " + f"and 1 keyword argument (got {len(kwargs)}). " + "Usage: hints_wrapper(body_fn, args, kwargs, hints=...). " + "args is expected to be list/tuple and kwargs is expected to be a dict.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + operands = args[1].unpack_var_sequence(tx) + fn_kwargs = args[2].as_python_constant() + + # Use create_wrapped_node from WrapHigherOrderVariable + ( + p_args, + _, + example_value, + body_r, + body_gmod, + _, + body_graph_output_vts, + ) = self.create_wrapped_node( + tx, + args[0], # function + operands, + fn_kwargs, + "hints_wrapper", + ) + + # hints_wrapper expects (body_node, args, kwargs) as positional args + # So we need to restructure p_args from (body_node, *lifted_args) + # to (body_node, lifted_args_tuple, {}) + body_node = p_args[0] + lifted_args = p_args[1:] + p_args = (body_node, tuple(lifted_args), {}) + + # add hints into p_kwargs + p_kwargs = {} + p_kwargs["hints"] = kwargs["hints"].as_python_constant() + + return _call_function_with_auto_output_flattening( + tx, + self.value, + p_args, + p_kwargs, + example_value, + body_r, + body_graph_output_vts, + ) + + +class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable): + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + if len(kwargs) > 0: + unimplemented( + gb_type="out_dtype: unexpected kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"out_dtype expects no keyword arguments (got {len(kwargs)}).", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + p_args = tuple(arg.as_proxy() for arg in args) + op = p_args[0] + output_dtype = p_args[1] + fake_sub_args = pytree.tree_map_only( + torch.fx.Proxy, lambda a: a.node.meta["example_value"], p_args[2:] + ) + # This is a simplified implementation of this operator just for tracing. + # Actual implementation may also first promote the arguments + example_value = op(*fake_sub_args).to(dtype=output_dtype) + + # Store the invocation as a call + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple(p_args), + kwargs={}, + ), + example_value=example_value, + ) + + +class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable): + @raise_hard_error_if_graph_break( + reason="strict_mode HOO doesn't work unless it is captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + unpacked_sequence = args[1].unpack_var_sequence(tx) + # TODO (tmanlaibaatar) support pytree here + for arg in unpacked_sequence: + if isinstance(arg, (ListVariable, TupleVariable, ConstDictVariable)): + unimplemented( + gb_type="strict_mode: improper args", + context=f"args: {args}, kwargs: {kwargs}", + explanation="strict_mode higher order op expects flat inputs (list/tuple/dict)", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + if kwargs: + unimplemented( + gb_type="strict_mode: unexpected kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"strict_mode higher order op expects no keyword arguments (got {len(kwargs)}).", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + ( + (ret_val, ret_spec), + ret_graph, + ret_lifted_freevars, + ) = speculate_subgraph( + tx, + args[0], + unpacked_sequence, + {}, + "strict_mode", + source_target=self.value, + should_flatten_outputs=True, + ) + + strict_mode_nn_modules = dict(tx.output.nn_modules) + + strict_mode_name = tx.output.install_subgraph( + "strict_mode_body", + torch.fx.GraphModule(strict_mode_nn_modules, ret_graph), + ) + + strict_mode_node = make_attr(tx, strict_mode_name) + p_args = ( + strict_mode_node, + tuple(ret_lifted_freevars.keys()), + ) + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + ret_val.as_proxy(), + ) + + return _call_function_and_unflatten_output( + tx, + torch.ops.higher_order.strict_mode, + p_args, + {}, + flat_example_value, + ret_spec, + ret_val, + ) + + +class CheckpointHigherOrderVariable(WrapHigherOrderVariable): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.allow_side_effects = ( + torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint + ) + + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from torch._higher_order_ops.wrap import TagActivationCheckpoint + from torch.utils.checkpoint import noop_context_fn + + context_fn = None + if "context_fn" in kwargs and kwargs["context_fn"] is not noop_context_fn: + ctx = kwargs.pop("context_fn") + if isinstance(ctx, torch._dynamo.variables.UserFunctionVariable): + context_fn = ctx.fn + elif isinstance( + ctx, torch._dynamo.variables.functions.FunctoolsPartialVariable + ): + context_fn = ctx.guard_as_python_constant() + else: + raise NotImplementedError( + f"checkpoint not implemented for {type(ctx)} context_fn" + ) + + checkpoint_kwargs, gmod_kwargs = TagActivationCheckpoint.divide_kwargs(kwargs) + + # Here we use checkpoint_kwargs (and not gmod kwargs). gmod_kwargs are + # already flattened above and managed inside the fx graph. + ( + p_args, + _, + example_value, + _body_r, + checkpointed_gmod, + _, + body_graph_output_vts, + ) = self.create_wrapped_node( + tx, + args[0], + args[1:], + gmod_kwargs, + "torch.utils.checkpoint.checkpoint", + ) + if context_fn is not None: + checkpointed_gmod.meta["_checkpoint_context_fn"] = context_fn + + _, checkpoint_kwargs = proxy_args_kwargs([], checkpoint_kwargs) + + return _call_function_with_auto_output_flattening( + tx, + self.value, + p_args, + checkpoint_kwargs, + example_value, + _body_r, + body_graph_output_vts, + ) + + +class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable): + def __init__(self, hop, source) -> None: + super().__init__(hop, source) + + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + func_var = args[0] + + if isinstance(func_var, torch._dynamo.variables.UserFunctionVariable): + func = func_var.fn + elif isinstance( + func_var, torch._dynamo.variables.functions.FunctoolsPartialVariable + ): + func = func_var.as_python_constant() + else: + raise RuntimeError( + f"DynamoBypassingWrapperHigherOrderVariable: Unsupported function {type(func_var)}" + ) + ( + p_args, + _, + example_value, + _body_r, + gmod, + _, + body_graph_output_vts, + ) = self.create_wrapped_node( + tx, + args[1], + args[2:], + kwargs, + str(func), + ) + + # Alternatively, we could've stored only the function's fqn and + # reconstructed, but that requires the function to be a global. + gmod_meta_key = "_dynamo_bypassing_wrapper_fn" + gmod.meta[gmod_meta_key] = func + + return _call_function_with_auto_output_flattening( + tx, + self.value, + (gmod_meta_key,) + tuple(p_args), + {}, + example_value, + _body_r, + body_graph_output_vts, + ) + + +class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + +class RunWithRNGStateHigherOrderVariable(TorchHigherOrderOperatorVariable): + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + +class AutoFunctionalizeHigherOrderVariable(TorchHigherOrderOperatorVariable): + def _call_function( + self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + +class FlexAttentionBackwardHighOrderVariable(TorchHigherOrderOperatorVariable): + def proxy_submod(self, tx, arg): + assert isinstance(arg.source.base, DictGetItemSource) + submod_name = tx.output.install_subgraph(arg.source.base.index, arg.value) + p_submod = make_attr(tx, submod_name) + set_example_value(p_submod.node, arg.value) + return p_submod + + def to_proxy(self, tx, arg): + if isinstance(arg, UnspecializedNNModuleVariable): + return self.proxy_submod(tx, arg) + elif isinstance(arg, (ListVariable, TupleVariable)): + return arg.python_type()( + self.to_proxy(tx, nested_arg) for nested_arg in arg.items + ) + else: + return arg.as_proxy() + + def _call_function( + self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + try: + p_args = tuple(self.to_proxy(tx, arg) for arg in args) + p_kwargs = {key: self.to_proxy(tx, arg) for key, arg in kwargs.items()} + except (NotImplementedError, Unsupported) as err: + unimplemented( + gb_type="failed to handle argument for FlexAttentionBackward HOP", + context=f"args: {args}, kwargs: {kwargs}", + explanation="Missing Dynamo support for FlexAttentionBackward HOP argument.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + from_exc=err, + ) + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + +class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): + """ + Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace + by unwrapping the higher order op and inlining through it. This op + is created by dynamo to survive through AotAutograd, then unwrapped + here in the call to dynamo from compiled autograd. + """ + + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + kwargs = dict(kwargs) + fn = kwargs.pop("fn") + return fn.call_function(tx, args, kwargs) + + +class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable): + @staticmethod + def normalize_to_args(args, kwargs): + # input signature is (query, key, value, score_mod, block_mask, *other_buffers), + # block_mask is a tuple, and we don't want to flatten it. + # only flatten kwargs into lists + flat_kwargs = pytree.tree_flatten(kwargs)[0] + + # Combine the flattened lists + all_args = args + flat_kwargs + return all_args + + def create_wrapped_node( + self, + tx: "InstructionTranslator", + query: "VariableTracker", + fn: "VariableTracker", + fn_name: str, + ): + from .._trace_wrapped_higher_order_op import TransformGetItemToIndex + + def create_scalar(): + return query.call_method( + tx, + "new_empty", + (VariableTracker.build(tx, []),), + { + "dtype": VariableTracker.build(tx, torch.int32), + }, + ) + + with discard_graph_changes(tx): + bhmn = [create_scalar() for _ in range(4)] + if fn_name == "score_mod": + scores_require_grad: bool = query.requires_grad + score = query.call_method( + tx, + "new_empty", + (VariableTracker.build(tx, []),), + {"requires_grad": VariableTracker.build(tx, scores_require_grad)}, + ) + new_args = [score, *bhmn] + else: + assert fn_name == "mask_fn", "Illegal function name: " + fn_name + new_args = [*bhmn] + + with TransformGetItemToIndex(): + ( + (_body_output, _body_spec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + fn, + new_args, + {}, # expect only args no kwargs for now + description=fn_name, + source_target=self.value, + set_subgraph_inputs="flatten_manual", + ) + + body_name = tx.output.install_subgraph( + fn_name, + torch.fx.GraphModule(tx.output.nn_modules, body_graph), + ) + + body_node = make_attr(tx, body_name) + + # It is possible that the score-mod function captures some free variables that are not + # passed in as arguments. In this case, we need to lift them, which is handled by speculate_subgraph. + # We then need to create proxies for this + the inputs. + + lifted_args = tuple(arg for arg in body_lifted_freevars) + + proxy_args = (body_node, lifted_args) + + return proxy_args + + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + ( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + ) = self.normalize_to_args(args, kwargs) + + score_mod_node, score_mod_lifted_args = self.create_wrapped_node( + tx, query, score_mod, "score_mod" + ) + mask_fn = block_mask.items[-1] # type: ignore[attr-defined] + if mask_fn.is_python_constant() and mask_fn.as_python_constant() is None: + mask_fn = UserFunctionVariable( + torch.nn.attention.flex_attention.noop_mask, + source=mask_fn.source, + ) + mask_fn_node, mask_fn_lifted_args = self.create_wrapped_node( + tx, query, mask_fn, "mask_fn" + ) + + proxied_args = [ + query, + key, + value, + TupleVariable(block_mask.items[:-1], source=block_mask.source), + scale, + kernel_options, + ] + + # Store the invocation as a call + # Norm_kwargs contains the score_function and we dont want to proxy this because + # Proxying user defined functions is not supported. + inp_args, _ = proxy_args_kwargs(proxied_args, {}) + + # Compose the ordered HOO args: + # - inp_args: [query, key, value, block_mask, scale, kernel_options] + # - subgraph node: [score_mod, mask_fn_node] + # - lifted args from tracing subgraph: [score_mod_other_buffers, mask_fn_other_buffers] + _, _, _, inp_arg_block_mask, inp_arg_scale, inp_arg_kernel_options = inp_args + block_mask = tuple(inp_arg_block_mask + (mask_fn_node,)) + with torch.fx.experimental.proxy_tensor.set_original_aten_op(self.value): + proxy = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=inp_args[:3] + + ( + score_mod_node, + block_mask, + inp_arg_scale, + inp_arg_kernel_options, + score_mod_lifted_args, + mask_fn_lifted_args, + ), + kwargs={}, + ), + example_value=None, + ) + return proxy + + +class AutogradFunctionApplyVariable(VariableTracker): + def __init__(self, fwd_graph, bwd_graph, parent_source, **kwargs) -> None: + super().__init__(**kwargs) + self.fwd_graph = fwd_graph + self.bwd_graph = bwd_graph + self.parent_source = parent_source + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ( + AutogradFunctionContextVariable, + UserDefinedClassVariable, + UserFunctionVariable, + UserMethodVariable, + ) + from .builder import wrap_fx_proxy + + """ + Consider the following: + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x.sin() + @staticmethod + def backward(ctx, grad): + x, = ctx.saved_tensors + return grad * x.cos() + We want the resulting graphs to look like: + def fwd(ctx, x): + # (output, saved tensors / attrs) + return (x.sin(), [x]) + # bwd(ctx, grad0, grad1, ..., gradn, *saved_tensors_or_attrs) + def bwd(ctx, grad, x): + return grad * x.cos() + To accomplish this, we're going to: + 1. Construct a ctx object + 2. (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph on MySin.forward (manually_set_inputs=True) + 3. (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph on MySin.backward, while manually setting + the ctx and grad inputs. + 4. Manually rewriting the fwd graph's output to be (output, stuff_that_gets_used in bwd_graph) + Getting from 3 to 4 is pretty elegant: stuff_that_gets_used in bwd graph is + just the bwd_freevars returned from speculate_subgraph, assuming MySin.backward + doesn't capture any arguments. + All these steps work if MySin.backward doesn't capture any values. This is a + limitation in general that we should check for. + """ + + prev_side_effects = tx.output.side_effects.clone() + fwd_tracer = torch._dynamo.output_graph.SubgraphTracer( + tx.output, + parent=tx.output.current_tracer, + source_target="autograd.Function", + ) + + ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) + with discard_graph_changes(tx): + # A little hacky, but we need a dummy ctx proxy for speculate_subgraph. + # We should clean this up at some point. + proxy = tx.output.create_proxy( + "call_function", torch.autograd.function.FunctionCtx, (), {} + ) + set_example_value(proxy.node, ctx.value) + ctx.proxy = proxy + + if isinstance(self.fwd_graph, types.FunctionType): + fwd_fn = UserFunctionVariable(self.fwd_graph) + fwd_args = [ctx, *args] + elif isinstance(self.fwd_graph, types.MethodType): + fwd_fn = UserMethodVariable( + self.fwd_graph.__func__, + UserDefinedClassVariable(self.fwd_graph.__class__), + ) + fwd_args = [fwd_fn.obj, ctx, *args] + else: + unimplemented( + gb_type="autograd.Function.apply: non-function or method forward", + context=str(self.fwd_graph), + explanation="Expected forward function to be a function or method.", + hints=[], + ) + + # Speculate subgraph on the fwd + (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph( + tx, + fwd_fn, + fwd_args, + kwargs, + "autograd.Function", + enable_grad=False, + set_subgraph_inputs="semi_automatic", + restore_side_effects=False, + tracer=fwd_tracer, + ) + + if ctx in tx.output.side_effects.store_attr_mutations: + if ( + "_materialize_non_diff_grads" + in tx.output.side_effects.store_attr_mutations[ctx] + ): + unimplemented( + gb_type="autograd.Function.apply: _materialize_non_diff_grads mutation", + context="", + explanation="Mutations to autograd.Function.ctx._materialize_non_diff_grads are not supported.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + bwd_tracer = torch._dynamo.output_graph.SubgraphTracer( + tx.output, + parent=fwd_tracer, + source_target="autograd.Function", + ) + + # Speculate subgraph on the backward. We make the + # bwd tracer a child of the fwd tracer, because backward may rely on + # tensors/attrs created in the fwd tracer. + + if isinstance(fwd_out, variables.BaseListVariable): + bwd_args = [ctx, *fwd_out.items] + else: + bwd_args = [ctx, fwd_out] + + bwd_src = AttrSource(self.parent_source, member="backward") + if isinstance(self.bwd_graph, types.FunctionType): + bwd_fn = UserFunctionVariable(self.bwd_graph, source=bwd_src) + elif isinstance(self.bwd_graph, types.MethodType): + bwd_fn = UserMethodVariable( + self.bwd_graph.__func__, + UserDefinedClassVariable(self.bwd_graph.__class__), + source=bwd_src, + ) + bwd_args = [bwd_fn.obj, *bwd_args] + else: + unimplemented( + gb_type="autograd.Function.apply: non-function or method backward", + context=str(self.bwd_graph), + explanation="Expected backward function to be a function or method.", + hints=[], + ) + + def is_strict_for(v: VariableTracker): + if v.is_tensor(): + # we can be more lax for stuff from forward + return v.proxy.tracer is not fwd_tracer + return True + + with ( + tx.output.subtracer(fwd_fn, fwd_tracer), + tx.strict_translation_mode(is_strict_for), + ): + try: + (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph( + tx, + bwd_fn, + bwd_args, + kwargs, + "autograd.Function", + enable_grad=False, + set_subgraph_inputs="manual", + restore_side_effects=False, + tracer=bwd_tracer, + ) + except torch._dynamo.exc.Unsupported as e: + if isinstance( + e, torch._dynamo.exc.UnknownPropertiesDuringBackwardTrace + ): + from unittest import mock + + bwd_tracer = torch._dynamo.output_graph.SubgraphTracer( + tx.output, + parent=fwd_tracer, + source_target="autograd.Function", + ) + from .._trace_wrapped_higher_order_op import ( + autograd_function_backward_rewritten, + ) + + if isinstance(self.bwd_graph, types.FunctionType): + bwd_fn = UserFunctionVariable( + autograd_function_backward_rewritten(self.bwd_graph) + ) + elif isinstance(self.bwd_graph, types.MethodType): + bwd_fn = UserMethodVariable( + autograd_function_backward_rewritten( + self.bwd_graph.__func__ + ), + UserDefinedClassVariable(self.bwd_graph.__class__), + ) + else: + unimplemented( + gb_type="autograd.Function.apply: non-function or method backward (2)", + context=str(self.bwd_graph), + explanation="Expected backward function to be a function or method.", + hints=[], + ) + + with mock.patch( + "torch._dynamo.config._autograd_backward_strict_mode_conditional_banned_ops", + [], + ): + (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph( + tx, + bwd_fn, + bwd_args, + kwargs, + "autograd.Function", + enable_grad=False, + set_subgraph_inputs="manual", + restore_side_effects=False, + tracer=bwd_tracer, + ) + else: + raise e + + # TODO: assert that bwd_graph didn't capture values that were + # not created inside fwd_graph. + + # TODO(oulgen): Ideally, we would not do a linear search for output + # node but as things currently are there could be nodes after the + # output node + # This is bug prone as if there's code after the output node, then + # graph.output will append the output at the very end + # This might be a behavior difference + + # If users call ctx.mark_non_differentiable, we should capture these output tensors who + # are marked as non-differentiable and pass them to ApplyTemplate + # at torch._functorch.autograd_function.AutogradFunctionApply for reconstruction. + non_differentiable_idx = [] + if ctx.non_differentiable is not None: + non_differentiable_set = set(ctx.non_differentiable) + assert isinstance(fwd_out, variables.BaseListVariable) + for i, x in enumerate(fwd_out.items): + if x.is_tensor() and x.as_proxy() in non_differentiable_set: + non_differentiable_idx.append(i) + + # Rewrite the output of fwd_graph to (output, stuff_necessary_for_bwd) + for node in fwd_graph.find_nodes(op="output"): + fwd_graph.erase_node(node) + break + + # Because we lift the bwd_freevars as inputs of the bwd_graph, + # we have to manually add the bwd_freevars as output of fwd_graph. + # However, the bwd_freevars got from speculate_subgraph use the Proxies in the bwd_graph, + # we need to convert them to Proxies in the fwd_graph and then generate new fwd_graph output. + fwd_proxy_of_bwd_freevars = [] + for k in bwd_freevars: + if k in fwd_freevars: + fwd_proxy_of_bwd_freevars.append(fwd_freevars[k]) + else: + fwd_proxy_of_bwd_freevars.append(k) + + def unwrap_proxy(x): + if isinstance(x, torch.fx.Proxy): + return x.node + else: + assert variables.ConstantVariable.is_literal(x), ( + f"Only constant is allowed. Got {x}" + ) + return x + + new_fwd_graph_outputs = (fwd_out.as_proxy(), fwd_proxy_of_bwd_freevars) + new_fwd_graph_outputs = pytree.tree_map(unwrap_proxy, new_fwd_graph_outputs) + fwd_graph.output(new_fwd_graph_outputs) + fwd_graph.lint() + + # Store fwd_body + fwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate() + fwd_name = tx.output.install_subgraph( + "fwd_body", + torch.fx.GraphModule(fwd_nn_modules.nn_modules, fwd_graph), + ) + + fwd_node = make_attr(tx, fwd_name) + + # The type of original args can be arbitrary, but we only support basic type in FX graph. + # So the speculated subgraph input includes original tensor args and the lifted freevars. + # We need to filter out the original tensor args and concat them with the lifted freevars + # to generate the proxy args for the FX call_function node. + filtered_args = [] + # A boolean list to mark if the type of corresponding argument is tensor. + # This is used to determine if a FX node's argument should be an argument of + # ApplyTemplate.forward and if we should skip the output from ApplyTemplate.backward + # at torch._functorch.autograd_function.AutogradFunctionApply. + args_tensor_mask = [False] * len(args) + for i, arg in enumerate(args): + if arg.is_tensor() or isinstance(arg, SymNodeVariable): + filtered_args.append(arg) + args_tensor_mask[i] = True + + # Rewrite the output of bwd_graph to remove the grad output for the non-Tensor args. + new_bwd_graph_outputs = None + for node in bwd_graph.find_nodes(op="output"): + bwd_graph.erase_node(node) + break + + # The same as the above fwd proxies, we need to use the bwd proxies in the bwd_graph + # if some of the output is from fwd_freevars. + bwd_out_proxy = bwd_out.as_proxy() + bwd_proxy_of_fwd_freevars = [] + if isinstance(bwd_out_proxy, (tuple, list)): + for k in bwd_out_proxy: + if k in bwd_freevars: + bwd_proxy_of_fwd_freevars.append(bwd_freevars[k]) + else: + bwd_proxy_of_fwd_freevars.append(k) + else: + if bwd_out_proxy in bwd_freevars: + bwd_proxy_of_fwd_freevars = bwd_freevars[bwd_out_proxy] + else: + bwd_proxy_of_fwd_freevars = bwd_out_proxy + + # Remove bwd output for non-Tensor args. + output_proxy = bwd_proxy_of_fwd_freevars + if isinstance(output_proxy, (tuple, list)): + new_bwd_graph_outputs = () + for x, mask in zip(output_proxy, args_tensor_mask): + if mask: + new_bwd_graph_outputs = new_bwd_graph_outputs + (x,) + else: + assert x is None, f"Grad of non-Tensor arg {x} is not None." + else: + new_bwd_graph_outputs = output_proxy + + # Update the bwd graph output. + new_bwd_graph_outputs = pytree.tree_map( + lambda x: None if x is None else x.node, new_bwd_graph_outputs + ) + bwd_graph.output(new_bwd_graph_outputs) + bwd_graph.lint() + + # Store bwd_body + bwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate() + bwd_name = tx.output.install_subgraph( + "bwd_body", + torch.fx.GraphModule(bwd_nn_modules.nn_modules, bwd_graph), + ) + + bwd_node = make_attr(tx, bwd_name) + + tx.output.side_effects = prev_side_effects + + p_args = ( + fwd_node, + bwd_node, + *([arg.as_proxy() for arg in filtered_args] + list(fwd_freevars.keys())), + ) + kwargs = { + "args_tensor_mask": args_tensor_mask, + "non_differentiable_idx": non_differentiable_idx, + } + + # Store the invocation as a call + from torch._functorch.autograd_function import autograd_function_apply + + # We use speculate_subgraph to get the fwd graph, but it's always under no grad mode like what eager mode does. + # The fwd outputs (tensor's example_value) need to be inferred from fake tensor prop to get the correct attributes + # (e.g, tensor.requires_grad), which would be used by downstream Dynamo tracing. + # Since there can be other ops like Triton kernels, which depends on python dispatcher, we have to enable it. + with enable_python_dispatcher(), tx.output.fake_mode: + fake_args = ( + tx.output.nn_modules[fwd_node.node.name], + tx.output.nn_modules[bwd_node.node.name], + *( + [ + _get_fake_value(arg) + for arg in filtered_args + list(fwd_freevars.keys()) + ] + ), + ) + example_value = autograd_function_apply(*fake_args, **kwargs) + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + autograd_function_apply, + args=p_args, + kwargs=kwargs, + ), + example_value=example_value, + ) + + +def _get_fake_value(x): + if isinstance(x, variables.VariableTracker): + return x.as_proxy().node.meta["example_value"] + elif isinstance(x, torch.fx.Proxy): + return x.node.meta["example_value"] + else: + return x + + +def maybe_positional_arg_names(func): + result = [] + if not hasattr(func, "get_function"): + return None + try: + fn = func.get_function() + except (Unsupported, NotImplementedError): + return None + try: + sig = inspect.signature(fn) + except ValueError: + return None + for name, param in sig.parameters.items(): + if param.kind is inspect.Parameter.VAR_POSITIONAL: + return None + if ( + param.kind is inspect.Parameter.POSITIONAL_ONLY + or param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD + ): + if name == "self": + # FX graphs can't have a placeholder named self + result.append("self_") + else: + result.append(name) + return result + + +class BaseHOPVariable(WrapHigherOrderVariable): + supports_input_mutation = False + supports_aliasing = False + + def python_type(self): + return type(self.value) + + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + ( + p_args, + p_kwargs, + example_value, + body_r, + _, + _, + body_graph_output_vts, + ) = self.create_wrapped_node( + tx, args[0], args[1:], {}, self.value._name, subgraph_name="subgraph" + ) + assert len(p_kwargs) == 0 + + p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()} + return _call_function_with_auto_output_flattening( + tx, + self.value, + p_args, + p_kwargs, + example_value, + body_r, + body_graph_output_vts, + ) + + +class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable): + supports_input_mutation = True + supports_aliasing = False + allow_side_effects = True + # invoke_subgraph is NOT desugared in AOTAutograd, so the HOP input/output + # shouldn't alias. For checkpoint HOP, we inline it so we don't need + # alias analysis as functionalization would just work on the flat graph. + filter_aliased_intermediates = True + + def install_subgraph_in_output_graph( + self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name + ): + # Check if the subgraph from speculate_subgraph (body_gmod) and the fake + # inputs have already been seen before. If yes, the subgraph is already + # installed in the output graph and we can just access the subgraph + # using the saved attr name. + + if not isinstance(fn_vt, (UnspecializedNNModuleVariable, UserFunctionVariable)): + unimplemented( + gb_type="Encountered non user function variable during invoke_subgraph HOP tracing", + context=str(fn_vt), + explanation="invoke_subgraph does not support non user function variable", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + invoke_subgraph_cache = ( + tx.output.tracing_context.hop_dispatch_set_cache.get_cache( + torch._higher_order_ops.invoke_subgraph + ) + ) + + if isinstance(fn_vt, UserFunctionVariable): + fn_id = id(fn_vt.get_function()) + fn_name = fn_vt.get_function().__name__ + else: + assert isinstance(fn_vt, UnspecializedNNModuleVariable) + fn_id = id(fn_vt.value.forward.__func__) + fn_name = fn_vt.value.forward.__name__ + previously_installed_submodules = [] + if invoke_subgraph_cache: + previously_installed_submodules = ( + invoke_subgraph_cache.get_dynamo_installed_submodules(fn_id) + ) + current_mod = body_gmod + # NB - reverse is more likely to cause a hit sooner because first + # graph can have requires_grad=False for a few inputs + for submodule_name in reversed(previously_installed_submodules): + assert submodule_name in tx.output.nn_modules + previous_mod = tx.output.nn_modules[submodule_name] + if are_same_graph_modules( + fn_name, previous_mod, current_mod, tx.fake_mode + ): + return submodule_name + + body_name = super().install_subgraph_in_output_graph( + tx, fn_vt, fn_args_vt, kwargs, body_gmod, "subgraph" + ) + hc_log.debug( + "%s: Installing subgraph with identifier '%s', bringing total count for '%s' function to %s", + fn_name, + body_name, + fn_name, + len(previously_installed_submodules) + 1, + ) + if invoke_subgraph_cache: + invoke_subgraph_cache.add_dynamo_installed_submodule(fn_id, body_name) + + return body_name + + @raise_hard_error_if_graph_break( + reason="torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # This flattens the kwargs into lifted args + ( + p_args, + p_kwargs, + example_value, + body_r, + _, + body_name, + body_graph_output_vts, + ) = self.create_wrapped_node(tx, args[0], args[1:], kwargs, "invoke_subgraph") + + if len(p_kwargs) > 0: + unimplemented( + gb_type="invoke_subgraph: kwargs unexpected", + context=f"args: {args}, kwargs: {kwargs}", + explanation="kwargs should have been flattened into lifted args.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + p_args = ( + p_args[0], + body_name, + *p_args[1:], + ) + return _call_function_with_auto_output_flattening( + tx, + torch._higher_order_ops.invoke_subgraph, + tuple(p_args), + p_kwargs, + example_value, + body_r, + body_graph_output_vts, + ) + + +class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable): + supports_input_mutation = False + supports_aliasing = False + + # Subclasses aren't supported by speculate_subgraph yet + # So this HOP is only usable with plain tensors + _enabled = False + + @classmethod + @contextlib.contextmanager + def enable(cls): + """Context manager to temporarily enable local map wrapping. + Will be removed when speculate_subgraph supports subclass inputs: + https://github.com/pytorch/pytorch/issues/161456. + + Usage: + with LocalMapWrappedHigherOrderVariable.enable_wrapping(): + # Code where should_wrap_in_hop will return True + pass + """ + old_value = cls._enabled + cls._enabled = True + try: + yield + finally: + cls._enabled = old_value + + @classmethod + def should_wrap_in_hop(cls, value): + if not torch.distributed.is_available(): + return False + + from torch.distributed.tensor.experimental._func_map import _local_map_wrapped + + # check is important to avoid subclass dispatch + if type(value) is not type(_local_map_wrapped): + return False + + return value is _local_map_wrapped and cls._enabled + + @staticmethod + def build(**options): + return TorchHigherOrderOperatorVariable.make( + torch._higher_order_ops.local_map_hop, + **options, + ) + + def python_type(self): + return type(self.value) + + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + """ + Goal of this function is to rewrite local_map usage as a HOP: + local_map(func, ...) -> local_map_hop(gm, ...) + """ + + ( + user_func, + out_placements, + in_placements, + in_grad_placements, + device_mesh, + redistribute_inputs, + *user_args, + ) = args + + # None placements are used to pass non-Tensors into the local_map function. + # Containers passed this way can not hold tensors. Thus, Dynamo would have inlined + # into them, and we handle None placements by assuming they will be desugared away. + # This will need to be adjusted for dynamic shapes support. + def check_none_last(placements): + seen_none = 0 + for p in placements: + if p is None: + seen_none += 1 + else: + assert seen_none == 0, ( + "Tracing local_map is only currently supported with None placements last." + ) + return seen_none + + inputs_none_placements = check_none_last(in_placements.value) + output_none_placements = check_none_last(out_placements.value) + + local_map_kwargs = { + "out_placements": out_placements.value, + "in_placements": in_placements.value, + "redistribute_inputs": redistribute_inputs.value, + "in_grad_placements": in_grad_placements.value, + "device_mesh": device_mesh.value, + } + assert local_map_kwargs["device_mesh"] is not None, ( + "Not yet implemented, please manually provide a device_mesh to local_map." + ) + mesh = local_map_kwargs["device_mesh"] + + # For Autoparallel, the initial trace is done with global shapes, then we decide model weights sharding, + # and reuse the graph. Since the sharding decision is after the initial trace, we can't trace with local shapes. + # For local_map however, since we specify all placements, we can trace with local shapes. + + # Step 1: Validate the annotated function matches the input_placements (i.e. that it can run in eager) + template = ( + "Expecting {expected} {inputs_or_outputs} to local_map function based on placements" + ", but found {actual}. Please ensure the count matches for eager. " + ) + assert len(in_placements.value) == len(user_args), template.format( + expected=len(in_placements.value), + inputs_or_outputs="inputs", + actual=len(user_args), + ) + + from torch._higher_order_ops.local_map import ( + redistribute_fw_inputs, + redistribute_fw_outputs, + ) + + # Step 2: Convert inputs to local shapes + priors = {} + for placements, vt in zip(in_placements.value, user_args): + if isinstance(vt, variables.lazy.LazyVariableTracker): + vt = variables.lazy.LazyVariableTracker.realize_all(vt) + + if not vt.is_tensor(): + assert placements is None + continue + + global_tensor = vt.as_proxy().node.meta["example_value"] + # NOTE: We don't support local_map region relying on exact grad_fn information + # This is okay since accessing grad_fn is a graph break. + local_tensor = redistribute_fw_inputs( + (global_tensor,), + (placements,), + mesh, + ) + local_tensor = local_tensor[0] + + priors[vt] = global_tensor + vt.as_proxy().node.meta["example_value"] = local_tensor + vt.synchronize_attributes(tx) + + # Step 3: Trace local_map subgraph with local tensors + ( + p_args, + p_kwargs, + example_value, + body_r, + body_gmod, + body_name, + body_graph_output_vts, + ) = self.create_wrapped_node( + tx, user_func, user_args, kwargs, self.value._name, subgraph_name="subgraph" + ) + + # Step 4: Validate traced graph signature still matches placement information + expected_num_inputs = len(in_placements.value) - inputs_none_placements + actual_num_inputs = len(body_gmod.graph.find_nodes(op="placeholder")) + expected_num_outputs = len(out_placements.value) - output_none_placements + assert len(body_gmod.graph.find_nodes(op="output")) == 1 + actual_num_outputs = len(body_gmod.graph.find_nodes(op="output")[0].args[0]) + + template = ( + "Expecting {expected} {inputs_or_outputs} to local_map function based on placements" + ", but found {actual}. If the count matches for eager, " + "Dynamo may have flattened {inputs_or_outputs} to the function or found additional " + "tensors used via closures. " + "Please adjust the input placements to match what the traced graph sees: \n{gm_str}." + ) + + def make_error_msg(*args): + expected_num, actual_num, inputs_or_outputs = args + gm_str = body_gmod.print_readable(print_output=False) + return template.format( + expected=expected_num, + inputs_or_outputs=inputs_or_outputs, + actual=actual_num, + gm_str=gm_str, + ) + + if expected_num_inputs != actual_num_inputs: + raise AssertionError( + make_error_msg(expected_num_inputs, actual_num_inputs, "inputs") + ) + if expected_num_outputs != actual_num_outputs: + raise AssertionError( + make_error_msg(expected_num_outputs, actual_num_outputs, "outputs") + ) + + if inputs_none_placements > 0: + expected_input_nodes = [ + arg.as_proxy().node for arg in user_args[:-inputs_none_placements] + ] + else: + expected_input_nodes = [arg.as_proxy().node for arg in user_args] + actual_input_nodes = [proxy.node for proxy in p_args] + assert actual_input_nodes[0].op == "get_attr" + assert "subgraph" in actual_input_nodes[0].target + assert len(expected_input_nodes) == len(actual_input_nodes) - 1 + for expected_order, actual_order in zip( + expected_input_nodes, actual_input_nodes[1:] + ): + assert expected_order == actual_order, ( + "Dynamo changed the order of inputs to the local_map function, please adjust " + f"the order of inputs and input_placements from {expected_input_nodes}, to: {actual_input_nodes[1:]}" + ) + assert len(p_kwargs) == 0 + + # Step 5: Install local_map subgraph + p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()} + out = _call_function_with_auto_output_flattening( + tx, + self.value, + p_args, + p_kwargs, + example_value, + body_r, + body_graph_output_vts, + ) + + # Step 6: Restore inputs and outputs to global shapes + for vt, global_tensor in priors.items(): + vt.as_proxy().node.meta["example_value"] = global_tensor + vt.synchronize_attributes(tx) + + outs = out.items if isinstance(out, TupleVariable) else [out] + assert len(outs) == len(out_placements.value) + for placements, vt in zip(out_placements.value, outs): + if not vt.is_tensor(): + assert placements is None + continue + + local_tensor = vt.as_proxy().node.meta["example_value"] + + # NOTE: We don't support code after the local_map region relying on exact grad_fn information + # This is okay since accessing grad_fn is a graph break. + global_tensor = redistribute_fw_outputs( + (local_tensor,), + (placements,), + mesh, + num_activations=0, # this is not the joint + ) + global_tensor = global_tensor[0] + + vt.as_proxy().node.meta["example_value"] = global_tensor + vt.synchronize_attributes(tx) + + # TODO: Figure out how to handle output order diverging from eager + + # Treat as const, so we don't have to deal with Placement types in fx IR + # Guarded with EQUALS_MATCH on local_map call's arguments + body_gmod.meta["local_map_kwargs"] = { + "out_placements": out_placements.value[:expected_num_outputs], + "in_placements": in_placements.value[:expected_num_inputs], + "redistribute_inputs": redistribute_inputs.value, + "in_grad_placements": in_grad_placements.value, + "device_mesh": device_mesh.value, + } + + return out + + +# Map operator names to their corresponding variable for fast TorchHigherOrderOperatorVariable.make() +_hop_name_to_variable_class = { + "cond": CondHigherOrderVariable, + "while_loop": WhileLoopHigherOrderVariable, + "while_loop_stack_output": WhileLoopStackOutputHigherOrderVariable, + "map_impl": MapHigherOrderVariable, + "executorch_call_delegate": ExecutorchCallDelegateHigherOrderVariable, + "out_dtype": OutDtypeHigherOrderVariable, + "wrap": WrapHigherOrderVariable, + "hints_wrapper": HintsWrapperHigherOrderVariable, + "flex_attention": FlexAttentionHigherOrderVariable, + "flex_attention_backward": FlexAttentionBackwardHighOrderVariable, + "wrap_activation_checkpoint": CheckpointHigherOrderVariable, + "tag_activation_checkpoint": CheckpointHigherOrderVariable, + "_export_tracepoint": ExportTracepointHigherOrderVariable, + "trace_wrapped": TraceWrappedHigherOrderOperatorVariable, + "strict_mode": StrictModeHigherOrderVariable, + "run_with_rng_state": RunWithRNGStateHigherOrderVariable, + "associative_scan": AssociativeScanHigherOrderVariable, + "scan": ScanHigherOrderVariable, + "call_torchbind": CallTorchbindHigherOrderVariable, + "print": PrintHigherOrderVariable, + "wrap_with_set_grad_enabled": WrapWithSetGradEnabledHigherOrderVariable, + "wrap_with_autocast": WrapWithAutocastHigherOrderVariable, + "dynamo_bypassing_wrapper": DynamoBypassingWrapperHigherOrderVariable, + "auto_functionalized": AutoFunctionalizeHigherOrderVariable, + "auto_functionalized_v2": AutoFunctionalizeHigherOrderVariable, + "invoke_subgraph": InvokeSubgraphHigherOrderVariable, + "custom_function_call": CustomFunctionHigherOrderOperatorVariable, + "local_map_hop": LocalMapWrappedHigherOrderVariable, +} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/iter.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/iter.py new file mode 100644 index 0000000000000000000000000000000000000000..4a3c0247add1b44329a2555ce49341fe75602ba2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/iter.py @@ -0,0 +1,620 @@ +""" +This module provides iterator-related variable tracking functionality for Dynamo. +It implements variable classes for handling Python iterators and itertools functions +during symbolic execution and tracing. + +The module includes: +- Base iterator variable classes for tracking iterator state +- Implementations of built-in iterators (zip, map, filter) +- Support for itertools functions (product, accumulate, combinations, etc.) +- Mutation tracking and reconstruction capabilities for iterator operations + +These classes integrate with Dynamo's variable tracking system to enable proper +handling of iterator operations during code transformation and optimization. +""" + +import itertools +import sys +from collections.abc import Callable, Sequence +from typing import Any, TYPE_CHECKING, Union + +from .. import graph_break_hints, polyfills, variables +from ..bytecode_transformation import ( + create_build_tuple, + create_call_function, + create_call_function_ex, + create_instruction, +) +from ..exc import ( + handle_observed_exception, + ObservedUserStopIteration, + raise_observed_exception, + unimplemented, + UserError, +) +from .base import ValueMutationNew, VariableTracker +from .constant import ConstantVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +MAX_ITERATOR_LIMIT = 100 * 1024 # 100k + + +class ItertoolsVariable(VariableTracker): + def __init__(self, value: Any, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.value = value + + def __repr__(self) -> str: + return f"ItertoolsVariable({self.value})" + + def as_python_constant(self) -> Any: + return self.value + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence["VariableTracker"], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # See also: module `torch._dynamo.polyfills.itertools` + + if self.value is itertools.product: + if any(kw != "repeat" for kw in kwargs): + unimplemented( + gb_type="Unsupported kwargs for itertools.product", + context=f"call_function {self} {args} {kwargs}", + explanation=f"Expected kwargs: 'repeat', but got " + f"{','.join(set(kwargs.keys()) - {'repeat'})}", + hints=[*graph_break_hints.USER_ERROR], + ) + + if "repeat" in kwargs: + r = kwargs["repeat"].as_python_constant() + else: + r = 1 + seqs = [arg.force_unpack_var_sequence(tx) for arg in args] + items = [ + variables.TupleVariable(list(item)) + for item in itertools.product(*seqs, repeat=r) + ] + return variables.ListIteratorVariable( + items, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), + ) + elif ( + self.value is itertools.combinations + and not kwargs + and len(args) == 2 + and args[0].has_unpack_var_sequence(tx) + and args[1].is_python_constant() + ): + iterable = args[0].unpack_var_sequence(tx) + r = args[1].as_python_constant() + + items = [] + for item in itertools.combinations(iterable, r): + items.append(variables.TupleVariable(list(item))) + return variables.ListIteratorVariable( + items, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), + ) + elif self.value is itertools.groupby: + if any(kw != "key" for kw in kwargs): + unimplemented( + gb_type="Unsupported kwargs for itertools.groupby", + context=f"call_function {self} {args} {kwargs}", + explanation=f"Expected kwargs: 'key', but got " + f"{','.join(set(kwargs.keys()) - {'key'})}", + hints=[*graph_break_hints.USER_ERROR], + ) + + def retrieve_const_key(key: VariableTracker) -> Any: + if isinstance(key, variables.SymNodeVariable): + return key.evaluate_expr() + elif key.is_python_constant(): + return key.as_python_constant() + else: + unimplemented( + gb_type="Unsupported key type for itertools.groupby", + context=f"call_function {self} {args} {kwargs}", + explanation="Dynamo does not know how to trace " + f"itertools.groupby with key type: {str(type(key))}. " + "We only support grouping keys that are constants (int, float, str, etc.)", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + if len(args) == 1 and args[0].has_unpack_var_sequence(tx): + seq = args[0].unpack_var_sequence(tx) + else: + unimplemented( + gb_type="Unsupported arguments for itertools.groupby", + context=f"call_function {self} {args} {kwargs}", + explanation="Dynamo does not know how to trace " + f"itertools.groupby with args: {args} and kwargs: {kwargs}. " + "itertools.groupby expects an iterable to group and an " + "optional key function to determine groupings.", + hints=[ + "Make sure the arguments to itertools.groupby are correct.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + if "key" in kwargs: + + def keyfunc(x: VariableTracker) -> Any: + return retrieve_const_key( + kwargs.get("key").call_function(tx, [x], {}) # type: ignore[union-attr] + ) + + else: + + def keyfunc(x: VariableTracker) -> Any: + return retrieve_const_key(x) + + result = [] + try: + # pyrefly: ignore [unbound-name] + for k, v in itertools.groupby(seq, key=keyfunc): + result.append( + variables.TupleVariable( + [ + ( + variables.ConstantVariable.create(k) + if variables.ConstantVariable.is_literal(k) + else k + ), + variables.ListIteratorVariable( + list(v), mutation_type=ValueMutationNew() + ), + ], + mutation_type=ValueMutationNew(), + ) + ) + except Exception as e: + unimplemented( + gb_type="Unexpected failure during itertools.groupby() iteration", + context=f"call_function {self} {args} {kwargs}", + explanation="Unexpected failure in invoking function during groupby", + hints=[*graph_break_hints.SUPPORTABLE], + from_exc=e, + ) + return variables.ListIteratorVariable( + result, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), + ) + elif self.value is itertools.repeat: + if len(args) < 2: + return variables.RepeatIteratorVariable( + *args, mutation_type=ValueMutationNew() + ) + + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.repeat), args, kwargs + ) + elif self.value is itertools.count: + return variables.CountIteratorVariable( + *args, mutation_type=ValueMutationNew() + ) + elif ( + self.value is itertools.permutations + and (len(args) == 1 or (len(args) == 2 and args[1].is_python_constant())) + and not kwargs + ): + if len(args) == 2: + r = args[1].as_python_constant() + else: + r = None + items = [ + variables.TupleVariable(list(item)) + for item in itertools.permutations( + args[0].force_unpack_var_sequence(tx), r + ) + ] + return variables.ListIteratorVariable( + items, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), + ) + else: + return super().call_function(tx, args, kwargs) + + +class IteratorVariable(VariableTracker): + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + unimplemented( + gb_type="Unimplemented next() call", + context=f"next({self})", + explanation="This abstract method must be implemented", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + + # NOTE: only call when unpacking this iterator safely done eagerly! + # Normally, iterators are accessed lazily. + # Example of safe eager unpacking: list(map(f, seq)) + # Example of unsafe eager unpacking: list(islice(map(f, seq), 5)) + def force_unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list[VariableTracker]: + result: list[VariableTracker] = [] + self.force_apply_to_var_sequence(tx, result.append) + return result + + def force_apply_to_var_sequence( + self, tx: "InstructionTranslator", fn: Callable[[Any], Any] + ) -> None: + while True: + try: + fn(self.next_variable(tx)) + except ObservedUserStopIteration: + handle_observed_exception(tx) + break + + # don't call force_unpack_var_sequence since it can mutate + # IteratorVariable state! + def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return True + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "ConstantVariable": + if name == "__iter__" or name == "__next__": + return variables.ConstantVariable.create(True) + return super().call_obj_hasattr(tx, name) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__iter__": + return self + elif name == "__next__": + return self.next_variable(tx) + return super().call_method(tx, name, args, kwargs) + + +class ObjectIteratorVariable(IteratorVariable): + """ + VariableTracker for iter(obj) that implements the iterator protocol (i.e., + has a `__next__` method). + + We use this class to track the state of the iterator and handle the case + when the iterator is exhausted: + + Example usage: + > b = iter(obj) + > list(b) # exhaust the iterator + > list(b) # empty list + """ + + def __init__(self, obj: VariableTracker, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.obj = obj + self.generator_exhausted = False + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + if self.generator_exhausted: + raise_observed_exception(StopIteration, tx) + + try: + return self.obj.next_variable(tx) + except ObservedUserStopIteration: + # Do not rely on the object to always return StopIteration once it + # is exhausted. + self.generator_exhausted = True + raise + + +class RepeatIteratorVariable(IteratorVariable): + def __init__(self, item: VariableTracker, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.item = item + + # Repeat needs no mutation, clone self + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + return self.item + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(itertools), + codegen.create_load_attr("repeat"), + ] + ) + ) + codegen(self.item) + codegen.extend_output(create_call_function(1, False)) + + +class CountIteratorVariable(IteratorVariable): + def __init__( + self, + item: Union[int, VariableTracker] = 0, + step: Union[int, VariableTracker] = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if not isinstance(item, VariableTracker): + item = ConstantVariable.create(item) + if not isinstance(step, VariableTracker): + step = ConstantVariable.create(step) + self.item = item + self.step = step + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + assert self.is_mutable() + old_item = self.item + tx.output.side_effects.mutation(self) + self.item = self.item.call_method(tx, "__add__", [self.step], {}) + return old_item + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(itertools), + codegen.create_load_attr("count"), + ] + ) + ) + codegen(self.item) + codegen(self.step) + codegen.extend_output(create_call_function(2, False)) + + +class ZipVariable(IteratorVariable): + """ + Represents zip(*iterables) + """ + + _nonvar_fields = { + "index", + "strict", + *IteratorVariable._nonvar_fields, + } + + def __init__( + self, + iterables: list[VariableTracker], + strict: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + assert isinstance(iterables, list) + # can be list[Variable] or VariableTracker (with next_variable implemented) + self.iterables = iterables + self.index = 0 + self.strict = strict + + def python_type(self) -> type[zip]: # type: ignore[type-arg] + return zip + + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return all( + isinstance(it, list) or it.has_unpack_var_sequence(tx) + for it in self.iterables + ) + + def unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list["VariableTracker"]: + assert self.has_unpack_var_sequence(tx) + iterables = [] + for it in self.iterables: + if isinstance(it, list): + iterables.append(it[self.index :]) + else: + iterables.append(it.unpack_var_sequence(tx)) + kwargs = {"strict": self.strict} if self.strict else {} + zipped = zip(*iterables, **kwargs) + return [variables.TupleVariable(list(var)) for var in zipped] + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + assert self.is_mutable() + + if len(self.iterables) == 0: + raise_observed_exception(StopIteration, tx) + + old_index = self.index + args = [] + + def get_item( + it: Union[list[VariableTracker], VariableTracker], + ) -> VariableTracker: + if isinstance(it, list): + if old_index >= len(it): + raise_observed_exception(StopIteration, tx) + return it[old_index] + else: + return it.next_variable(tx) + + idx: int | None = None + try: + for idx, it in enumerate(self.iterables): # noqa:B007 + args.append(get_item(it)) + except ObservedUserStopIteration: + if self.strict: + if idx == 0: + # all other iterables should be exhausted + for it in self.iterables: + try: + get_item(it) + except ObservedUserStopIteration: + handle_observed_exception(tx) + continue + # no ObservedUserStopIteration - fall through to UserError + break + else: + # all iterables exhausted, raise original error + raise + handle_observed_exception(tx) + raise UserError( + ValueError, # type: ignore[arg-type] + "zip() has one argument of len differing from others", + ) from None + raise + + tx.output.side_effects.mutation(self) + self.index += 1 + return variables.TupleVariable(args) + + def reconstruct_items(self, codegen: "PyCodegen") -> None: + for it in self.iterables: + if isinstance(it, list): + remaining_items = it[self.index :] + codegen.foreach(remaining_items) + codegen.append_output(create_build_tuple(len(remaining_items))) + else: + codegen(it) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True + ) + self.reconstruct_items(codegen) + codegen.append_output(create_build_tuple(len(self.iterables))) + codegen.extend_output( + [ + codegen.create_load_const("strict"), + codegen.create_load_const(self.strict), + create_instruction("BUILD_MAP", arg=1), + *create_call_function_ex(True, False), + ] + ) + + +class MapVariable(ZipVariable): + """ + Represents map(fn, *iterables) + """ + + def __init__( + self, + fn: VariableTracker, + iterables: list[VariableTracker], + **kwargs: Any, + ) -> None: + super().__init__(iterables, **kwargs) + self.fn = fn + + def python_type(self) -> type: + return map + + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return False + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + args = super().next_variable(tx) + return self.fn.call_function(tx, args.items, {}) # type: ignore[attr-defined] + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True + ) + codegen(self.fn) + self.reconstruct_items(codegen) + codegen.append_output(create_build_tuple(len(self.iterables) + 1)) + if self.strict: + assert sys.version_info >= (3, 14), ( + "Unexpected bug: map(strict=True) requires Python 3.14+" + ) + codegen.extend_output( + [ + codegen.create_load_const("strict"), + codegen.create_load_const(self.strict), + create_instruction("BUILD_MAP", arg=1), + *create_call_function_ex(True, False), + ] + ) + else: + codegen.extend_output(create_call_function_ex(False, False)) + + +class FilterVariable(IteratorVariable): + """ + Represents filter(fn, iterable) + """ + + _nonvar_fields = { + "index", + *IteratorVariable._nonvar_fields, + } + + def __init__( + self, + fn: VariableTracker, + iterable: list[VariableTracker], + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.fn = fn + self.iterable = iterable + self.index = 0 + + def python_type(self) -> type: + return filter + + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return isinstance(self.iterable, list) or self.iterable.has_unpack_var_sequence( + tx + ) + + def unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list["VariableTracker"]: + assert self.has_unpack_var_sequence(tx) + it = None + if isinstance(self.iterable, list): + it = self.iterable[self.index :] + else: + it = self.iterable.unpack_var_sequence(tx) + filtered = self.fn.call_function(tx, it, {}) + return [variables.TupleVariable([filtered])] + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + def _next() -> VariableTracker: + old_index = self.index + if isinstance(self.iterable, list): + if old_index >= len(self.iterable): + raise_observed_exception(StopIteration, tx) + return self.iterable[old_index] + else: + return self.iterable.next_variable(tx) + + # A do-while loop to find elements that make fn return true + while True: + item = _next() + self.index += 1 + if self.fn.is_constant_none(): + res = item + else: + res = self.fn.call_function(tx, [item], {}) + pred_res = variables.UserFunctionVariable( + polyfills.predicate # type: ignore[arg-type] + ).call_function(tx, [res], {}) + if pred_res.as_python_constant(): + return item + + def reconstruct_items(self, codegen: "PyCodegen") -> None: + if isinstance(self.iterable, list): + remaining_items = self.iterable[self.index :] + codegen.foreach(remaining_items) + codegen.append_output(create_build_tuple(len(remaining_items))) + else: + codegen(self.iterable) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter")) + codegen(self.fn) + self.reconstruct_items(codegen) + codegen.extend_output(create_call_function(2, False)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..74609e0884cb284f4d9e286696e2cdde4e7d8e1f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import collections +import functools +import inspect +from typing import Any, TYPE_CHECKING + +from ..utils import is_function_or_wrapper +from .base import VariableTracker, VariableTrackerMeta + + +if TYPE_CHECKING: + from collections.abc import Callable + from typing_extensions import Self + + from .tensor import SymNodeVariable + + +class LazyCache: + """Container to cache the real VariableTracker""" + + def __init__(self, value: Any, source: Any) -> None: + if not isinstance(value, LazySymNodeFormatString): + assert source + self.value = value + self.source = source + self.name_hint: str | None = None + self.vt: VariableTracker | None = None + + def realize(self) -> None: + assert self.vt is None + from ..symbolic_convert import InstructionTranslator + from . import builder + + tx = InstructionTranslator.current_tx() + + if isinstance(self.value, LazySymNodeFormatString): + self.vt = builder.SourcelessBuilder.create(tx, self.value) + else: + self.vt = builder.VariableBuilder(tx, self.source)(self.value) + + if self.name_hint is not None: + # pyrefly: ignore [missing-attribute] + self.vt.set_name_hint(self.name_hint) + + del self.value + del self.source + del self.name_hint + + +class LazyVariableTracker(VariableTracker, metaclass=VariableTrackerMeta): + """ + A structure that defers the creation of the actual VariableTracker + for a given underlying value until it is accessed. + + The `realize` function invokes VariableTracker.build() to produce the real object. + Once a LazyVariableTracker has been realized, internal bookkeeping will + prevent double realization. + + This object should be utilized for processing containers, or objects that + reference other objects where we may not want to take on creating all the + VariableTrackers right away. + """ + + # Flag to prevent implicit realization in isinstance checks (inherited by subclasses) + _no_implicit_realize = True + _nonvar_fields = {"_cache", *VariableTracker._nonvar_fields} + + @staticmethod + def create(value: Any, source: Any, **options: Any) -> LazyVariableTracker: + return LazyVariableTracker(LazyCache(value, source), source=source, **options) + + def __init__(self, _cache: LazyCache, **kwargs: Any) -> None: + assert isinstance(_cache, LazyCache) + super().__init__(**kwargs) + self._cache = _cache + + def realize(self) -> VariableTracker: + """Force construction of the real VariableTracker""" + if self._cache.vt is None: + self._cache.realize() + assert self._cache.vt is not None + return self._cache.vt + + def lazy_isinstance(self, cls: type) -> bool: + """Check isinstance after realizing, used by ImplicitRealizingVariableTrackerMeta""" + return type.__instancecheck__(cls, self.realize()) + + def unwrap(self) -> VariableTracker | Self: + """Return the real VariableTracker if it already exists""" + if self.is_realized(): + assert self._cache.vt is not None + return self._cache.vt + return self + + def is_realized(self) -> bool: + return self._cache.vt is not None + + def clone(self, **kwargs: Any) -> VariableTracker: + assert kwargs.get("_cache", self._cache) is self._cache + if kwargs.get("source", self.source) is not self.source: + self.realize() + return VariableTracker.clone(self.unwrap(), **kwargs) + + def peek_type(self) -> type[Any]: + assert not self.is_realized() + return type(self._cache.value) + + def peek_value(self) -> Any: + assert not self.is_realized() + return self._cache.value + + def set_name_hint(self, name: str) -> None: + if self.is_realized(): + self._cache.vt.set_name_hint(name) # type: ignore[union-attr] + else: + self._cache.name_hint = name + + def __str__(self) -> str: + variable_info = "LazyVariableTracker(" + if self.is_realized(): + variable_info += f"realized: {repr(self.unwrap())})" + else: + variable_info += f"unrealized: {self.peek_type()})" + + return variable_info + + def __getattr__(self, item: str) -> Any: + return getattr(self.realize(), item) + + # most methods are auto-generated below, these are the ones we want to exclude + visit = VariableTracker.visit # type: ignore[assignment] + __repr__ = __str__ + + @classmethod + def realize_all( + cls, + value: Any, + cache: dict[int, tuple[Any, Any]] | None = None, + ) -> Any: + """ + Walk an object and realize all LazyVariableTrackers inside it. + """ + if cache is None: + cache = {} + + idx = id(value) + if idx in cache: + return cache[idx][0] + + value_cls = type(value) + if issubclass(value_cls, LazyVariableTracker): + result = cls.realize_all(value.realize(), cache) + elif issubclass(value_cls, VariableTracker): + # update value in-place + result = value + value_dict = value.__dict__ + nonvars = value._nonvar_fields + for key in value_dict: + if key not in nonvars: + value_dict[key] = cls.realize_all(value_dict[key], cache) + elif value_cls is list: + result = [cls.realize_all(v, cache) for v in value] + elif value_cls is tuple: + result = tuple(cls.realize_all(v, cache) for v in value) + elif value_cls in (dict, collections.OrderedDict): + result = {k: cls.realize_all(v, cache) for k, v in list(value.items())} + else: + result = value + + # save `value` to keep it alive and ensure id() isn't reused + cache[idx] = (result, value) + return result + + def is_hashable(self) -> bool: + # Checks that the underlying value is hashable without realizing the VT. + # This is used by ConstDictVariable tracker to find if the key LazyVT + # can be hashed. + def _helper(value: Any) -> bool: + # TODO: Add support for more types + return ( + inspect.isbuiltin(value) + or issubclass(type(value), type) + or is_function_or_wrapper(value) + ) + + assert not self.is_realized() + value = self._cache.value + if isinstance(value, tuple): + return all(_helper(v) for v in value) + return _helper(value) + + def original_value(self) -> Any: + # Returns the value without realizing the VT. + assert not self.is_realized() + return self._cache.value + + def original_source(self) -> Any: + # Returns the source without realizing the VT. + assert not self.is_realized() + return self._cache.source + + +class LazySymNodeFormatString: + def __init__( + self, sym_node_variable: SymNodeVariable, fmt_spec_var: VariableTracker + ) -> None: + from .constant import ConstantVariable + + self.sym_node_var = sym_node_variable + self.fmt_var = ConstantVariable.create( + "{:" + fmt_spec_var.as_python_constant() + "}" + ) + + def __repr__(self) -> str: + return str.format( + self.fmt_var.as_python_constant(), + str(self.sym_node_var.evaluate_expr()), + ) + + +def _create_realize_and_forward( + name: str, +) -> Callable[[LazyVariableTracker, Any, Any], Any]: + @functools.wraps(getattr(VariableTracker, name)) + def realize_and_forward( + self: LazyVariableTracker, *args: Any, **kwargs: Any + ) -> Any: + return getattr(self.realize(), name)(*args, **kwargs) + + return realize_and_forward + + +def _populate() -> None: + for name, value in VariableTracker.__dict__.items(): + if name not in LazyVariableTracker.__dict__: + if callable(value): + setattr(LazyVariableTracker, name, _create_realize_and_forward(name)) + + +_populate() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/lists.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/lists.py new file mode 100644 index 0000000000000000000000000000000000000000..734d30a76380d350e615da34929ec56d6d4bae7d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/lists.py @@ -0,0 +1,1821 @@ +""" +Variable tracking implementations for list-like data structures in Dynamo. + +This module provides specialized variable tracking for various collection types: +- Lists and list subclasses (including torch.nn.ModuleList, ParameterList) +- Tuples and named tuples +- Ranges and slices +- Collections.deque +- torch.Size with special proxy handling + +The implementations support both mutable and immutable collections, iteration, +and common sequence operations. Each collection type has a dedicated Variable +class that handles its unique behaviors while integrating with Dynamo's +variable tracking system. +""" + +import collections +import inspect +import operator +import sys +from collections.abc import Sequence +from typing import Any, Optional, TYPE_CHECKING + +import torch +import torch.fx + +from .. import graph_break_hints, polyfills, variables +from ..bytecode_transformation import ( + create_build_tuple, + create_call_function, + create_instruction, + create_rot_n, +) +from ..exc import raise_observed_exception, unimplemented +from ..source import AttrSource, NamedTupleFieldsSource +from ..utils import ( + cmp_name_to_op_mapping, + cmp_name_to_op_str_mapping, + get_fake_value, + guard_if_dyn, + iter_contains, + Lit, + namedtuple_fields, + odict_values, + raise_args_mismatch, + range_iterator, + set_example_value, +) +from .base import ValueMutationNew, VariableTracker +from .constant import ConstantVariable +from .functions import UserFunctionVariable, UserMethodVariable +from .iter import IteratorVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class BaseListVariable(VariableTracker): + @staticmethod + def cls_for_instance(obj: Any) -> type["BaseListVariable"]: + return BaseListVariable.cls_for(type(obj)) + + @staticmethod + def cls_for(obj: Any) -> type: + return { + iter: ListIteratorVariable, + list: ListVariable, + slice: SliceVariable, + torch.Size: SizeVariable, + tuple: TupleVariable, + odict_values: ListVariable, + torch.nn.ParameterList: ListVariable, + torch.nn.ModuleList: ListVariable, + collections.deque: DequeVariable, + }[obj] + + def __init__( + self, + items: list[VariableTracker], + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + assert isinstance(items, list) + assert all(isinstance(x, VariableTracker) for x in items) + self.items: list[VariableTracker] = items + + def _as_proxy(self) -> list[Any]: + return [x.as_proxy() for x in self.items] + + def modified( + self, items: list[VariableTracker], **kwargs: Any + ) -> "BaseListVariable": + return type(self)(items, **kwargs) + + @property + def value(self) -> Any: + return self.as_python_constant() + + def debug_repr_helper(self, prefix: str, suffix: str) -> str: + return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix + + def as_python_constant(self) -> Any: + return self.python_type()([x.as_python_constant() for x in self.items]) + + def as_proxy(self) -> Any: + assert self.python_type() is not SizeVariable + return self.python_type()(self._as_proxy()) + + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + from .tensor import SymNodeVariable + + if isinstance(arg, SymNodeVariable): + index = arg.sym_num + else: + index = arg.as_python_constant() + + if isinstance(index, slice): + if index.step == 0: + msg = ConstantVariable.create("slice step cannot be zero") + raise_observed_exception(ValueError, tx, args=[msg]) + # Set source to None because slicing a list gives a new local + return self.clone( + items=self.items[index], + source=None, + mutation_type=ValueMutationNew() if self.mutation_type else None, + ) + else: + assert isinstance(index, (int, torch.SymInt)) + try: + return self.items[index] + except IndexError: + raise_observed_exception( + IndexError, tx, args=["list index out of range"] + ) + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + return list(self.items) + + def call_tree_map_branch( + self, + tx: "InstructionTranslator", + tree_map_fn: UserFunctionVariable, + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not isinstance(self, (ListVariable, TupleVariable)): + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + + other_lists: list[BaseListVariable] = [] + for candidate in rest: + if ( + not isinstance(candidate, BaseListVariable) + or len(candidate.items) != len(self.items) + or self.python_type() != candidate.python_type() + ): + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + other_lists.append(candidate) + + new_items: list[VariableTracker] = [] + for idx, item in enumerate(self.items): + sibling_leaves = [candidate.items[idx] for candidate in other_lists] + new_items.append( + item.call_tree_map( + tx, + tree_map_fn, + map_fn, + sibling_leaves, + tree_map_kwargs, + ) + ) + + return self.clone( + items=new_items, + source=None, + mutation_type=ValueMutationNew(), + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__getitem__": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + if args[0].is_tensor(): + value = get_fake_value(args[0].as_proxy().node, tx) + if value.constant is not None and value.constant.numel() == 1: + value = variables.ConstantVariable.create(value.constant.item()) + else: + unimplemented( + gb_type="Indexing list with non-scalar tensor", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=( + "Attempted to index list-like object with tensor with > 1 element." + ), + hints=[*graph_break_hints.USER_ERROR], + ) + else: + value = args[0] + + if value.python_type() not in (int, slice): + msg = f"indices must be integers or slices, not {value.python_type()}" + raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)]) + + return self.getitem_const(tx, value) + elif name == "__contains__": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return iter_contains(self.unpack_var_sequence(tx), args[0], tx) + elif name == "index": + if not len(args): + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.index), + [self] + list(args), + kwargs, + ) + elif name == "count": + if len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return VariableTracker.build(tx, operator.countOf).call_function( + tx, + [self, args[0]], + kwargs, + ) + elif name in ("__add__", "__iadd__"): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + if type(self) is not type(args[0]): + tp_name = self.python_type_name() + other = args[0].python_type_name() + msg_vt = ConstantVariable.create( + f'can only concatenate {tp_name} (not "{other}") to {tp_name}' + ) + raise_observed_exception(TypeError, tx, args=[msg_vt]) + + if name == "__add__": + return type(self)(self.items + args[0].items, source=self.source) # type: ignore[attr-defined] + else: + self.items += args[0].items # type: ignore[attr-defined] + return self + elif name in ("__mul__", "__imul__"): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + if not (args[0].is_python_constant() and args[0].python_type() is int): + msg_vt = ConstantVariable.create( + f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'" + ) + raise_observed_exception(TypeError, tx, args=[msg_vt]) + + val = args[0].as_python_constant() + + if name == "__mul__": + return type(self)(self.items * val, source=self.source) + else: + self.items *= val + return self + elif name in cmp_name_to_op_mapping: + if len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + left = self + right = args[0] + # TODO this type check logic mirrors the following + # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/object.c#L991-L1007 + # But we should probably move it up the stack to so that we don't + # need to duplicate it for different VTs. + if not isinstance(left, BaseListVariable) or not isinstance( + right, BaseListVariable + ): + if name == "__eq__": + return variables.BuiltinVariable(operator.is_).call_function( + tx, (left, right), {} + ) + elif name == "__ne__": + return variables.BuiltinVariable(operator.is_not).call_function( + tx, (left, right), {} + ) + else: + op_str = cmp_name_to_op_str_mapping[name] + left_ty = left.python_type_name() + right_ty = right.python_type_name() + msg = f"{op_str} not supported between instances of '{left_ty}' and '{right_ty}'" + raise_observed_exception(TypeError, tx, args=[msg]) + + return variables.UserFunctionVariable(polyfills.list_cmp).call_function( + tx, + [variables.BuiltinVariable(cmp_name_to_op_mapping[name]), left, right], + {}, + ) + elif name == "__iter__": + return ListIteratorVariable(self.items, mutation_type=ValueMutationNew()) + + return super().call_method(tx, name, args, kwargs) + + +class RangeVariable(BaseListVariable): + def __init__(self, items: Sequence[VariableTracker], **kwargs: Any) -> None: + items_to_map = items + start = variables.ConstantVariable.create(0) + stop = None + step = variables.ConstantVariable.create(1) + + if len(items_to_map) == 1: + (stop,) = items_to_map + elif len(items_to_map) == 2: + start, stop = items_to_map + elif len(items_to_map) == 3: + start, stop, step = items_to_map + else: + raise AssertionError + + def maybe_as_int(x: VariableTracker) -> VariableTracker: + return ( + ConstantVariable.create(int(x.as_python_constant())) + if x.is_python_constant() + else x + ) + + # cast each argument to an integer + start = maybe_as_int(start) + step = maybe_as_int(step) + stop = maybe_as_int(stop) + + assert stop is not None + super().__init__([start, stop, step], **kwargs) + + def debug_repr(self) -> str: + return self.debug_repr_helper("range(", ")") + + def python_type(self) -> type: + return range + + def start(self) -> Any: + return self.items[0].as_python_constant() + + def stop(self) -> Any: + return self.items[1].as_python_constant() + + def step(self) -> Any: + return self.items[2].as_python_constant() + + def range_length(self) -> int: + lo = self.start() + hi = self.stop() + step = self.step() + + assert step != 0 + if step > 0 and lo < hi: + return 1 + (hi - 1 - lo) // step + elif step < 0 and lo > hi: + return 1 + (lo - 1 - hi) // (0 - step) + else: + return 0 + + def _get_slice_indices(self, length: int, slice: slice) -> list[int]: + step_is_negative = 0 + + if slice.step is None: + step = 1 + step_is_negative = False + else: + step = slice.step + step_is_negative = slice.step < 0 + + # Find lower and upper bounds for start and stop. + if step_is_negative: + lower = -1 + upper = length + lower + else: + lower = 0 + upper = length + + # Compute start + if slice.start is None: + start = upper if step_is_negative else lower + else: + start = slice.start + + if start < 0: + start += length + if start < lower: + start = lower + else: + if start > upper: + start = upper + + # Compute stop. + if slice.stop is None: + stop = lower if step_is_negative else upper + + else: + stop = slice.stop + + if stop < 0: + stop += length + if stop < lower: + stop = lower + else: + if stop > upper: + stop = upper + + return [start, stop, step] + + def apply_index(self, index: int) -> VariableTracker: + length = self.range_length() + if index < 0: + index = length + index + + if index < 0 or index >= length: + tx = torch._dynamo.symbolic_convert.InstructionTranslator.current_tx() + raise_observed_exception( + IndexError, + tx, + args=[ConstantVariable("range object index out of range")], + ) + + return variables.ConstantVariable.create(self.start() + (index * self.step())) + + def apply_slice(self, slice: slice) -> "RangeVariable": + (slice_start, slice_stop, slice_step) = self._get_slice_indices( + self.range_length(), slice + ) + + def compute_item(index: int) -> int: + return self.start() + (index * self.step()) + + sub_step = self.step() * slice_step + sub_start = compute_item(slice_start) + sub_stop = compute_item(slice_stop) + + result = RangeVariable( + [ + variables.ConstantVariable.create(x) + for x in [sub_start, sub_stop, sub_step] + ], + mutation_type=ValueMutationNew() if self.mutation_type else None, + ) + return result + + def as_python_constant(self) -> range: + return range(*[x.as_python_constant() for x in self.items]) + + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + # implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c + index = arg.as_python_constant() + + if isinstance(index, slice): + return self.apply_slice(index) + elif isinstance(index, int): + return self.apply_index(index) + else: + msg = ConstantVariable("range indices must be integers or slices") + raise_observed_exception(TypeError, tx, args=[msg]) + + def as_proxy(self) -> range: + return self.python_type()(*self._as_proxy()) + + def unpack_var_sequence( + self, tx: Optional["InstructionTranslator"] = None + ) -> list[VariableTracker]: + return [variables.ConstantVariable.create(x) for x in self.as_python_constant()] + + def reconstruct(self, codegen: "PyCodegen") -> None: + assert "range" not in codegen.tx.f_globals + codegen.add_push_null( + lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type] + ) + codegen.foreach(self.items) + codegen.extend_output(create_call_function(3, False)) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if self.python_type() is range: + return variables.ConstantVariable.create(name in range.__dict__) + return super().call_obj_hasattr(tx, name) + + def range_equals(self, other: "RangeVariable") -> bool: + r0, r1 = self, other + if ( + self.range_length() != r1.range_length() + or self.range_length() == 0 + or r0.start() != r1.start() + ): + return False + + if self.range_length() == 1: + return True + + return r0.step() == r1.step() + + def range_count(self, x: VariableTracker) -> int: + # Based on CPython + # https://github.com/guilhermeleobas/cpython/blob/baefaa6cba1d69efd2f930cdc56bca682c54b139/Objects/rangeobject.c#L442-L486 + x = x.as_python_constant() + if type(x) not in (bool, int, float): + return 0 + + start, stop, step = self.start(), self.stop(), self.step() + + if step == 0: + return 0 + + in_range = (start <= x < stop) if step > 0 else (stop < x <= start) + + if in_range: + re = ((x - start) % step) == 0 + return int(re) + return 0 + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__iter__": + if not all(var.is_python_constant() for var in self.items): + # Can't represent a `range_iterator` without well defined bounds + return variables.misc.DelayGraphBreakVariable( + msg="Cannot create range_iterator: bounds (start, stop, step) must be fully defined as concrete constants.", + ) + return RangeIteratorVariable( + self.start(), self.stop(), self.step(), self.range_length() + ) + elif name == "__len__": + length = self.range_length() + if length > sys.maxsize: + raise_observed_exception(OverflowError, tx) + return ConstantVariable.create(self.range_length()) + elif name in ("count", "__contains__"): + return ConstantVariable(self.range_count(*args)) + elif name == "__getitem__": + return self.getitem_const(tx, *args) + elif name in cmp_name_to_op_mapping: + other = args[0] + pt = other.python_type() + if name not in ("__eq__", "__ne__"): + # ranges are only comparable to other ranges + msg = f"{name} not supported between instances of 'range' and '{pt}'" + raise_observed_exception( + TypeError, + tx, + args=[ConstantVariable.create(msg)], + ) + + if pt is not range: + return ConstantVariable.create(NotImplemented) + + if isinstance(other, RangeVariable): + cmp = self.range_equals(other) + else: + cmp = False + + # Two ranges are equal if they produce the same sequence of values + if name == "__eq__": + return ConstantVariable(cmp) + else: + return ConstantVariable(not cmp) + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + fields = ["start", "stop", "step"] + if name in fields: + return self.items[fields.index(name)] + return super().var_getattr(tx, name) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + l = self.range_length() + start = self.start() + step = self.step() + return hash((l, start, step)) + + def is_python_equal(self, other): + if not isinstance(other, variables.RangeVariable): + return False + + return ( + self.start() == other.start() + and self.step() == other.step() + and self.stop() == other.stop() + ) + + +class CommonListMethodsVariable(BaseListVariable): + """ + Implement methods common to List and other List-like things + """ + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .tensor import SymNodeVariable + + if name == "append" and self.is_mutable(): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + (arg,) = args + tx.output.side_effects.mutation(self) + self.items.append(arg) + return ConstantVariable.create(None) + elif name == "extend" and self.is_mutable(): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + if not args[0].has_force_unpack_var_sequence(tx): + msg = ConstantVariable.create(f"{type(args[0])} object is not iterable") + raise_observed_exception(TypeError, tx, args=[msg]) + + (arg,) = args + arg.force_apply_to_var_sequence( + tx, lambda item: self.call_method(tx, "append", [item], {}) + ) + return ConstantVariable.create(None) + elif name == "insert" and self.is_mutable(): + if kwargs or len(args) != 2: + raise_args_mismatch( + tx, + name, + "2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + idx, value = args + if isinstance(idx, SymNodeVariable): + const_idx = idx.evaluate_expr() + else: + const_idx = idx.as_python_constant() + tx.output.side_effects.mutation(self) + self.items.insert(const_idx, value) + return ConstantVariable.create(None) + elif name == "pop" and self.is_mutable(): + if kwargs or len(args) > 1: + raise_args_mismatch( + tx, + name, + "at most 1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + if len(self.items) == 0: + msg = ConstantVariable.create("pop from empty list") + raise_observed_exception(IndexError, tx, args=[msg]) + + if len(args): + idx = args[0].as_python_constant() + if idx > len(self.items): + msg = ConstantVariable.create("pop index out of range") + raise_observed_exception(IndexError, tx, args=[msg]) + tx.output.side_effects.mutation(self) + return self.items.pop(*[a.as_python_constant() for a in args]) + elif name == "clear" and self.is_mutable(): + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + tx.output.side_effects.mutation(self) + self.items.clear() + return ConstantVariable.create(None) + elif ( + name == "__setitem__" + and self.is_mutable() + and args + and ( + args[0].is_python_constant() + or isinstance(args[0], SymNodeVariable) + or ( + isinstance(args[0], SliceVariable) + and all( + s.is_python_constant() or isinstance(s, SymNodeVariable) + for s in args[0].items + ) + ) + ) + ): + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + key, value = args + tx.output.side_effects.mutation(self) + if isinstance(key, SymNodeVariable): + self.items[key.evaluate_expr()] = value + elif isinstance(key, SliceVariable): + if key.is_python_constant(): + self.items[key.as_python_constant()] = list(value.items) # type: ignore[attr-defined] + else: + items_slice = slice( + *[ + ( + s.evaluate_expr() + if isinstance(s, SymNodeVariable) + else s.as_python_constant() + ) + for s in key.items + ] + ) + self.items[items_slice] = list(value.items) # type: ignore[attr-defined] + else: + self.items[key.as_python_constant()] = value + return ConstantVariable.create(None) + elif name == "__delitem__" and self.is_mutable(): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + tx.output.side_effects.mutation(self) + if args[0].is_python_constant() and isinstance( + args[0].as_python_constant(), (int, slice) + ): + if isinstance(args[0], SymNodeVariable): + idx = args[0].evaluate_expr() + else: + idx = args[0].as_python_constant() + + try: + self.items.__delitem__(idx) + except (IndexError, ValueError) as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + else: + msg = ConstantVariable.create( + f"list indices must be integers or slices, not {args[0].python_type_name()}" + ) + raise_observed_exception(TypeError, tx, args=[msg]) + return ConstantVariable.create(None) + elif name == "copy": + # List copy() doesn't have args and kwargs + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + items_lst: list[VariableTracker] = list(self.items) + return self.modified(items_lst, mutation_type=ValueMutationNew()) + elif name == "reverse" and self.is_mutable(): + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + self.items.reverse() + tx.output.side_effects.mutation(self) + return ConstantVariable.create(None) + elif name == "remove" and self.is_mutable(): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + idx = self.call_method(tx, "index", args, kwargs) + self.call_method(tx, "pop", [idx], {}) + return ConstantVariable.create(None) + else: + return super().call_method(tx, name, args, kwargs) + + +class ListVariable(CommonListMethodsVariable): + def python_type(self) -> type: + return list + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(length={len(self.items)})" + + def debug_repr(self) -> str: + return self.debug_repr_helper("[", "]") + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach(self.items) + codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.items))) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .tensor import SymNodeVariable + + if name == "__setitem__" and self.is_mutable(): + if kwargs or len(args) != 2: + raise_args_mismatch( + tx, + name, + "2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + key, value = args + + if not key.is_python_constant(): + # probably will graph-break + super().call_method(tx, name, args, kwargs) + + tx.output.side_effects.mutation(self) + if isinstance(key, SliceVariable): + if not value.has_force_unpack_var_sequence(tx): + msg = ConstantVariable.create("can only assign an iterable") + raise_observed_exception(TypeError, tx, args=[msg]) + + key_as_const = key.as_python_constant() + if key_as_const.step == 0: + msg = ConstantVariable.create("slice step cannot be zero") + raise_observed_exception(ValueError, tx, args=[msg]) + + value_unpack = value.force_unpack_var_sequence(tx) + try: + self.items[key_as_const] = value_unpack + except Exception as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + else: + if isinstance(key, SymNodeVariable): + key = key.evaluate_expr() + else: + key = key.as_python_constant() + + try: + self.items[key] = value + except (IndexError, TypeError) as e: + raise_observed_exception( + type(e), tx, args=list(map(ConstantVariable.create, e.args)) + ) + return ConstantVariable.create(None) + + if name == "sort" and self.is_mutable(): + if len(args) != 0: + raise_args_mismatch(tx, name, "0 args", f"{len(args)} args") + key_fn_var = kwargs.pop("key", ConstantVariable.create(None)) + reverse = kwargs.pop( + "reverse", ConstantVariable.create(False) + ).as_python_constant() + if len(kwargs) != 0: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + + if key_fn_var.is_constant_none(): + keys = self.items.copy() + else: + keys = [key_fn_var.call_function(tx, [x], {}) for x in self.items] + + if not all(k.is_python_constant() for k in keys): + first_non_constant_key = None + for k in keys: + if not k.is_python_constant(): + first_non_constant_key = k + assert first_non_constant_key is not None + + try: + python_type = str(first_non_constant_key.python_type()) + except NotImplementedError: + python_type = "unknown" + + unimplemented( + gb_type="sort with non-constant keys", + context=str(first_non_constant_key), + explanation=( + f"Cannot perform sort with non-constant key. " + f"First non-constant key type: {python_type}. " + f"Most notably, we cannot sort with Tensor or SymInt keys, but we can " + f"sort ints." + ), + hints=["Use something else as the key."], + ) + + tx.output.side_effects.mutation(self) + sorted_items_with_keys = sorted( + ( + ( + x, + k.as_python_constant(), + -i if reverse else i, # extra key to ensure stable sort + ) + for i, (k, x) in enumerate(zip(keys, self.items)) + ), + key=operator.itemgetter(1, 2), + reverse=reverse, + ) + self.items[:] = [x for x, *_ in sorted_items_with_keys] + return ConstantVariable.create(None) + + if name == "__init__" and self.is_mutable(): + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + if len(args) == 0: + return ConstantVariable.create(None) + elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + (arg,) = args + tx.output.side_effects.mutation(self) + self.items[:] = arg.force_unpack_var_sequence(tx) + return ConstantVariable.create(None) + + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "__class__": + source = AttrSource(self.source, name) if self.source else None + class_type = self.python_type() + if class_type is list: + return variables.BuiltinVariable(class_type, source=source) + else: + return variables.UserDefinedClassVariable(class_type, source=source) + return super().var_getattr(tx, name) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if self.python_type() is not list: + return super().call_obj_hasattr(tx, name) + return variables.ConstantVariable.create(hasattr([], name)) + + def is_python_hashable(self): + return False + + +class DequeVariable(CommonListMethodsVariable): + def __init__( + self, + items: list[VariableTracker], + maxlen: Optional[VariableTracker] = None, + **kwargs: Any, + ) -> None: + if maxlen is None: + maxlen = ConstantVariable.create(None) + assert maxlen.is_python_constant(), ( + f"maxlen must be a constant, got: {maxlen.debug_repr()}" + ) + self.maxlen = maxlen + items = list(items) + if self.maxlen.as_python_constant() is not None: + items = items[-maxlen.as_python_constant() :] + super().__init__(items, **kwargs) + + def python_type(self) -> type: + return collections.deque + + def debug_repr(self) -> str: + if self.maxlen.as_python_constant() is None: + return self.debug_repr_helper( + "deque([", "], maxlen=" + self.maxlen.debug_repr() + ")" + ) + return self.debug_repr_helper("deque([", "])") + + def as_python_constant(self) -> collections.deque[Any]: + return self.python_type()( + [x.as_python_constant() for x in self.items], + maxlen=self.maxlen.as_python_constant(), + ) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.append_output( + codegen.create_load_python_module(collections.deque) # type: ignore[arg-type] + ) + ) + codegen.foreach(self.items) + codegen.extend_output([create_instruction("BUILD_LIST", arg=len(self.items))]) + codegen(self.maxlen) + codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False)) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "maxlen": + return self.maxlen + return super().var_getattr(tx, name) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if ( + name == "__setitem__" + and self.is_mutable() + and args + and args[0].is_python_constant() + ): + if kwargs or len(args) != 2: + raise_args_mismatch( + tx, + name, + "2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + key, value = args + assert key.is_python_constant() + assert isinstance(key.as_python_constant(), int) + tx.output.side_effects.mutation(self) + self.items[key.as_python_constant()] = value + return ConstantVariable.create(None) + + maxlen = self.maxlen.as_python_constant() + if maxlen is not None: + slice_within_maxlen = slice(-maxlen, None) + else: + slice_within_maxlen = None + + if ( + name == "extendleft" + and self.is_mutable() + and len(args) > 0 + and args[0].has_force_unpack_var_sequence(tx) + ): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + # NOTE this is inefficient, but the alternative is to represent self.items + # as a deque, which is a more intrusive change. + args[0].force_apply_to_var_sequence( + tx, lambda item: self.call_method(tx, "appendleft", [item], {}) + ) + slice_within_maxlen = slice(None, maxlen) + result = ConstantVariable.create(None) + elif name == "popleft" and self.is_mutable(): + if kwargs or len(args) > 0: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + tx.output.side_effects.mutation(self) + result, *self.items[:] = self.items + elif name == "appendleft" and len(args) > 0 and self.is_mutable(): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + tx.output.side_effects.mutation(self) + self.items[:] = [args[0], *self.items] + slice_within_maxlen = slice(None, maxlen) + result = ConstantVariable.create(None) + elif name == "insert" and len(args) > 0 and self.is_mutable(): + if kwargs or len(args) != 2: + raise_args_mismatch( + tx, + name, + "2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + if maxlen is not None and len(self.items) == maxlen: + raise_observed_exception( + IndexError, tx, args=["deque already at its maximum size"] + ) + result = super().call_method(tx, name, args, kwargs) + else: + result = super().call_method(tx, name, args, kwargs) + + if ( + slice_within_maxlen is not None + and maxlen is not None + and len(self.items) > maxlen + ): + self.items[:] = self.items[slice_within_maxlen] + return result + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if self.python_type() is collections.deque: + return variables.ConstantVariable.create(name in collections.deque.__dict__) + return super().call_obj_hasattr(tx, name) + + +class TupleVariable(BaseListVariable): + def python_type(self) -> type[tuple]: # type: ignore[type-arg] + return tuple + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(length={len(self.items)})" + + def debug_repr(self) -> str: + return self.debug_repr_helper("(", ")") + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach(self.items) + codegen.append_output(create_build_tuple(len(self.items))) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "__class__": + source = AttrSource(self.source, name) if self.source else None + class_type = self.python_type() + if class_type is tuple: + return variables.BuiltinVariable(class_type, source=source) + else: + return variables.UserDefinedClassVariable(class_type, source=source) + return super().var_getattr(tx, name) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if self.python_type() is not tuple: + return super().call_obj_hasattr(tx, name) + return variables.ConstantVariable.create(hasattr((), name)) + + def is_python_hashable(self): + return all(item.is_python_hashable() for item in self.items) + + def get_python_hash(self): + items = tuple(x.get_python_hash() for x in self.items) + return hash(items) + + def is_python_equal(self, other): + return isinstance(other, variables.TupleVariable) and all( + a.is_python_equal(b) for (a, b) in zip(self.items, other.items) + ) + + +class SizeVariable(TupleVariable): + """torch.Size(...)""" + + _nonvar_fields = { + "proxy", + *TupleVariable._nonvar_fields, + } + + def __init__( + self, + items: list[VariableTracker], + proxy: Optional[torch.fx.Proxy] = None, + **kwargs: Any, + ) -> None: + self.proxy = proxy + super().__init__(items, **kwargs) + + def debug_repr(self) -> str: + return self.debug_repr_helper("torch.Size([", "])") + + def python_type(self) -> type: + return torch.Size + + def as_proxy(self) -> Any: + if self.proxy is not None: + return self.proxy + + # torch.Size needs special handling. Normally, we pun a list-like + # container to directly contain Proxy/Node objects from FX, and FX + # knows to look inside containers (via map_aggregate). But torch.Size + # is weird; although it subclasses from tuple, it doesn't allow + # members which aren't int-like (rejecting Proxy and Node). This + # means we can't use the normal representation trick + # torch.Size([proxy0, proxy1]). I looked into seeing if I could + # relax torch.Size in PyTorch proper, but if torch.Size constructor + # sees a type that it doesn't recognize, it will try to call + # __index__() on it, so there is no BC way to actually change this + # behavior (though it occurs to me that I could have just added a + # YOLO no checking alternate constructor.) + # + # To work around this problem, I represent a torch.Size proxy as + # a straight up proxy, that would have been constructed by taking + # the constituent proxies as arguments. This trick can be generally + # used for any construct that we need a proxy for but we can't + # directly represent as an aggregate; I don't see very many examples + # of this in torchdynamo though! + + # Look for a proxy. If there are none, do the legacy behavior + tracer = None + proxies = self._as_proxy() + for proxy in proxies: + if isinstance(proxy, torch.fx.Proxy): + tracer = proxy.tracer + break + + if tracer is None: + return torch.Size(proxies) + + proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {}) + set_example_value( + proxy.node, + torch.Size( + [ + p.node.meta["example_value"] if not isinstance(p, int) else p + for p in proxies + ] + ), + ) + return proxy + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen.load_import_from("torch", "Size")) + codegen.foreach(self.items) + build_torch_size = [ + create_build_tuple(len(self.items)), + ] + create_call_function(1, False) + codegen.extend_output(build_torch_size) + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + return list(self.items) + + def numel(self, tx: "InstructionTranslator") -> VariableTracker: + from .builtin import BuiltinVariable + from .tensor import SymNodeVariable + + const_result = 1 + sym_sizes = [] + + for v in self.items: + if v.is_python_constant(): + const_result *= v.as_python_constant() + else: + assert isinstance(v, SymNodeVariable), type(v) + # Delay proxy calls until we know it will be necessary + sym_sizes.append(v) + + result = ConstantVariable.create(const_result) + if sym_sizes and const_result == 1: + # Skip multiplying by 1 + result, *sym_sizes = sym_sizes + + if not sym_sizes or const_result == 0: + return result + + mul = BuiltinVariable(operator.mul) + for v in sym_sizes: + result = mul.call_function(tx, [result, v], {}) + return result + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__getitem__": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + out = self.get_item_dyn(tx, args[0]) + return out + elif name == "numel": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return self.numel(tx) + + return super().call_method(tx, name, args, kwargs) + + def get_item_dyn( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + from .tensor import SymNodeVariable + + if isinstance(arg, SymNodeVariable): + index = arg.sym_num + else: + index = arg.as_python_constant() + + if isinstance(index, slice): + return SizeVariable(self.items[index]) + else: + assert isinstance(index, (int, torch.SymInt)) + return self.items[index] + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + return variables.ConstantVariable.create(hasattr(torch.Size, name)) + + +class NamedTupleVariable(TupleVariable): + _nonvar_fields = { + "tuple_cls", + "dynamic_attributes", + *TupleVariable._nonvar_fields, + } + + def __init__( + self, + items: list[VariableTracker], + tuple_cls: type, + dynamic_attributes: Optional[dict[str, VariableTracker]] = None, + **kwargs: Any, + ) -> None: + super().__init__(items, **kwargs) + self.tuple_cls = tuple_cls + self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {} + + def is_namedtuple(self) -> bool: + return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable( + getattr(self.tuple_cls, "_make", None) + ) + + def is_structseq(self) -> bool: + return not self.is_namedtuple() + + def fields(self) -> tuple[str, ...]: + return namedtuple_fields(self.tuple_cls) + + def debug_repr(self) -> str: + if self.is_structseq(): + # StructSequenceType(iterable) + return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items])) + # NamedTupleType(*iterable) + return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items))) + + def python_type(self) -> type: + return self.tuple_cls + + def as_python_constant(self) -> Any: + if self.is_structseq(): + # StructSequenceType(iterable) + result = self.python_type()([x.as_python_constant() for x in self.items]) + else: + # NamedTupleType(*iterable) + result = self.python_type()(*[x.as_python_constant() for x in self.items]) + + # Apply dynamic attributes if any were set + if self.dynamic_attributes: + for attr_name, attr_value in self.dynamic_attributes.items(): + # Convert VariableTracker to Python constant if needed + if hasattr(attr_value, "as_python_constant"): + python_value = attr_value.as_python_constant() + else: + raise NotImplementedError( + "Can not convert dynamic attribute without python constant value to python constant." + ) + setattr(result, attr_name, python_value) + + return result + + def as_proxy(self) -> Any: + assert self.python_type() is not SizeVariable + if self.is_structseq(): + # StructSequenceType(iterable) + return self.python_type()(self._as_proxy()) + # NamedTupleType(*iterable) + return self.python_type()(*self._as_proxy()) + + def reconstruct(self, codegen: "PyCodegen") -> None: + # Always reconstruct the NamedTuple normally first + # Constructors: + # StructSequenceType(iterable) + # NamedTupleType(*iterable) + # NamedTupleType._make(iterable) + if self.is_structseq(): + create_fn = self.tuple_cls + else: + create_fn = self.tuple_cls._make # type: ignore[attr-defined] + codegen.add_push_null( + lambda: codegen.append_output( + codegen.create_load_const_unchecked(create_fn) + ) + ) + codegen.foreach(self.items) + codegen.extend_output( + [ + create_build_tuple(len(self.items)), + ] + + create_call_function(1, False) + ) + + for name, value in self.dynamic_attributes.items(): + codegen.dup_top() + codegen(value) + codegen.extend_output(create_rot_n(2)) + codegen.store_attr(name) + + def _is_method_overridden(self, method_name: str) -> bool: + """Checks if a method is overridden in the NamedTuple subclass. + + Args: + method_name (str): The name of the method to check. + + Returns: + bool: True if the method is overridden in the subclass, False otherwise. + + Raises: + ValueError: If the NamedTuple class does not inherit from both Tuple and Object. + """ + if len(self.tuple_cls.__mro__) < 3: + raise ValueError("NamedTuple should inherit from Tuple and Object.") + if getattr(self.tuple_cls, method_name, None) == getattr( + self.tuple_cls.__mro__[-3], method_name, None + ): + return False + return True + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__setattr__": + if kwargs or len(args) != 2: + raise_args_mismatch( + tx, + name, + "2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + attr, value = args + attr = attr.as_python_constant() + if ( + # structseq is immutable + self.is_structseq() + # namedtuple directly created by `collections.namedtuple` is immutable + or self.tuple_cls.__bases__ == (tuple,) + # fields are immutable + or attr in self.fields() + ): + raise_observed_exception(AttributeError, tx) + # Subclass of namedtuple type can have dynamic attributes + tx.output.side_effects.mutation(self) + if self.source: + tx.output.side_effects.store_attr(self, attr, value) + self.dynamic_attributes[attr] = value + return ConstantVariable.create(None) + elif name == "_replace": + # NamedTuple._replace should create a new instance with replaced fields + if args: + raise_args_mismatch(tx, name, "0 args", f"{len(args)} args") + + # Get the field names for validation + fields = self.fields() + + # Start with current items (copy them) + new_items = list(self.items) + + # Replace fields specified in kwargs + for field_name, new_value in kwargs.items(): + if field_name not in fields: + raise_observed_exception( + ValueError, + tx, + args=[ + ConstantVariable.create( + f"Got unexpected field name: '{field_name}'" + ) + ], + ) + + # Replace the item at the field's index + field_index = fields.index(field_name) + new_items[field_index] = new_value + + return NamedTupleVariable(new_items, self.tuple_cls) + + return super().call_method(tx, name, args, kwargs) + + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + if isinstance(arg, SliceVariable): + # slicing a namedtuple produces a tuple + return TupleVariable( + self.items[arg.as_python_constant()], + source=None, + ) + return super().getitem_const(tx, arg) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + def check_and_create_method() -> Optional[VariableTracker]: + method = inspect.getattr_static(self.tuple_cls, name, None) + if isinstance(method, classmethod): + # We need the unbounded cls method to avoid the inline __self__ + return UserMethodVariable( + method.__func__, + variables.UserDefinedClassVariable(self.tuple_cls), + ) + elif isinstance(method, staticmethod): + # pyrefly: ignore[bad-argument-type] + return UserFunctionVariable(method.__func__) + elif inspect.isfunction(method): + return UserMethodVariable(method, self) + else: + return None + + # Avoid UserMethodVariable fallback precisely when methods NamedTuple methods have not been overwritten. + if ( + name == "_replace" + and not self._is_method_overridden("_replace") + and not self._is_method_overridden("__getattr__") + ): + # Return a BuiltinVariable for the _replace method + # Get the actual _replace method from the tuple class + actual_replace_method = getattr(self.tuple_cls, "_replace", None) + if actual_replace_method: + from ..source import AttrSource + + source = AttrSource(self.source, name) if self.source else None + return variables.GetAttrVariable(self, name, source=source) + # Fallback if _replace doesn't exist (shouldn't happen for proper NamedTuples) + return super().var_getattr(tx, name) + + if name == "_fields": + result_source = NamedTupleFieldsSource(self.source) if self.source else None + return VariableTracker.build(tx, self.fields(), source=result_source) + + if name in self.dynamic_attributes: + return self.dynamic_attributes[name] + + fields = self.fields() + if name not in fields: + method = check_and_create_method() + if not method: + return super().var_getattr(tx, name) + return method + return self.items[fields.index(name)] + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + return variables.ConstantVariable.create( + name in self.dynamic_attributes or hasattr(self.tuple_cls, name) + ) + + +class SliceVariable(VariableTracker): + def __init__( + self, + items: Sequence[VariableTracker], + tx: Optional["InstructionTranslator"] = None, + **kwargs: Any, + ) -> None: + items_to_map = items + start, stop, step = [variables.ConstantVariable.create(None)] * 3 + + if len(items_to_map) == 1: + (stop,) = items_to_map + elif len(items_to_map) == 2: + start, stop = items_to_map + elif len(items_to_map) == 3: + start, stop, step = items_to_map + else: + raise AssertionError + + # Convert TensorVariable to SymIntVariable by calling .item() + # This decomposes a[:t] to u=t.item(); a[:u] at the dynamo level + if start.is_tensor(): + assert tx is not None, ( + "tx is required when slice indices are TensorVariables" + ) + start = start.call_method(tx, "item", [], {}) + if stop.is_tensor(): + assert tx is not None, ( + "tx is required when slice indices are TensorVariables" + ) + stop = stop.call_method(tx, "item", [], {}) + if step.is_tensor(): + assert tx is not None, ( + "tx is required when slice indices are TensorVariables" + ) + step = step.call_method(tx, "item", [], {}) + + self.items = (start, stop, step) + + super().__init__(**kwargs) + + def debug_repr(self) -> str: + return "slice(" + ", ".join(i.debug_repr() for i in self.items) + ")" + + def as_proxy(self) -> slice: + return slice(*[x.as_proxy() for x in self.items]) + + def python_type(self) -> type: + return slice + + def as_python_constant(self) -> slice: + return slice(*[guard_if_dyn(x) for x in self.items]) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach(self.items) + codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items))) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + fields = ["start", "stop", "step"] + if name not in fields: + unimplemented( + gb_type="Unsupported attribute for slice() object", + context=f"var_getattr {self} {name}", + explanation=f"Expected attribute to be one of {','.join(fields)} " + f"but got {name}", + hints=[*graph_break_hints.USER_ERROR], + ) + return self.items[fields.index(name)] + + +class ListIteratorVariable(IteratorVariable): + _nonvar_fields = { + "index", + *IteratorVariable._nonvar_fields, + } + + def __init__( + self, items: list[VariableTracker], index: int = 0, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + assert isinstance(items, list) + # Removing this check as it slows things down too much + # https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492 + + # assert all(isinstance(x, VariableTracker) for x in items) + self.items = items + self.index = index + self.is_exhausted = False + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})" + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + assert self.is_mutable() + old_index = self.index + if old_index >= len(self.items) or self.is_exhausted: + self.is_exhausted = True + raise_observed_exception(StopIteration, tx) + + tx.output.side_effects.mutation(self) + self.index += 1 + return self.items[old_index] + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + return variables.ConstantVariable.create(hasattr(iter([]), name)) + + def python_type(self) -> type: + return type(iter([])) + + def as_python_constant(self) -> Any: + if self.index > 0: + raise NotImplementedError + return iter([x.as_python_constant() for x in self.items]) + + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return True + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + if self.is_exhausted: + return [] + self.is_exhausted = True + return list(self.items[self.index :]) + + def force_unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list[VariableTracker]: + return self.unpack_var_sequence(tx) + + def reconstruct(self, codegen: "PyCodegen") -> None: + if not self.is_exhausted: + remaining_items = self.items[self.index :] + else: + remaining_items = [] + codegen.foreach(remaining_items) + codegen.extend_output( + [ + create_build_tuple(len(remaining_items)), + create_instruction("GET_ITER"), + ] + ) + + +class TupleIteratorVariable(ListIteratorVariable): + pass + + +class RangeIteratorVariable(IteratorVariable): + # only needed for isinstance(..., range_iterator) to work + _nonvar_fields = { + "iter_obj", + } + + def __init__( + self, start: int, stop: int, step: int, len_: int, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + self.start = start + self.stop = stop + self.step = step + self.len = len_ + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__next__": + return self.next_variable(tx) + elif name == "__iter__": + return self + return super().call_method(tx, name, args, kwargs) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if self.python_type() is range_iterator: + ri = iter(range(0)) + return ConstantVariable(hasattr(ri, name)) + return super().call_obj_hasattr(tx, name) + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + if self.len <= 0: + raise_observed_exception(StopIteration, tx) + + self.len -= 1 + current = self.start + self.start += self.step + return ConstantVariable.create(current) + + def python_type(self) -> type: + return range_iterator + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type] + ) + codegen.append_output(codegen.create_load_const(self.start)) + codegen.append_output(codegen.create_load_const(self.stop)) + codegen.append_output(codegen.create_load_const(self.step)) + codegen.extend_output(create_call_function(3, False)) + codegen.append_output(create_instruction("GET_ITER")) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/misc.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..95816b81fa199d8427c24a9b50cbd74e24b81f24 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/misc.py @@ -0,0 +1,2129 @@ +# mypy: ignore-errors + +""" +This module contains miscellaneous variable tracker implementations for various Python types +and features used in Dynamo's symbolic execution. These classes help track and propagate +information about different kinds of variables during graph capture. + +Key classes include: +- SuperVariable: Handles super() calls and method resolution +- ExceptionVariable: Tracks exception objects +- RandomVariable: Manages random number generators +- GetAttrVariable: Tracks attribute access +- MethodWrapperVariable: Handles method wrappers +- PythonModuleVariable: Tracks Python modules +- NumpyVariable: Handles numpy functions and types +- StringFormatVariable: Manages string formatting +- DebuggingVariable: Handles print and logging +""" + +import dataclasses +import enum +import functools +import inspect +import itertools +import random +import re +import sys +import types +import warnings +from typing import Optional, TYPE_CHECKING + +import torch._C +import torch._numpy as tnp +import torch.utils._pytree as pytree + +from .. import config, graph_break_hints, trace_rules, variables +from ..bytecode_transformation import ( + create_call_function, + create_call_function_ex, + create_instruction, +) +from ..create_parameter_op import do_not_convert_to_tracable_parameter +from ..exc import raise_observed_exception, unimplemented +from ..guards import GuardBuilder, install_guard +from ..mutation_guard import unpatched_nn_module_init +from ..source import ( + AttrSource, + GenericAttrSource, + GetItemSource, + TypeMROSource, + TypeSource, + WeakRefCallSource, +) +from ..utils import ( + check_unspec_or_constant_args, + cmp_name_to_op_mapping, + identity, + is_tensor_base_attr_getter, + istype, + list_methods, + proxy_args_kwargs, + raise_args_mismatch, + tuple_methods, +) +from .base import ( + AsPythonConstantNotImplementedError, + raise_type_error_exc, + VariableTracker, +) +from .constant import ConstantVariable +from .functions import NestedUserFunctionVariable, UserFunctionVariable +from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class NO_SUCH_SUBOBJ: + pass + + +class SuperVariable(VariableTracker): + _nonvar_fields = { + *VariableTracker._nonvar_fields, + } + + def __init__(self, typevar, objvar=None, **kwargs) -> None: + super().__init__(**kwargs) + # typevar is the first argument to super(). In the case where no argument + # is provided to super(), it is the __class__ object where + # the super() function is being called + self.typevar = typevar + # objvar here must be an instance or subtype of typevar. + # In the case where super() is called without arguments, it is the first argument + # to the current function where super() is called from (self for regular method, + # cls for a classmethod) + self.objvar = objvar + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super))) + codegen(self.typevar) + if self.objvar is not None: + codegen(self.objvar) + codegen.extend_output(create_call_function(2, False)) + else: + codegen.extend_output(create_call_function(1, False)) + + def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name): + if not self.objvar: + unimplemented( + gb_type="1-arg super not implemented", + context="", + explanation=f"Dynamo failed to trace attribute `{name}` accessed " + f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) " + "because one-argument of super() is not supported.", + hints=[ + "Use two-argument super(type, object_or_type).", + ], + ) + search_type = self.typevar.as_python_constant() + + # The rest of this function does two things: + # - Walk the mro to find where the attribute comes from to be + # able to provide accurate source + # - Call the getattr to get the object + + # Find the class object, where the function lives. + # When objvar is "self", use type(self), when objvar is "cls", use it as-is + type_to_use = self.objvar.python_type() + type_to_use_source = ( + TypeSource(self.objvar.source) if self.objvar.source else None + ) + if issubclass(type_to_use, type): + type_to_use = self.objvar.value + type_to_use_source = self.objvar.source + + source = None + search_mro = type_to_use.__mro__ + + try: + start_index = search_mro.index(search_type) + 1 + except ValueError: + # Corner case where the typevar is not in the mro of the objvar + # https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8843-L8844 + return getattr(super(search_type, type_to_use), name), None + # Implemented based on https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8812 + # super has its getattro implementation. The key point is that instead of calling getattr, it checks the + # attribute in the class __dict__ + for index in range(start_index, len(search_mro)): + # Dont call getattr, just check the __dict__ of the class + if resolved_getattr := search_mro[index].__dict__.get(name, NO_SUCH_SUBOBJ): + if resolved_getattr is not NO_SUCH_SUBOBJ: + # Equivalent of something like type(L['self']).__mro__[1].attr_name + if type_to_use_source: + source = AttrSource( + GetItemSource(TypeMROSource(type_to_use_source), index), + name, + ) + return resolved_getattr, source + + unimplemented( + gb_type="Unable to resolve super getattr", + context="", + explanation=f"Dynamo failed to trace attribute `{name}` accessed " + f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) " + "because the resolved attribute type is not supported.", + hints=[ + "Ensure the attribute exists in the parent class.", + "Check the arguments passed to `super()`.", + ], + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + # Check if getattr is a constant. If not, delay the actual work by + # wrapping the result in GetAttrVariable. Mostly super is called with a + # method, so most of the work is delayed to call_function. + # + # We could have just implemented a const_getattr. However, super is + # special when it comes to finding sources. Compared to other VTs, super + # requires the attr name to walk the mro and find the actual source (and + # not just AttrSource). + value, source = self._resolved_getattr_and_source(self, name) + if not variables.ConstantVariable.is_literal(value): + return GetAttrVariable(self, name) + if source: + install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) + return variables.ConstantVariable.create(value, source=source) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + inner_fn, source = self._resolved_getattr_and_source(self, name) + # This essentially simulates CPython's `super_getattro`: + # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L11138-L11168 + # where `inner_fn` is the VT for `res = _super_lookup_descr(...)`. + # + # However, `res`'s type needs to be checked for `tp_descr_get`, and + # applied if it has one. We currently don't have polyfills for all the + # relevant `tp_descr_get`, so we explicitly handle the cases we care + # about here (e.g., note the staticmethod, classmethod cases). + if inner_fn is object.__init__: + return LambdaVariable(identity) + elif inner_fn is torch.nn.Module.__init__: + objvar = self.objvar + from ..side_effects import AttributeMutationNew + + if ( + isinstance(objvar, variables.UserDefinedObjectVariable) + and isinstance(objvar.mutation_type, AttributeMutationNew) + and not (args or kwargs) + ): + with do_not_convert_to_tracable_parameter(): + fn_vt = VariableTracker.build( + tx, unpatched_nn_module_init, source=source + ) + return fn_vt.call_function(tx, [self.objvar] + args, kwargs) + else: + unimplemented( + gb_type="Unsupported super().__init__() call", + context=f"call_method {self} {name} {args} {kwargs}", + explanation="Dynamo encountered a super().__init__() call " + f"on {objvar} that resolved to a `torch.nn.Module.__init__()` " + "call that we cannot trace.", + hints=[*graph_break_hints.DIFFICULT], + ) + elif ( + self.objvar.source + and hasattr(inner_fn, "__name__") + and inner_fn.__name__ == "__new__" + and variables.UserDefinedClassVariable.is_supported_new_method(inner_fn) + ): + user_cls = inner_fn.__self__ + if hasattr(user_cls, "__module__") and user_cls.__module__ == "builtins": + user_cls_vt = variables.BuiltinVariable(user_cls) + else: + user_cls_source = source.member + user_cls_vt = variables.UserDefinedClassVariable( + user_cls, source=user_cls_source + ) + return user_cls_vt.call_method(tx, "__new__", args, kwargs) + elif isinstance(inner_fn, staticmethod) and isinstance( + inner_fn.__func__, types.FunctionType + ): + fn_vt = VariableTracker.build(tx, inner_fn.__func__, source=source) + return fn_vt.call_function(tx, args, kwargs) + elif isinstance(inner_fn, classmethod) and isinstance( + inner_fn.__func__, types.FunctionType + ): + if isinstance(self.objvar, variables.UserDefinedClassVariable): + # super().classmethod is called from a classmethod itself. So, + # super was converted to super(__class__, cls) in bytecode and + # therefore we have to propagate the cls. + cls_variable = self.objvar + else: + # current function is an instance method, therefore super was + # converted to super(__class__, self). We have to find + # type(self) to bind the cls to the parent classmethod. + # Note that it can't be the self.typevar because __class__ is + # the class where the method is defined, which could be + # different from type(self) with polymorphism. + cls_source = None + if self.objvar.source: + cls_source = TypeSource(self.objvar.source) + cls_variable = VariableTracker.build( + tx, self.objvar.value_type, cls_source + ) + + fn_vt = VariableTracker.build( + tx, inner_fn.__func__, source=AttrSource(source, "__func__") + ) + return fn_vt.call_function(tx, [cls_variable, *args], kwargs) + elif isinstance(inner_fn, types.FunctionType): + fn_vt = VariableTracker.build(tx, inner_fn, source=source) + return fn_vt.call_function(tx, [self.objvar] + args, kwargs) + elif isinstance(inner_fn, types.MethodType): + return variables.UserMethodVariable( + inner_fn.__func__, self.objvar, source=source + ).call_function(tx, args, kwargs) + elif is_standard_setattr(inner_fn) and isinstance( + self.objvar, UserDefinedObjectVariable + ): + return self.objvar.method_setattr_standard(tx, *args, **kwargs) + elif inner_fn is object.__delattr__: + attr = args[0] + try: + attr = attr.as_python_constant() + except NotImplementedError as exc: + unimplemented( + gb_type="Non-constant attribute given to `super().__delattr__()`", + context=f"call_method {self} {name}", + explanation="Dynamo requires the attribute name passed to " + "`super().__delattr__(...)` to be a constant (string).", + hints=[ + "Ensure the attribute name is a string literal or a constant variable." + ], + from_exc=exc, + ) + if not tx.output.side_effects.is_attribute_mutation(self.objvar): + unimplemented( + gb_type="Attempted super().__delattr__() on an object without mutation tracking", + context=f"call_method {self} {name}", + explanation="Dynamo needs to track mutations on an object " + "before `super().__delattr__` can be used on it. But the " + f"object ({self.objvar}) doesn't have attribute mutation " + "tracking enabled.", + hints=[ + "Ensure the object is tracked by Dynamo's side effect system.", + *graph_break_hints.DYNAMO_BUG, + ], + ) + + tx.output.side_effects.store_attr( + self.objvar, attr, variables.DeletedVariable() + ) + return variables.ConstantVariable(None) + elif ( + isinstance(self.objvar, variables.UserDefinedDictVariable) + and inner_fn in self.objvar._dict_methods + ): + return self.objvar._dict_vt.call_method(tx, name, args, kwargs) + elif ( + isinstance(self.objvar, variables.UserDefinedSetVariable) + and inner_fn in self.objvar._set_methods + ): + return self.objvar._set_vt.call_method(tx, name, args, kwargs) + elif ( + isinstance(self.objvar, variables.UserDefinedTupleVariable) + and inner_fn in tuple_methods + ): + return self.objvar._tuple_vt.call_method(tx, name, args, kwargs) + elif ( + isinstance(self.objvar, variables.UserDefinedListVariable) + and inner_fn in list_methods + ): + return self.objvar._list_vt.call_method(tx, name, args, kwargs) + elif inner_fn is object.__getattribute__: + # object.__getattribute__ has no side-effects. We can directly call + # __getattribute__ to access the attribute. + attr_name = args[0].value + if tx.output.side_effects.has_pending_mutation_of_attr( + self.objvar, attr_name + ): + result = tx.output.side_effects.load_attr( + self.objvar, attr_name, deleted_ok=True + ) + if isinstance(result, variables.DeletedVariable): + raise_observed_exception(AttributeError, tx) + return result + + try: + # NB - use object.__getattribute__ to prevent running any user code + attr_value = object.__getattribute__(self.objvar.value, attr_name) + except AttributeError: + raise_observed_exception(AttributeError, tx) + + attr_source = None + if self.objvar.source is not None: + # setup a object.__getattribute__(self.objvar, name) source + attr_source = GenericAttrSource(self.objvar.source, attr_name) + return VariableTracker.build(tx, attr_value, attr_source) + elif inner_fn is torch._C._disabled_torch_function_impl: + # See `THPModule_disable_torch_function` for the C impl. + # The signature of _disabled_torch_function_impl is similar to + # `__torch_function__`, just without the first `cls` argument: + # * (func, types, args, kwargs) + func = args[0] + tf_kwargs = {} + tf_args = args[2].items + for hash_key_vt, value_vt in args[3].items.items(): + key_str = hash_key_vt.vt.as_python_constant() + tf_kwargs[key_str] = value_vt + + tx_old = tx.symbolic_torch_function_state.torch_function_subclass_enabled + tx.symbolic_torch_function_state.torch_function_subclass_enabled = False + try: + return func.call_function(tx, tf_args, tf_kwargs) + finally: + tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( + tx_old + ) + elif ( + isinstance(inner_fn, types.MethodDescriptorType) + and inner_fn in trace_rules.get_tensor_method() + ): + # FunctionType but implementation is in C, we support some of these, + # e.g., tensor ops like `torch.Tensor.to`. + fn_var = VariableTracker.build(tx, inner_fn, source) + return fn_var.call_function(tx, [self.objvar] + args, kwargs) + + unimplemented( + gb_type="Attempted to call a super() attribute that is " + "not a function or method", + context=f"call_method {self} {name}", + explanation="Dynamo does not know how to trace the call " + f"`super().{name}()` because `super().{name}` is not a " + "function or method attribute.", + hints=[ + "Ensure the attribute accessed via `super()` is a standard method or function.", + ], + ) + + +class ExceptionVariable(VariableTracker): + # The ExceptionVariable corresponds to the BaseException class in Python + def __init__( + self, exc_type, args, init_kwargs=None, source=None, mutation_type=None + ) -> None: + super().__init__(source=source, mutation_type=mutation_type) + self.exc_type = exc_type + self.args = args + if init_kwargs: + unimplemented( + gb_type="Keyword args passed to exception constructor", + context=f"{self} with kwargs {init_kwargs}", + explanation="Dynamo does not know how to handle keyword args passed to an exception constructor", + hints=[*graph_break_hints.SUPPORTABLE], + ) + # When raising a new exception while another exception is already being + # handled, the new exception's __context__ attribute is automatically + # set to the handled exception. + self.__context__ = ConstantVariable(None) + # Set when user raised an exception from another: + # raise ... from ... + self.__cause__ = ConstantVariable(None) + # Boolean flag that controls whether the __context__ attribute is set + self.__suppress_context__ = ConstantVariable(False) + # Contains the call stack where the exception was raised. Dynamo does + # not track traceback. So, this variable is always set to None + self.__traceback__ = ConstantVariable(None) + + def set_context(self, context: "ExceptionVariable"): + self.__context__ = context + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", self.exc_type.__name__) + ) + codegen.foreach(self.args) + codegen.call_function(len(self.args), False) + + def codegen_attr(name: str) -> None: + attr = getattr(self, name) + if istype(attr, ConstantVariable): + assert attr.value in (True, False, None), attr + else: + codegen.dup_top() + codegen(attr) + codegen.extend_output(codegen.rot_n(2)) + codegen.store_attr(name) + + codegen_attr("__context__") + codegen_attr("__cause__") + codegen_attr("__suppress_context__") + + def python_type(self): + return self.exc_type + + def call_setattr( + self, + tx: "InstructionTranslator", + name_var: VariableTracker, + val: VariableTracker, + ): + def raise_error(msg): + raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)]) + + name = name_var.as_python_constant() + if name == "__context__": + self.set_context(val) + elif name == "__cause__": + if val.is_constant_none() or isinstance( + val, + ( + variables.BuiltinVariable, + variables.ExceptionVariable, + variables.UserDefinedExceptionClassVariable, + variables.UserDefinedExceptionObjectVariable, + ), + ): + self.__cause__ = val + self.__suppress_context__ = variables.ConstantVariable(True) + else: + raise_error("exception cause must be None or derive from BaseException") + elif name == "__suppress_context__": + if val.is_constant_match(True, False): + self.__suppress_context__ = val + else: + raise_error("exception cause must be None or derive from BaseException") + elif name == "__traceback__": + if val.is_constant_none(): + self.__traceback__ = val + else: + unimplemented( + gb_type="Set Exception object `__traceback__` attribute to not-`None`", + context=f"call_setattr {self} {name}", + explanation="Dynamo does not support setting the attribute " + "'__traceback__' on tracked exception objects to anything " + "other than None.", + hints=[ + "Avoid setting '__traceback__' on exception objects " + "within traced code, or set it to None." + ], + ) + else: + unimplemented( + gb_type="Unsupported attribute assignment on Exception object", + context=f"call_setattr {self} {name}", + explanation="Dynamo does not support setting the attribute " + f"'{name}' on tracked exception objects. Only `__context__`, " + "`__cause__`, `__suppress_context__`, and `__traceback__` are supported.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + return variables.ConstantVariable(None) + + def call_method(self, tx, name, args, kwargs): + if name == "__setattr__": + return self.call_setattr(tx, *args) + elif name == "with_traceback": + [tb] = args + self.call_setattr(tx, ConstantVariable("__traceback__"), tb) + return self + else: + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx, name): + if name == "__context__": + return self.__context__ + elif name == "__cause__": + return self.__cause__ + elif name == "__suppress_context__": + return self.__suppress_context__ + elif name == "__traceback__": + return variables.ConstantVariable(None) + elif name == "args": + return variables.ListVariable(self.args, source=self.source) + return super().var_getattr(tx, name) + + def __str__(self): + return f"{self.__class__.__name__}({self.exc_type})" + + __repr__ = __str__ + + +class UnknownVariable(VariableTracker): + """ + It could be anything! + """ + + +class DelayGraphBreakVariable(UnknownVariable): + """ + Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION. + """ + + def __init__(self, msg=None, **kwargs): + super().__init__(**kwargs) + self.msg = msg + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + unimplemented( + gb_type="Unsupported function call (delayed)", + context=f"source: {self.source}", + explanation="Dynamo determined that a graph break should occur " + f"when calling `{self.source.name}`. Reason: {self.msg}", + hints=[], + ) + + +class ComptimeVariable(VariableTracker): + """ + This variable is special, it lets you execute arbitrary code at + Dynamo compile time + """ + + def reconstruct(self, codegen: "PyCodegen"): + raise NotImplementedError("comptime is special form") + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + from ..comptime import comptime + + # To support the comptime.print_graph convenience accessors + return VariableTracker.build( + tx, getattr(comptime, name), source=AttrSource(self.source, name) + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from ..comptime import ComptimeContext + + # TODO: support an expression form as well + # Second argument is runtime lambda, ignored + if kwargs or len(args) > 2: + raise_args_mismatch( + tx, + "comptime()", + "at most 2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + fn = args[0] + if isinstance(fn, UserFunctionVariable): + fn.get_function()(ComptimeContext(tx)) + elif isinstance(fn, NestedUserFunctionVariable): + # We have to manually bind the freevars ourselves + code = fn.get_code() + if fn.closure: + raise_type_error_exc( + tx, + f"comptime function must not have free variables, but these variables were free: {code.co_freevars}", + ) + func = types.FunctionType( + code, + fn.f_globals, + fn.fn_name.as_python_constant(), + tuple(fn.defaults.items) if fn.defaults else None, + # We could automatically promote free variables into + # ComptimeVar but this is confusing if you access + # a free variable that we actually DO have the runtime + # value for + # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items) + (), + ) + func(ComptimeContext(tx)) + else: + raise RuntimeError(f"unsupported argument to comptime: {type(fn)}") + + return variables.ConstantVariable.create(None) + + +class CellVariable(VariableTracker): + # If the cell existed before Dynamo tracing started, this will be the + # VariableTracker that represents the cell content. + # + # Note that all mutation to the cell (i.e., its content) will be buffered in + # SideEffects, rather than being reflected here. One can think of + # `CellVariable` as a special case for `UserDefinedObjectVariable`. + pre_existing_contents: Optional[VariableTracker] + + # This is set when this cell can be referenced via `LOAD/STORE_DEREF` in the + # root frame via this name (e.g., the name is in `co_cellvars/co_freevars`). + local_name: Optional[str] = None + + def __init__( + self, pre_existing_contents: Optional[VariableTracker] = None, **kwargs + ) -> None: + super().__init__(**kwargs) + self.pre_existing_contents = pre_existing_contents + + +class NewGlobalVariable(VariableTracker): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + +def produce_trampoline_autograd_apply(fn_cls): + def trampoline_autograd_apply(*args, **kwargs): + return fn_cls.apply(*args, **kwargs) + + trampoline_autograd_apply._origin = produce_trampoline_autograd_apply + return trampoline_autograd_apply + + +class AutogradFunctionVariable(VariableTracker): + """represents a torch.autograd.Function subclass""" + + _nonvar_fields = { + "fn_cls", + *VariableTracker._nonvar_fields, + } + + def __init__(self, fn_cls, **kwargs) -> None: + super().__init__(**kwargs) + self.fn_cls = fn_cls + + def call_apply(self, tx: "InstructionTranslator", args, kwargs): + requires_grad = False + + def visit(vt): + nonlocal requires_grad + if vt.is_tensor(): + if vt.requires_grad is not False: + requires_grad = True + if isinstance(vt, variables.NNModuleVariable): + if vt.is_training(tx): + requires_grad = True + + VariableTracker.visit(visit, (args, kwargs)) + + if requires_grad and torch.is_grad_enabled(): + if config.capture_autograd_function is False: + warnings.warn( + "The config.capture_autograd_function flag is deprecated, it's now always true." + ) + + from torch._functorch.autograd_function import ( + autograd_function_forward_rewritten, + ) + from torch.autograd.function import _is_setup_context_defined + + forward_fn = self.fn_cls.forward + + is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context) + if is_setup_ctx_defined: + # If setup_context is defined, we generate a new forward function which includes + # the original forward and setup_context function, and trace the new forward function. + forward_fn = autograd_function_forward_rewritten( + self.fn_cls.forward, self.fn_cls.setup_context + ) + + vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined] + if vjp_fn is not torch.autograd.Function.vjp: + unimplemented( + gb_type="Unsupported custom vjp", + context=f"call_apply {self} {args} {kwargs}", + explanation="Dynamo does not support tracing " + "`torch.autograd.Function` subclasses that define " + "a custom `vjp` method.", + hints=[ + "Remove the custom `vjp` method if possible.", + "Use standard `backward` instead if applicable.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined] + if jvp_fn is not torch.autograd.Function.jvp: + unimplemented( + gb_type="Unsupported custom jvp", + context=f"call_apply {self} {args} {kwargs}", + explanation="Dynamo does not support tracing " + "`torch.autograd.Function` subclasses that define " + "a custom `jvp` method.", + hints=[ + "Remove the custom `jvp` method if possible.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + from .higher_order_ops import AutogradFunctionApplyVariable + + source = self.source + if source is None: + source = AttrSource( + tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__ + ) + + val = AutogradFunctionApplyVariable( + forward_fn, + self.fn_cls.backward, + source, + source=AttrSource(source, member="apply"), + ).call_function(tx, args, kwargs) + # Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping + # the forward function, as we don't want to generate guards for new_forward.__closure__ + # if forward is rewritten by autograd_function_forward_rewritten. + # But we still need to generate correct guards for the original forward and setup_context + # functions, so we have to add guards manually. + if self.source: + fwd_src = AttrSource(self.source, "forward") + install_guard(fwd_src.make_guard(GuardBuilder.CLOSURE_MATCH)) + if is_setup_ctx_defined: + setup_ctx_src = AttrSource(self.source, "setup_context") + install_guard(setup_ctx_src.make_guard(GuardBuilder.CLOSURE_MATCH)) + + return val + + if self.source: + source = AttrSource(self.source, "forward") + else: + source = None + + fn = self.fn_cls.forward + ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) + args = [ctx, *args] + if isinstance(fn, types.FunctionType): + sig = inspect.signature(fn) + if len(args) - 1 == len(sig._parameters): + args = args[1:] # Don't use context + fn_vt = VariableTracker.build(tx, fn, source=source) + return fn_vt.call_function(tx, args, kwargs) + elif isinstance(fn, types.MethodType): + return variables.UserMethodVariable( + fn.__func__, + variables.UserDefinedClassVariable(self.fn_cls), + source=source, + ).call_function(tx, args, kwargs) + else: + unimplemented( + gb_type="Non-function or method in subclass of torch.autograd.Function", + context=f"call_apply {self} {args} {kwargs}", + explanation="Dynamo requires the `forward` attribute of a " + "`torch.autograd.Function` subclass to be a standard Python " + f"function or method. Found type `{type(fn).__name__}` instead.", + hints=[ + "Ensure the `forward` method is defined as a regular " + "function or instance method." + ], + ) + + def call_backward(self, tx: "InstructionTranslator", args, kwargs): + fn = self.fn_cls.backward + assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction + assert isinstance(fn, types.FunctionType) + + fn_source = AttrSource(self.source, "backward") + fn_vt = VariableTracker.build(tx, fn, source=fn_source) + return fn_vt.call_function(tx, args, kwargs) + + def call_function(self, tx: "InstructionTranslator", args, kwargs): + return AutogradFunctionVariable(self.fn_cls) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ): + from .builder import wrap_fx_proxy + + if name == "apply": + if trace_rules.is_callable_allowed(self.fn_cls): + trampoline_autograd_apply = produce_trampoline_autograd_apply( + self.fn_cls + ) + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + trampoline_autograd_apply, + *proxy_args_kwargs(args, kwargs), + ), + ) + else: + return self.call_apply(tx, args, kwargs) + + elif name == "backward": + return self.call_backward(tx, args, kwargs) + else: + source = AttrSource(self.source, name) if self.source is not None else None + try: + obj = inspect.getattr_static(self.fn_cls, name) + except AttributeError: + obj = None + + if isinstance(obj, staticmethod): + func = obj.__get__(self.fn_cls) + if source is not None: + return ( + trace_rules.lookup(func) + .create_with_source(func, source=source) + .call_function(tx, args, kwargs) + ) + else: + return trace_rules.lookup(func)(func).call_function( + tx, args, kwargs + ) + elif isinstance(obj, classmethod): + return variables.UserMethodVariable( + obj.__func__, self, source=source + ).call_function(tx, args, kwargs) + else: + unimplemented( + gb_type="Unsupported autograd.Function method", + context=f"call_method {self} {name}", + explanation="Dynamo does not support calling the method " + f"`{name}` directly on the `torch.autograd.Function` " + "instance. Supported methods include `apply`, `backward`, " + "static methods, and class methods.", + hints=[ + "Ensure the method is decorated with `@staticmethod` " + "or `@classmethod` if it's meant to be called on the class.", + ], + ) + + +@dataclasses.dataclass +class SavedTensorBox: + tensors: list[VariableTracker] = dataclasses.field(default_factory=list) + + +class AutogradFunctionContextVariable(UserDefinedObjectVariable): + """ + Tracks an autograd.Function() context using mutation tracking in side_effects.py + """ + + _nonvar_fields = { + "proxy", + "inference", + "saved_tensors", + *UserDefinedObjectVariable._nonvar_fields, + } + + def __init__( + self, + value, + value_type=None, + inference=False, + saved_tensors=None, + needs_input_grad=None, + non_differentiable=None, + **kwargs, + ) -> None: + super().__init__(value=value, value_type=value_type, **kwargs) + self.inference = inference + self.saved_tensors = saved_tensors + self.needs_input_grad = needs_input_grad + self.non_differentiable = non_differentiable + + @staticmethod + def create(tx: "InstructionTranslator", args=None, kwargs=None): + needs_input_grad = None + if args and not kwargs: + needs_input_grad = tuple(x.is_tensor() and x.requires_grad for x in args) + out = tx.output.side_effects.track_object_new( + None, + torch.autograd.function.FunctionCtx, + functools.partial( + AutogradFunctionContextVariable, + inference=True, + saved_tensors=SavedTensorBox(), + needs_input_grad=needs_input_grad, + ), + {}, + ) + return out + + def as_proxy(self): + if self.proxy is None: + unimplemented( + gb_type="proxy not set", + context=f"as_proxy {self}", + explanation="Dynamo requires the autograd.Function context " + "to be initialized with a proxy.", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + return self.proxy + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__setattr__": + return super().call_method(tx, name, args, kwargs) + elif name == "mark_non_differentiable": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + self.non_differentiable = proxy_args_kwargs(args, {})[0] + return variables.ConstantVariable.create(None) + + if name != "save_for_backward": + unimplemented( + gb_type="Unsupported autograd.Function context method", + context=f"call_method {self} {name}", + explanation="Dynamo does not support calling the method " + f"`{name}` on `autograd.Function` context objects. Supported " + "methods are `__setattr__`, `save_for_backward` and " + "`mark_non_differentiable`.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + if self.saved_tensors is None: + unimplemented( + gb_type="Unsupported autograd.Function context `save_for_backward`", + context=f"call_method {self} {name}", + explanation="Dynamo requires the `saved_tensors` attribute " + "to be initialized on the `autograd.Function` context object.", + hints=[ + "Ensure that the `saved_tensors` attribute is properly " + "initialized before calling `save_for_backward`. " + "`save_for_backward` only supported on a newly constructed `torch.autograd.function.FunctionCtx`.", + ], + ) + + if not self.inference: + if kwargs or not self.source: + raise_type_error_exc( + tx, "save_for_backward() requires a source and no keyword arguments" + ) + tx.output.side_effects.track_save_for_backward(self, args) + + # In eager mode, multiple calls to .save_for_backward() will overwrite previous calls. + if len(self.saved_tensors.tensors) > 0: + self.saved_tensors.tensors = [] + for arg in args: + self.saved_tensors.tensors.append(arg) + return variables.ConstantVariable.create(None) + + def var_getattr(self, tx: "InstructionTranslator", name): + if name in ["save_for_backward", "mark_non_differentiable"]: + return LambdaVariable( + lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) + ) + if name == "saved_tensors" and self.saved_tensors is not None: + return variables.TupleVariable(list(self.saved_tensors.tensors)) + if name == "needs_input_grad": + if self.needs_input_grad is not None: + return variables.ConstantVariable.create(self.needs_input_grad) + if self.source: + source = AttrSource(self.source, "needs_input_grad") + return VariableTracker.build(tx, self.value.needs_input_grad, source) + + return super().var_getattr(tx, name) + + +class AutogradEngineVariable(UserDefinedObjectVariable): + """ + Represents a torch._C._ImperativeEngine instance. + """ + + def __init__( + self, + value, + value_type=None, + **kwargs, + ) -> None: + super().__init__(value=value, value_type=value_type, **kwargs) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "queue_callback": + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + assert tx.one_graph or tx.error_on_graph_break, ( + "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" + ) + # queue_callback is a method-wrapper, no need to insert a guard. + fn_vt = VariableTracker.build( + tx, + torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback, + ) + return fn_vt.call_function( + tx, + (tx.output.side_effects.get_ca_final_callbacks_var(), *args), + kwargs, + ) + else: + unimplemented( + gb_type="Unsupported torch._C._ImperativeEngine.queue_callback()", + context=f"call_method {self} {name}", + explanation="queue_callback() is only supported when " + "Compiled Autograd is enabled with fullgraph=True.", + hints=[], + ) + else: + unimplemented( + gb_type="Unsupported torch._C._ImperativeEngine method", + context=f"call_method {self} {name}", + explanation="Dynamo only supports the `queue_callback` method " + f"on a torch._C._ImperativeEngine instance, but found: `{name}`.", + hints=[], + ) + + +class LambdaVariable(VariableTracker): + def __init__(self, fn, **kwargs) -> None: + super().__init__(**kwargs) + self.fn = fn + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return self.fn(*args, **kwargs) + + +class GetAttrVariable(VariableTracker): + _nonvar_fields = { + "name", + "py_type", + *VariableTracker._nonvar_fields, + } + + def __init__(self, obj, name, py_type=None, **kwargs) -> None: + super().__init__(**kwargs) + assert isinstance(obj, VariableTracker) + assert isinstance(name, str) + self.obj = obj + self.name = name + self.py_type = py_type # In some cases we know the type (ex. tensor methods) + + def python_type(self): + if self.py_type is not None: + return self.py_type + else: + return super().python_type() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.obj}, {self.name})" + + @staticmethod + def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr): + return getattr(base_proxy, attr) + + def as_proxy(self): + return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name) + + def as_python_constant(self): + constant = self.obj.as_python_constant() + try: + return getattr(constant, self.name) + except AttributeError: + raise NotImplementedError(f"{self} is not a constant") from None + + def const_getattr(self, tx: "InstructionTranslator", name): + if not isinstance(self.obj, variables.NNModuleVariable): + raise NotImplementedError + step1 = tx.output.get_submodule(self.obj.module_key) + if self.name not in step1.__dict__: + raise NotImplementedError + step2 = inspect.getattr_static(step1, self.name) + if name not in step2.__dict__: + raise NotImplementedError + return inspect.getattr_static(step2, name) + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.obj) + codegen.extend_output(codegen.create_load_attrs(self.name)) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return self.obj.call_method(tx, self.name, args, kwargs) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if ( + name in ("__getitem__", "get") + and self.name == "__dict__" + and not kwargs + and args[0].is_python_constant() + and isinstance( + self.obj, + ( + variables.UserDefinedObjectVariable, + variables.NNModuleVariable, + variables.UserDefinedClassVariable, + ), + ) + ): + obj = self.obj + key = args[0].as_python_constant() + if obj.has_key_in_generic_dict(tx, key): + # redirect to var_getattr on the original obj + return obj.var_getattr(tx, key) + + # Return the default value for get + if name == "get": + if len(args) == 2: + return args[1] + else: + return variables.ConstantVariable(None) + + elif ( + name == "__contains__" + and self.name == "__dict__" + and len(args) == 1 + and args[0].is_python_constant() + and not kwargs + and isinstance( + self.obj, + ( + variables.UserDefinedObjectVariable, + variables.NNModuleVariable, + variables.UserDefinedClassVariable, + ), + ) + ): + obj = self.obj + key = args[0].as_python_constant() + if obj.has_key_in_generic_dict(tx, key): + return variables.ConstantVariable(True) + else: + return variables.ConstantVariable(False) + + elif name == "__setitem__" and self.name == "__dict__" and not kwargs: + if isinstance(self.obj, variables.UserDefinedObjectVariable): + # Bypass any custom setattr as we are updating the `__dict__` itself + return self.obj.method_setattr_standard( + tx, args[0], args[1], directly_update_dict=True + ) + if isinstance(self.obj, variables.NNModuleVariable): + # This matches how `setattr` is handled for NNModuleVariable + self.obj.convert_to_unspecialized(tx) + + return super().call_method(tx, name, args, kwargs) + + def get_forwarded_dict(self, tx): + assert ( + self.name == "__dict__" + and isinstance(self.obj, variables.UserDefinedClassVariable) + and not tx.output.side_effects.has_pending_mutation(self.obj) + ) + self.obj.ban_mutation = True + return VariableTracker.build(tx, self.obj.value.__dict__, self.source) + + +class MethodWrapperVariable(VariableTracker): + def __init__(self, method_wrapper, **kwargs) -> None: + super().__init__(**kwargs) + self.method_wrapper = method_wrapper + self._builtin_fns = {} + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if is_tensor_base_attr_getter(self.method_wrapper) and isinstance( + args[0], variables.TensorVariable + ): + if not (len(args) == 1 and len(kwargs) == 0): + raise_type_error_exc( + tx, "tensor attribute getter takes exactly one argument" + ) + + return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__) + + # method-wrapper variables are common in __init__ calls. For example, + # str("foo").__init__ is a method-wrapper. These method wrappers point + # to C functions. Here we intercept if these method-wrappers are from + # builtins and then call the function counterpart directly by obtaining + # the self object. + self_obj = self.method_wrapper.__self__ + wrapper_name = self.method_wrapper.__name__ + # TODO(dynamo-team) - We can perhaps expand the scope to more names and + # more builtins. + if wrapper_name == "__init__": + fn_obj = type(self_obj).__init__ + if fn_obj is object.__init__: + return variables.BuiltinVariable(object).call_method( + tx, wrapper_name, [self_obj, *args], kwargs + ) + elif ( + sys.version_info >= (3, 14) + # for some reason, even if the below check passes, + # self.method_wrapper may not be the same as type.__dict__["__annotations__"].__get__ + and self_obj is type.__dict__["__annotations__"] + and wrapper_name == "__get__" + ): + from .builder import SourcelessBuilder + + if len(args) == 1 and not kwargs: + try: + return SourcelessBuilder.create( + tx, self.method_wrapper(args[0].as_python_constant()) + ) + except AttributeError: + raise_observed_exception(AttributeError, tx) + except AsPythonConstantNotImplementedError: + pass + + unimplemented( + gb_type="unsupported type.__dict__['__annotations__'].__get__ call", + context=f"call_function {self}, args: {args}, kwargs: {kwargs}", + explanation="`torch.compile` only supports calling type.__dict__['__annotations__'].__get__ " + "on a single constant argument (i.e. a type).", + hints=[ + "Make sure your call to type.__dict__['__annotations__'] only has " + "one positional argument (no keyword arguments).", + "Make sure the argument to type.__dict__['__annotations__'] is a constant " + "(i.e. type). For example, `object`, `int`, `MyCustomClass`.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + return super().call_function(tx, args, kwargs) + + def is_python_constant(self): + return True + + def as_python_constant(self): + return self.method_wrapper + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +class GetSetDescriptorVariable(VariableTracker): + def __init__(self, desc, **kwargs) -> None: + super().__init__(**kwargs) + self.desc = desc + + def var_getattr(self, tx: "InstructionTranslator", name): + if name == "__get__" and self.source: + source = AttrSource(self.source, "__get__") + return VariableTracker.build(tx, self.desc.__get__, source) + else: + return super().var_getattr(tx, name) + + def is_python_constant(self): + return True + + def as_python_constant(self): + return self.desc + + +class PythonModuleVariable(VariableTracker): + _nonvar_fields = { + "value", + "is_torch", + *VariableTracker._nonvar_fields, + } + + def __init__(self, value: types.ModuleType, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + self.is_torch = self.value is torch or self.value.__name__.startswith("torch.") + + def python_type(self): + return types.ModuleType + + def as_python_constant(self): + return self.value + + def __repr__(self) -> str: + return f"PythonModuleVariable({self.value})" + + def call_obj_hasattr(self, tx: "InstructionTranslator", name): + result = hasattr(self.value, name) + return variables.ConstantVariable.create(result) + + def var_getattr(self, tx: "InstructionTranslator", name): + if tx.output.side_effects.has_pending_mutation_of_attr(self, name): + return tx.output.side_effects.load_attr(self, name) + + if self.is_torch or name not in self.value.__dict__: + try: + attr_value = getattr(self.value, name) + except AttributeError: + raise_observed_exception(AttributeError, tx) + else: + attr_value = self.value.__dict__[name] + + source = self.source and AttrSource(self.source, name) + return VariableTracker.build(tx, attr_value, source) + + +class TypingVariable(VariableTracker): + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # Create a new typing variable, e.g., `List[int]` + if name == "__getitem__" and len(args) == 1: + new_typing = self.value[args[0].as_python_constant()] + return TypingVariable(new_typing) + unimplemented( + gb_type="unsupported method call on `typing` variable", + context=f"typing variable: {self.value}, method name: {name}, args: {args}, kwargs: {kwargs}", + explanation=f"`torch.compile` does not support method call `{name}` on `typing` variable f{self.value}.", + hints=[ + f"Avoid calling the {name} method on {self.value}.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str): + from .builder import SourcelessBuilder, VariableBuilder + + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + + if tx.output.side_effects.has_pending_mutation_of_attr(self, name): + return tx.side_effects.load_attr(self, name) + + value = getattr(self.value, name) + if self.source: + attr_source = AttrSource(self.source, name) + return VariableBuilder(tx, attr_source)(value) + else: + return SourcelessBuilder.create(tx, value) + + def as_python_constant(self): + return self.value + + def reconstruct(self, codegen: "PyCodegen") -> None: + if not isinstance(self.value, types.GenericAlias): + return super().reconstruct(codegen) + # We're just trying to load the type here. Reconstructing the type from + # scratch is tricky - for a type like `typing.List[int]` we'd need to + # deconstruct the origin and args. The origin for `List[int]` is `list` + # and the args is `(int,)`. When we recombine those we get the parts + # back and need to emit code for: + # + # `typing.List[int]` + # + # But it's # worse than that - what if `typing` isn't in the globals (or + # was loaded like `import typing as _typing ; _typing.List[int]`?) so we + # really need to do something like: + # + # `sys.modules["typing"].List[int]` + # + # Argh - but what if they rewrote the global `int`? So we have to do: + # + # `sys.modules["typing"].List[sys.modules["builtins"].int]` + # + # But where do we get `sys`? What if they never imported it or have + # something ELSE called `sys`? + # + # Let's skip all that noise and just emit it as a simple const. + # + codegen.append_output(codegen.create_load_const(self.value)) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +@functools.lru_cache(maxsize=1) +def get_np_to_tnp_map(): + """ + This generates a mapping from numpy modules to their torch._numpy + modules equivalents. + """ + from ..utils import NP_TO_TNP_MODULE + + np_fn_to_tnp_fn = {} + + for np_mod, tnp_mod in NP_TO_TNP_MODULE.items(): + for fn_name, tnp_fn in tnp_mod.__dict__.items(): + if callable(tnp_fn): + # some internal details do leak from tnp + # which are not part of numpy API. + if np_fn := getattr(np_mod, fn_name, None): + np_fn_to_tnp_fn[np_fn] = tnp_fn + + return np_fn_to_tnp_fn + + +@functools.lru_cache(maxsize=1) +def get_tnp_to_np_map(): + """ + This is just the reverse mapping of get_np_to_tnp_map() - mapping from + torch._numpy modules to numpy equivalents. + """ + m = get_np_to_tnp_map() + return {v: k for k, v in m.items()} + + +class NumpyVariable(VariableTracker): + """ + Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes. + """ + + constant_fold_functions = (tnp.issubdtype,) + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + @classmethod + def can_constant_fold_through(cls, fn): + mod = fn.__module__.split(".") + assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] + return fn in cls.constant_fold_functions + + @classmethod + def get_constant_collection_for_func(cls, fn): + mod = fn.__module__.split(".") + assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] + return np_constant_collections_map.get(fn) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if not config.trace_numpy: + unimplemented( + gb_type="attempted to trace numpy function with config.trace_numpy=False", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}", + explanation=f"Attempted to trace numpy function {self.value} " + "while `torch._dynamo.config.trace_numpy` was set to False.", + hints=[ + "Set `torch._dynamo.config.trace_numpy` to True to trace numpy functions.", + ], + ) + + from ..utils import numpy_to_tensor_wrapper + from .tensor import NumpyNdarrayVariable + + func = get_np_to_tnp_map().get(self.value) + if func is None: + unimplemented( + gb_type="attempted to trace numpy function unsupported by PyTorch", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", + explanation=f"Can't find numpy numpy function {self.value} in torch._numpy.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo) + if ( + collection_variable_typ := self.get_constant_collection_for_func(func) + ) is not None: + try: + return collection_variable_typ( + self.value( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + ) + except AsPythonConstantNotImplementedError: + unimplemented( + gb_type="numpy function that produces a const collection type encountered non-const arguments", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", + explanation=f"numpy function {self.value} that produces a const collection type " + "(e.g. np.dtype, np.iinfo/np.finfo) " + "received arguments that are not constant.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + else: + if ( + func.__module__ == "torch._numpy.random" + and config.use_numpy_random_stream + ): + unimplemented( + gb_type="attempted to trace torch._numpy.random function with config.use_numpy_random_stream=True", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", + explanation=f"Attempted to trace {self.value} when `torch._dynamo.config.use_numpy_random_stream` " + "is set to True.", + hints=[ + "Set `torch._dynamo.config.use_numpy_random_stream` to False.", + f"Avoid calling {self.value}.", + ], + ) + + args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs) + + if self.can_constant_fold_through(func) and ( + check_unspec_or_constant_args(args, kwargs) + ): + # constant fold + return variables.ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + + # TODO Add all the functions that go from constants to constants to can_constant_fold_through + proxy = tx.output.create_proxy( + "call_function", + numpy_to_tensor_wrapper(func), + *proxy_args_kwargs(args, kwargs), + ) + return NumpyNdarrayVariable.create(tx, proxy) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + unimplemented( + gb_type="attempted to trace numpy.* function as a method", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}", + explanation="Tracing numpy.* functions as methods is not supported.", + hints=[ + *graph_break_hints.DIFFICULT, + ], + ) + + def as_python_constant(self): + return self.value + + def as_proxy(self): + if config.trace_numpy: + # Can replace with EnumType once we drop 3.10 support + if isinstance(self.value, enum.EnumMeta): + # This is mostly for np._CopyMode + return self.value + if isinstance(self.value, type): + # This handles numpy dtype attributes such as np.float32 + # We return a string as we don't want to serialize non-PyTorch objects in the output FX graph + # In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does + return self.value.__name__ + + return super().as_proxy() + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +# Used to keep track of NULLs pushed on the stack for Python 3.11 function calls +class NullVariable(VariableTracker): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def __repr__(self) -> str: + return "NullVariable" + + def reconstruct(self, codegen: "PyCodegen"): + if sys.version_info < (3, 11): + unimplemented( + gb_type="cannot reconstruct NullVariable in Python < 3.11", + context="", + explanation="Attempted to generate PUSH_NULL instruction in Python < 3.11; " + "where this instruction does not exist.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + codegen.append_output(create_instruction("PUSH_NULL")) + + +class DeletedVariable(VariableTracker): + """Marker used to implement delattr()""" + + +class StringFormatVariable(VariableTracker): + """ + Represents a call to str.format(), we delay calling format until after the graph. + """ + + _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields} + + @classmethod + def create(cls, format_string, sym_args, sym_kwargs): + if all( + x.is_python_constant() + for x in itertools.chain(sym_args, sym_kwargs.values()) + ): + return variables.ConstantVariable.create( + format_string.format( + *[v.as_python_constant() for v in sym_args], + **{k: v.as_python_constant() for k, v in sym_kwargs.items()}, + ) + ) + return cls(format_string, list(sym_args), dict(sym_kwargs)) + + def __init__(self, format_string, sym_args, sym_kwargs, **kwargs) -> None: + super().__init__(**kwargs) + assert isinstance(format_string, str) + self.format_string = format_string + self.sym_args = sym_args + self.sym_kwargs = sym_kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})" + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_const(self.format_string), + codegen.create_load_attr("format"), + ] + ), + call_function_ex=True, + ) + codegen(variables.TupleVariable(self.sym_args)) + kwargs = { + variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items() + } + codegen(variables.ConstDictVariable(kwargs)) + codegen.extend_output(create_call_function_ex(True, False)) + + +class DebuggingVariable(VariableTracker): + """ + Represents a call to a debugging function like print(), or something + registered to config.reorderable_logging_functions. + """ + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + @staticmethod + def is_reorderable_logging_function(obj): + return ( + callable(obj) + and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)) + and obj in torch._dynamo.config.reorderable_logging_functions + ) + + def call_function(self, tx: "InstructionTranslator", args, kwargs): + if tx.export: + # For export cases, we can just make debugging functions no-ops + return + + if not self.can_reorder_logs(self.value, args, kwargs): + unimplemented( + gb_type="attempted to reorder a debugging function that can't actually be reordered", + context=f"fn: {self.value}, args: {args}, kwargs: {kwargs}", + explanation="`torch.compile` can only reorder functions where the arguments " + "are Tensors, constants, or string formatters.", + hints=[ + f"Avoid calling the logging function {self.value} with args that are not supported.", + ], + ) + + tx.debug_locals.append((self, list(args))) + + def reconstruct(self, codegen: "PyCodegen"): + return self.source.reconstruct(codegen) + + @staticmethod + def can_reorder_logs(fn, args, kwargs) -> True: + """ + Run some additional checks for what sort of function calls can we + actually reorder. + """ + + allowed_input_types = ( + variables.TensorVariable, + variables.ConstantVariable, + StringFormatVariable, + ) + + flat_args = pytree.tree_leaves([args, kwargs]) + for arg in flat_args: + if not isinstance(arg, allowed_input_types): + return False + + return True + + +class LoggingLoggerVariable(VariableTracker): + """ + Represents a call to any of logging.Logger methods + """ + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if tx.export: + # For export cases, we can just make debugging functions no-ops + return + method = getattr(self.value, name, None) + function = getattr(method, "__func__", None) + if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods): + return variables.ConstantVariable.create(None) + unimplemented( + gb_type="logging.Logger method not supported for non-export cases", + context=f"method: {self.value}.{name}, args: {args}, kwargs: {kwargs}", + explanation="logging.Logger methods are not supported for non-export cases.", + hints=[ + "Add the logging method to `torch._dynamo.config.ignore_logger_methods.", + ], + ) + + +class ConstantLikeVariable(VariableTracker): + """self.value is a compile-time constant, but not a literal""" + + try: + from numpy import ( + dtype as np_dtype, + floating as np_floating, + generic as np_generic, + ) + except ImportError: + np_floating = type("invalid_type", (), {}) + np_dtype = type("invalid_type", (), {}) + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + @property + def _error_prefix(self): + """Dynamically compute the prefix from the value's type""" + t = type(self.value) + + # For builtins (int, str, etc.), just return the name + if t.__module__ == "builtins": + return t.__qualname__ + + return f"{t.__module__}.{t.__qualname__}" + + def as_python_constant(self): + return self.value + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + try: + # we only support constant propagation for methods + cargs = [x.as_python_constant() for x in args] + ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + except NotImplementedError: + unimplemented( + gb_type="constant-like method call with non-constant args", + context=f"{self._error_prefix}.{name}(*{args}, **{kwargs})", + explanation=f"Attempted to call {self._error_prefix}.{name} with non-constant args.", + hints=[ + "Ensure that the args to the method call are constant (int, str, etc.).", + ], + ) + + result = getattr(self.value, name)(*cargs, **ckwargs) + + if variables.ConstantVariable.is_literal(result): + return variables.ConstantVariable.create(result) + if isinstance(result, re.Match): + return ConstantLikeVariable(result) + + unimplemented( + gb_type="constant-like method call with unsupported return type", + context=f"{self._error_prefix}.{name}(*{args}, **{kwargs}) returned {result}", + explanation=f"Attempted to call {self._error_prefix}.{name}, got unsupported return value {result}.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + result = getattr(self.value, name) + if isinstance(result, self.np_floating): + result = float(result) + if isinstance(result, self.np_dtype): + return NumpyDTypeVariable(result) + if isinstance(result, type) and issubclass(result, self.np_generic): + # things like x.dtype.type + return NumpyVariable(result) + if variables.ConstantVariable.is_literal(result): + return variables.ConstantVariable.create(result) + return GetAttrVariable(self, name) + + +class TorchVersionVariable(ConstantLikeVariable): + _error_prefix = "torch.__version__" + + def __init__(self, **kwargs) -> None: + kwargs.setdefault("value", torch.__version__) + assert kwargs["value"] is torch.__version__ + super().__init__(**kwargs) + + +class NumpyDTypeVariable(ConstantLikeVariable): + def as_proxy(self): + """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable: + + np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype. + This also handles unsupported things nicely (i.e. structured arrays and object arrays). + """ + return self.value.type.__name__ + + +np_constant_collections_map = { + tnp.finfo: ConstantLikeVariable, + tnp.iinfo: ConstantLikeVariable, + tnp.dtype: NumpyDTypeVariable, +} + + +class RandomClassVariable(VariableTracker): + """random.Random""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def call_function(self, tx: "InstructionTranslator", args, kwargs): + if len(args) > 1 or kwargs: + unimplemented( + gb_type="random.Random() with improper arguments", + context=f"args: {args}, kwargs: {kwargs}", + explanation="random.Random() with > 1 arg or with kwargs is not supported.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0] + return RandomVariable( + seed=seed, mutation_type=variables.base.ValueMutationNew() + ) + + +class RandomVariable(VariableTracker): + """random.Random() + + Implemented by wrapping a VariableTracker around a random.Random object. + The supported methods for the random.Random object cannot be overridden. + Assumes that random objects behave the same given a set seed or state. + """ + + _nonvar_fields = { + "random", + *VariableTracker._nonvar_fields, + } + + _supported_fn_names = { + "random", + "randint", + "randrange", + "uniform", + } + + def __init__( + self, + rand: Optional[random.Random] = None, + seed: Optional[VariableTracker] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if rand is not None: + assert self.is_supported_random_obj(rand) + self.random = random.Random() + self.random.setstate(rand.getstate()) + else: + seed = seed.as_python_constant() if seed is not None else None + self.random = random.Random(seed) + + def python_type(self): + return random.Random + + def as_python_constant(self): + return self.random + + @staticmethod + def is_supported_random_obj(val): + if type(val) is not random.Random: + return False + for name in itertools.chain( + RandomVariable._supported_fn_names, ("seed", "getstate", "setstate") + ): + if not hasattr(val, name): + return False + meth = getattr(val, name) + if inspect.isbuiltin(meth): + # e.g. random.Random.random + if meth != getattr(random.Random, name).__get__(val): + return False + else: + if getattr(meth, "__func__", None) is not getattr(random.Random, name): + return False + return True + + @staticmethod + def check_state(state): + assert type(state) is tuple + assert type(state[0]) is int + assert type(state[1]) is tuple + assert all(type(x) is int for x in state[1]) + assert state[2] is None or type(state[2]) is float + + @staticmethod + def wrap_state(state): + RandomVariable.check_state(state) + return variables.TupleVariable( + [ + variables.ConstantVariable.create(state[0]), + variables.TupleVariable( + [variables.ConstantVariable.create(x) for x in state[1]] + ), + variables.ConstantVariable.create(state[2]), + ] + ) + + @staticmethod + def unwrap_state(state): + state_obj = state.as_python_constant() + RandomVariable.check_state(state_obj) + return state_obj + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "seed": + tx.output.side_effects.mutation(self) + self.random.seed( + *[x.as_python_constant() for x in args], + **{key: val.as_python_constant() for key, val in kwargs.items()}, + ) + return variables.ConstantVariable.create(None) + elif name == "getstate": + return self.wrap_state(self.random.getstate()) + elif name == "setstate": + tx.output.side_effects.mutation(self) + self.random.setstate(self.unwrap_state(args[0])) + return variables.ConstantVariable.create(None) + elif name in self._supported_fn_names: + tx.output.side_effects.mutation(self) + state = self.random.getstate() + + def call_random_meth(*args, **kwargs): + r = random.Random() + r.setstate(state) + return getattr(r, name)(*args, **kwargs) + + # self.random state not actually updated by call_random_meth, so update here + # by calling the method + getattr(self.random, name)( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + + return call_random_fn(tx, call_random_meth, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(random), + codegen.create_load_attr("Random"), + ] + ) + ) + codegen.call_function(0, False) + # NOTE using add_push_null may result in NULL being duplicated + # so defer the push_null to call_function + codegen.dup_top() + codegen.load_attr("setstate") + codegen(self.wrap_state(self.random.getstate())) + codegen.call_function(1, True) + codegen.pop_top() + + +class WeakRefVariable(VariableTracker): + @staticmethod + def build(tx, weakref_value, **options): + source = options.get("source") + callback = weakref_value.__callback__ + callback_source = source and AttrSource(source, "__callback__") + callback_vt = VariableTracker.build(tx, callback, callback_source) + referent = weakref_value() + source = source and WeakRefCallSource(source) + referent_vt = VariableTracker.build(tx, referent, source) + options["source"] = source + return WeakRefVariable(referent_vt, callback_vt, **options) + + def __init__(self, referent_vt, callback_vt, **options): + super().__init__(**options) + self.referent_vt = referent_vt + self.callback_vt = callback_vt + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return self.referent_vt + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref")) + codegen(self.referent_vt) + codegen(self.callback_vt) + codegen.extend_output(create_call_function(2, False)) + + def is_python_hashable(self): + return self.referent_vt.is_python_hashable() + + def get_python_hash(self): + # weakref relies on the referent's hash + return self.referent_vt.get_python_hash() + + def is_python_equal(self, other): + return self.referent_vt.is_python_equal(other.referent_vt) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py new file mode 100644 index 0000000000000000000000000000000000000000..fb3b2b792215ccdec807f20a31d06e9fdd937e49 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py @@ -0,0 +1,1378 @@ +""" +This module implements variable tracking for PyTorch nn.Module instances during Dynamo tracing. + +It provides specialized handling for different types of nn.Module instances through several key classes: + +- NNModuleVariable: Handles instance-specific module tracing, specializing on module id() and placing + parameters directly on the torch.fx.GraphModule. This creates one graph per module instance. + +- UnspecializedNNModuleVariable: Provides class-level module tracing, treating nn.Modules like other + user-defined objects and passing parameters as inputs to the FX graph. This creates one graph per + module class. + +- UnspecializedBuiltinNNModuleVariable: Specifically handles built-in PyTorch modules (e.g. nn.Linear) + with appropriate optimizations. + +- FSDPManagedNNModuleVariable: Special handling for FSDP-wrapped modules with modified guarding behavior + and parameter handling. + +The module integrates with Dynamo's broader tracing functionality to handle module method calls, +parameter access, hooks, and other nn.Module behaviors while maintaining proper scoping and guarding +of module state. +""" + +import functools +import inspect +import itertools +import re +import types +from collections.abc import Iterable, Sequence +from contextlib import contextmanager, nullcontext +from typing import Any, Optional, TYPE_CHECKING + +import torch.nn +from torch._guards import Source + +from .. import graph_break_hints, trace_rules, variables +from ..exc import raise_observed_exception, unimplemented, UnspecializeRestartAnalysis +from ..guards import GuardBuilder, install_guard +from ..mutation_guard import GenerationTracker +from ..source import ( + AttrSource, + ConstDictKeySource, + DictGetItemSource, + FSDPNNModuleSource, + GetItemSource, + NNModuleSource, + UnspecializedNNModuleSource, +) +from ..utils import ( + get_custom_getattr, + get_fake_value, + is_lazy_module, + is_namedtuple, + is_safe_constant, + istensor, + istype, + nnmodule_has_hooks, + object_has_getattribute, + proxy_args_kwargs, + raise_args_mismatch, + set_example_value, + unpatched_nn_module_call, + unpatched_nn_module_call_impl, +) +from .base import raise_type_error_exc, typestr, ValueMutationNew, VariableTracker +from .functions import invoke_and_store_as_constant +from .lazy import LazyVariableTracker +from .lists import SliceVariable +from .user_defined import UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + from .constant import ConstantVariable + + +def initialize_lazy_module( + tx: "InstructionTranslator", + mod: torch.nn.Module, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], +) -> None: + """ + Fairly coupled helper used by NNModuleVariable and UnspecializedNNModuleVariable. + + Used to cause lazy module to be initialized (and delete its init hook) before tracing. Especially + useful now that 'allowed' modules graph-break on hooks, calling this first ensures there is no hook + by the time we trace __call__ and thus no graph-break for lazy allowed modules. + """ + if hasattr(mod, "_initialize_hook"): + + def convert_to_fake(x: Any) -> Any: + if is_namedtuple(x): + return type(x)(*(convert_to_fake(elem) for elem in x)) + elif isinstance(x, dict): + return {k: convert_to_fake(v) for k, v in x.items()} # type: ignore[misc] + elif isinstance(x, (list, tuple, set)): + return type(x)(convert_to_fake(elem) for elem in x) + elif isinstance(x, torch.fx.Proxy): + return get_fake_value(x.node, tx) + else: + return x + + proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) + fake_args = [convert_to_fake(arg) for arg in proxy_args] + fake_kwargs = {k: convert_to_fake(v) for k, v in proxy_kwargs.items()} + try: + mod._infer_parameters(mod, fake_args, fake_kwargs) # type: ignore[operator] + except AttributeError as e: + # Re-raise with the original error message from the AttributeError + raise_observed_exception( + AttributeError, + tx, + args=[ + str(e) + if str(e) + else "AttributeError during lazy module initialization" + ], + ) + + +@contextmanager +def record_nn_module_stack( + module_key: str, source: Source, tx: "InstructionTranslator", mod: torch.nn.Module +) -> Any: + fully_qualified_name = source.name + # Remove redundant namings + fully_qualified_name = re.sub( + r"\._(?:modules|parameters|buffers)\[(['\"])([^'\"\]]+)\1\]", + r".\2", + fully_qualified_name, + ) + num_calls = tx.num_calls.get(fully_qualified_name, 0) + module_key = f"{module_key}@{num_calls}" if num_calls > 0 else module_key + try: + tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__) + tx.num_calls[fully_qualified_name] = num_calls + 1 + yield + finally: + del tx.nn_module_stack[module_key] + + +def guard_to_detect_forward_monkeypatching( + source: Optional[Source], mod: torch.nn.Module +) -> None: + # Users sometimes patch the forward method of a nn module instance to + # perform optimizations like quantization. Though this is not a good + # software practice, but python allows this and Dynamo needs to detect + # this patching. + # + # One way to do this is to add an ID_MATCH guard on every function + # getting inlined (https://github.com/pytorch/pytorch/pull/124975). But + # this increased guard overhead by around 20%. + # + # To keep the guard overhead down, we just guard on the `forward` being + # not present in the mod __dict__. The common case of patching forward + # method adds `forward` in the instance __dict__, whereas the unpatched + # `forward` sits in the type(mod).__dict__ + if source: + if "forward" in mod.__dict__ and callable(mod.__dict__["forward"]): + # Monkeypatched forward method, add an ID_MATCH guard on forward function + fwd = mod.__dict__["forward"] + forward_source = AttrSource(source, "forward") + if type(fwd) is types.MethodType: + forward_source = AttrSource(forward_source, "__func__") + install_guard(forward_source.make_guard(GuardBuilder.CLOSURE_MATCH)) + else: + # Common case - check that the forward key is absent in mod __dict__ + install_guard( + source.make_guard( + functools.partial( + GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr="forward" + ) + ) + ) + + +class NNModuleVariable(VariableTracker): + _nonvar_fields = { + "module_type", + "module_key", + "value", + "nn_module_stack_source", + *VariableTracker._nonvar_fields, + } + + def __init__( + self, module_type: type, module_key: str, value: torch.nn.Module, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + self.module_type = module_type + self.module_key = module_key + self.value = value + # pyrefly: ignore[bad-override] + # NOTE: Don't remove this; better than adding suppressions + # everywhere else with asserts + self.source: Source = self.source + self.nn_module_stack_source = self.source + + def get_nn_module_stack_source(self) -> Source: + res = self.nn_module_stack_source or self.source + assert res + return res + + def set_nn_module_stack_source(self, source: Source) -> None: + self.nn_module_stack_source = source + + def python_type(self) -> type: + return self.module_type + + def _wrap_submodule( + self, + tx: "InstructionTranslator", + source: Source, + submod: torch.nn.Module, + *key_extra: Any, + **options: Any, + ) -> None: + return + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + # implement list/iter/tuple/etc calls + base = tx.output.get_submodule(self.module_key) + result: list[VariableTracker] = [] + if isinstance(base, torch.nn.ModuleDict): + for name, submod in base.items(): + name_var = variables.ConstantVariable.create(name) + tx.output.register_attr_or_module( + submod, + self.module_key, + name, + source=NNModuleSource(GetItemSource(self.source, name)), # type: ignore[arg-type] + ) + result.append(name_var) + return result + + assert isinstance( + base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential) + ), typestr(base) + for idx, submod in enumerate(base): + result.append( + tx.output.register_attr_or_module( + submod, + self.module_key, + idx, + source=NNModuleSource(GetItemSource(self.source, idx)), + ) + ) + return result + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "ConstantVariable": + mod = tx.output.get_submodule(self.module_key) + result = hasattr(mod, name) + install_guard( + NNModuleSource(AttrSource(self.source, name)).make_guard( + GuardBuilder.HASATTR + ) + ) + return variables.ConstantVariable.create(result) + + def is_training(self, tx: "InstructionTranslator") -> bool: + mod = tx.output.get_submodule(self.module_key) + return getattr(mod, "training", False) + + def convert_to_unspecialized(self, tx: "InstructionTranslator") -> None: + """Restart analysis treating this module as an UnspecializedNNModuleVariable""" + mod = tx.output.get_submodule(self.module_key) + GenerationTracker.tag(mod) + + # Mark the class dynamic unless its module initialization + if tx.f_code.co_name != "__init__": + GenerationTracker.mark_class_dynamic(type(mod)) + raise UnspecializeRestartAnalysis + + def has_key_in_generic_dict(self, tx: "InstructionTranslator", key: str) -> bool: + base = tx.output.get_submodule(self.module_key) + + if object_has_getattribute(base): + unimplemented( + gb_type="Custom __getattribute__ in nn.Module dict key check", + context=f"has_key_in_generic_dict {self} {key}", + explanation="Dynamo does not support checking key existence " + "on `nn.Module` instances that have a custom " + "`__getattribute__` method defined.", + hints=[ + "Avoid defining `__getattribute__` in your module.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + if tx.output.side_effects.has_pending_mutation_of_attr(self, key): + mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) + return not isinstance(mutated_attr, variables.DeletedVariable) + + base_dict = object.__getattribute__(base, "__dict__") + return key in base_dict + + def _custom_getattr_fallback( + self, + base: torch.nn.Module, + tx: "InstructionTranslator", + name: str, + obj_source: Source, + ) -> Optional[VariableTracker]: + """Check for a __getattr__ and handle it specially if it is implemented""" + if object_has_getattribute(base): + unimplemented( + gb_type="Custom __getattribute__ in nn.Module attribute access", + context=f"var_getattr {self} {name}", + explanation="Dynamo does not support checking key existence " + "on `nn.Module` instances that have a custom " + "`__getattribute__` method defined.", + hints=[ + "Avoid defining `__getattribute__` in your module.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True) + if getattr_fn is None: + return None + + if not isinstance(getattr_fn, types.FunctionType): + unimplemented( + gb_type="torch.nn.Module with a non-function custom __getattr__", + context=f"var_getattr {self} {name}", + explanation=( + "Dynamo detected a nn.Module object with a custom " + "`__getattr__` method, but this method is not a standard " + "Python function (e.g., it might be implemented in C/C++). " + "Dynamo cannot currently trace into such non-standard " + "`__getattr__` methods." + ), + hints=[ + "Avoid using objects with non-standard __getattr__ methods " + "within the compiled region. If possible, implement " + "__getattr__ as a standard Python function.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + options = {"source": AttrSource(obj_source, "__getattr__")} + # pyrefly: ignore[bad-argument-type] + return variables.UserMethodVariable(getattr_fn, self, **options).call_function( + tx, [variables.ConstantVariable.create(name)], {} + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + source = self.source and AttrSource(self.source, name) + + base = tx.output.get_submodule(self.module_key) + base_dict = object.__getattribute__(base, "__dict__") + object_member = True + all_class_attribute_names = set() + for x in inspect.getmro(base.__class__): + all_class_attribute_names.update(x.__dict__.keys()) + + if not self.source: + unimplemented( + gb_type="getattr with no source", + context=f"var_getattr {self} {name}", + explanation="Dynamo does not know how to access an attribute " + "on an `nn.Module` instance that lacks a source. This is " + "usually an internal error in Dynamo.", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + + if name == "__dict__": + return variables.GetAttrVariable(self, name, source=source) + + subobj = None + if name in base_dict: + subobj = base_dict[name] + elif ( + "_modules" in base_dict + and name in base_dict["_modules"] + and name not in all_class_attribute_names + ): + subobj = base_dict["_modules"][name] + elif "_parameters" in base_dict and name in base_dict["_parameters"]: + subobj = base_dict["_parameters"][name] + elif "_buffers" in base_dict and name in base_dict["_buffers"]: + subobj = base_dict["_buffers"][name] + else: + try: + subobj = inspect.getattr_static(base, name) + object_member = False + except AttributeError: + # see if we can fallback to __getattr__, which is not checked by getattr_static + result = self._custom_getattr_fallback( + base=base, tx=tx, name=name, obj_source=self.source + ) + if result is not None: + return result + # if we can't find a __getattr__, we can't parse this, raise attribute error + raise_observed_exception( + AttributeError, + tx, + args=[f"'{type(base).__name__}' object has no attribute '{name}'"], + ) + + if name == "forward": + guard_to_detect_forward_monkeypatching(self.source, base) + + if name == "__class__" and not object_member: + return variables.UserDefinedClassVariable(base.__class__, source=source) + + if object_member: + out = VariableTracker.build(tx, subobj, NNModuleSource(source)) # type: ignore[arg-type] + + if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)): + # nn_module_stack source is BC surface area. Ensure that + # mod._modules["linear"] is reflected as mod.linear for + # nn_module_stack. + out.set_nn_module_stack_source( + AttrSource(self.get_nn_module_stack_source(), name) + ) + return out + + else: + if istype(subobj, property): + if self.source: + # Read the class attribute to reach the property + source = AttrSource(AttrSource(self.source, "__class__"), name) + # Get the getter function + source = AttrSource(source, "fget") + return variables.UserFunctionVariable( + subobj.fget, # pyrefly: ignore[bad-argument-type] + source=source, + ).call_function(tx, [(self)], {}) + elif istype(subobj, classmethod): + return variables.UserMethodVariable( + subobj.__func__, + variables.UserDefinedObjectVariable(type(base)), + source=source, + ) + elif istype(subobj, staticmethod): + return variables.UserFunctionVariable( + # pyrefly: ignore[bad-argument-type] + subobj.__get__(base), + source=source, + ) + elif istype(subobj, types.FunctionType): + return variables.UserMethodVariable(subobj, self, source=source) + elif is_safe_constant(subobj) or istensor(subobj): + # Support possibly common cases of class members + return VariableTracker.build(tx, subobj, NNModuleSource(source)) # type: ignore[arg-type] + else: + unimplemented( + gb_type="Unsupported nn.Module attribute type", + context=f"nn.Module subclass: {typestr(base)}, name: {name}, attribute type: {typestr(subobj)}", + explanation=f"Dynamo does not support tracing nn.Module attributes of type `{typestr(subobj)}`", + hints=[ + f"Refactor your code so that `{name}` (type `{typestr(subobj)}`) is not an attribute of `{typestr(base)}`", + "Currently supported attribute types are methods, classmethods, staticmethods, " + "properties, constants, and tensors.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + return variables.GetAttrVariable(self, name, source=source) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + mod = tx.output.get_submodule(self.module_key) + + with record_nn_module_stack( + self.module_key, self.get_nn_module_stack_source(), tx, mod + ): + is_lazy = is_lazy_module(mod) + if ( + isinstance(mod, torch.nn.Sequential) + and mod.__class__.forward is torch.nn.Sequential.forward + ): + if nnmodule_has_hooks(mod): + # We do not want to unroll sequential if it has hooks, since evaporating it + # will cause hooks to not fire! + # This terminates and restart the tracing process + self.convert_to_unspecialized(tx) + + # Unroll sequential + assert not is_lazy, ( + "Expected lazy sequential isn't a valid combination?" + ) + if kwargs: + raise_args_mismatch( + tx, + "torch.nn.Module.Sequential", + "0 kwargs", + f"{len(kwargs)} kwargs", + ) + (arg,) = args + # TODO: Use named_children when it supports remove_duplicate=False. + for child_name, submod in mod._modules.items(): + tx.call_function( + tx.output.register_attr_or_module( + submod, + self.module_key, + child_name, + source=NNModuleSource(AttrSource(self.source, child_name)), # type: ignore[arg-type] + ), + [arg], + {}, + ) + arg = tx.pop() + return arg + + if is_lazy: + # The module type will change after it is called + if mod.cls_to_become is not None: + self.module_type = mod.cls_to_become # type: ignore[assignment] + + # The pre-hook runs to initialize the module shapes, then deletes itself. After this, + # the module is more or less not lazy and can be treated as a normal module regardless of + # is_allowed or other variations. + initialize_lazy_module(tx, mod, args, kwargs) + + # If we are tracing the higher order op, we want Dynamo to step + # inside the module call so that Dynamo can see the underlying + # parameters and buffers and raise them as inputs to the graph. + # + # NB: torch.nn.utils.parametrize changes the class type of a + # parametrized module such that its __module__ points to + # "torch.nn.utils.parametrize". + if ( + tx.output.is_root_tracer() + and mod.__module__.startswith(("torch.nn.", "torch.ao.")) + and mod.__module__ != "torch.nn.utils.parametrize" + # this basically means we are using the new strict export tracer which wraps the + # user callable, so we shouldn't directly proxy in the fx graph + and not isinstance( + mod, torch.ao.quantization.pt2e.export_utils._WrapperModule + ) + ): + if nnmodule_has_hooks( + mod, check_forward_hooks=True, check_backward_hooks=True + ): + # End of fn, this bubbles up and restarts tracing. + self.convert_to_unspecialized(tx) + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_module", + self.module_key, + *proxy_args_kwargs(args, kwargs), + ), + ) + else: + if isinstance(mod, torch.fx.GraphModule): + # TODO: do we want to support __call__ for GM's? + # If so at least some changes are needed, we don't allow inlining + # the call_wrapped currently, and maybe other issues too + fn = mod.forward + fn_source = AttrSource(self.source, "forward") + else: + fn = mod._call_impl + fn_source = AttrSource(self.source, "_call_impl") + if istype(fn, types.MethodType): + fn = fn.__func__ + fn_source = AttrSource(fn_source, "__func__") + args = [self] + list(args) + else: + assert istype(fn, types.FunctionType) + return tx.inline_user_function_return( + # pyrefly: ignore[bad-argument-type] + variables.UserFunctionVariable(fn, source=fn_source), + args, + kwargs, + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + constant: bool = False, + ) -> VariableTracker: + from . import ConstantVariable, ListIteratorVariable, TupleVariable + + key = self.module_key + module = tx.output.get_submodule(key) + + def generic_call_method_helper(name: str) -> VariableTracker: + # Helper function to put a `call_method` node in FX graph, + # with nn.Module as the first arg. + mod_proxy = tx.output.create_proxy( + "get_attr", + self.module_key, + (), + {}, + ) + set_example_value(mod_proxy.node, module) + + proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_method", + name, + args=(mod_proxy, *proxy_args), + kwargs=proxy_kwargs, + ), + ) + + if name in ["_call_impl", "_wrapped_call_impl"]: + # Example: `self.layer.__call__(x)` + # This is used for explicit calling `__call__` in a forward function. + # Dynamo inlines `__call__`, includes hooks. + return self.call_function(tx, args, kwargs) + elif name == "forward": + # Example: `self.layer.forward(x)` + # This is used for explicit calling `forward` in a forward function. + # Dynamo puts `call_method` node in FX, doesn't trigger hooks. + with record_nn_module_stack( + self.module_key, self.get_nn_module_stack_source(), tx, module + ): + return generic_call_method_helper(name) + + if name == "_check_input_dim" and trace_rules.is_torch_inline_allowed( + inspect.getfile(module.__class__._check_input_dim) # type: ignore[union-attr] + ): + return ConstantVariable.create(True) + + if name == "_get_item_by_idx": + if not args[1].is_python_constant(): + raise_type_error_exc( + tx, + f"``nn.Module`` {module}'s call method {name} requires a constant index argument", + ) + if not isinstance(args[0], TupleVariable): + raise_type_error_exc( + tx, + f"``nn.Module`` {module}'s call method {name} requires a tuple as first argument", + ) + mod_var = args[0].items[args[1].value] # type: ignore[attr-defined] + if isinstance(mod_var, UnspecializedNNModuleVariable): + return mod_var + key = mod_var.module_key # type: ignore[attr-defined] + submod = tx.output.get_submodule(key) + return tx.output.register_attr_or_module( + submod, + key, + key, + source=NNModuleSource(GetItemSource(self.source, key)), + ) + + if constant: + fn = getattr(module, name) + name = f"{module.__class__.__name__}_{name}_result" + return invoke_and_store_as_constant(tx, fn, name, args, kwargs) + + def assert_all_args_kwargs_const() -> None: + if not all( + x.is_python_constant() for x in itertools.chain(args, kwargs.values()) + ): + unimplemented( + gb_type="non-const argument in nn.Module method", + context=f"call_method: {self} {name} {args} {kwargs}", + explanation="Dynamo does not support calling " + f"method `{name}` of ``nn.Module`` {module} with non-constant arguments.", + hints=[], + ) + + def get_kwargs(*names: str) -> dict[str, Any]: + assert_all_args_kwargs_const() + fn = getattr(module, name) + bound_args = inspect.signature(fn).bind( + *([x.as_python_constant() for x in args]), + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + bound_args.apply_defaults() + bound_args = bound_args.arguments + return {k: bound_args[k] for k in names} + + def wrap_values( + items: Iterable[tuple[Any, Any]], + ) -> "variables.ListIteratorVariable": + result = [] + for name, submod in items: + result.append( + tx.output.register_attr_or_module( + submod, + key, + name, + source=NNModuleSource(gen_source(self.source, name)), + ) + ) + return ListIteratorVariable( + named_children, mutation_type=ValueMutationNew() + ) + + def named_embed(name: str, obj: Any) -> "variables.TupleVariable": + return TupleVariable( + [ + ConstantVariable.create(name), + tx.output.register_attr_or_module( + obj, + key, + name, + source=NNModuleSource(gen_source(self.source, name)), + ), + ] + ) + + def gen_source(source: Source, name: str) -> Source: + name_split = name.split(".") + if name_split[0] == "": + return source + while len(name_split) > 0: + x = name_split.pop(0) + source = AttrSource(source, x) + return source + + if name == "named_children": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + named_children: list[VariableTracker] = [] + for name, submod in module.named_children(): + named_children.append(named_embed(name, submod)) + return ListIteratorVariable( + named_children, mutation_type=ValueMutationNew() + ) + elif name == "named_parameters": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_parameters")) + named_parameters: list[VariableTracker] = [] + for name, param in module.named_parameters( + **get_kwargs("prefix", "recurse") + ): + named_parameters.append(named_embed(name, param)) + return ListIteratorVariable( + named_parameters, mutation_type=ValueMutationNew() + ) + elif name == "named_buffers": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers")) + named_buffers: list[VariableTracker] = [] + for name, buffer in module.named_buffers( + **get_kwargs("prefix", "recurse", "remove_duplicate") + ): + named_buffers.append(named_embed(name, buffer)) + return ListIteratorVariable(named_buffers, mutation_type=ValueMutationNew()) + elif name == "named_modules": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) + named_modules_list: list[VariableTracker] = [] + for name, submod in module.named_modules( + **get_kwargs("memo", "prefix", "remove_duplicate") + ): + named_modules_list.append(named_embed(name, submod)) + return ListIteratorVariable( + named_modules_list, mutation_type=ValueMutationNew() + ) + elif name == "children": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return wrap_values(module.named_children()) + elif name == "modules": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) + return wrap_values(module.named_modules()) + elif name == "parameters": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_parameters")) + return wrap_values(module.named_parameters(**get_kwargs("recurse"))) + elif name == "buffers": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers")) + return wrap_values(module.named_buffers(**get_kwargs("recurse"))) + elif name == "keys": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + result = [] + # pyrefly: ignore[not-iterable] + for tmp in module: + result.append(ConstantVariable.create(tmp)) + return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + elif name == "values": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return wrap_values(module.items()) # type: ignore[operator] + elif name == "items": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + items_result: list[VariableTracker] = [] + for name, submod in module.items(): # type: ignore[operator] + items_result.append(named_embed(name, submod)) + return ListIteratorVariable(items_result, mutation_type=ValueMutationNew()) + elif name == "__len__": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return ConstantVariable.create(len(module)) # type: ignore[arg-type] + elif name == "__iter__": + return ListIteratorVariable( + self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() + ) + elif ( + name == "__contains__" + and isinstance(module, (torch.nn.ModuleDict, torch.nn.ParameterDict)) + and args + and args[0].is_python_constant() + ): + return ConstantVariable.create( + args[0].as_python_constant() in module._modules + ) + elif name == "__getitem__": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + builtin_supported = ( + torch.nn.ModuleDict.__getitem__, + torch.nn.ModuleList.__getitem__, + torch.nn.ParameterDict.__getitem__, + torch.nn.ParameterList.__getitem__, + torch.nn.Sequential.__getitem__, + ) + # pyrefly: ignore[missing-attribute] + if type(module).__getitem__ not in builtin_supported: + if not ( + args[0].is_python_constant() + and isinstance(args[0].as_python_constant(), (str, int)) + ): + unimplemented( + gb_type="Invalid or non-const argument in nn.Module __getitem__", + context=f"call_method: {self} {name} {args} {kwargs}", + explanation="Dynamo does not support calling " + f"method `{name}` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.", + hints=[ + "Use constant arguments of type str or int for __getitem__" + ], + ) + fn = getattr(module, name).__func__ + + assert isinstance(fn, types.FunctionType) + + src = AttrSource(AttrSource(self.source, name), "__func__") # type: ignore[arg-type] + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn, source=src), + [self] + list(args), + kwargs, + ) + + if isinstance(args[0], SliceVariable): + # TODO(anijain2305,export-team) - Remove this if condition when inlining of inbuilt nn modules is + # enabled for export. + if tx.output.export: + # Build a TupleVariable of NNModules + result = [] + + # Turn the slice into the list of integers + keys = list(range(len(module)))[args[0].as_python_constant()] # type: ignore[arg-type] + for idx, submod in enumerate(module[args[0].as_python_constant()]): # type: ignore[arg-type] + key = keys[idx] + src = NNModuleSource(GetItemSource(self.source, key)) + result.append( + tx.output.register_attr_or_module( + submod, + key, + source=src, + ) + ) + + new_module = module[args[0].as_python_constant()] # type: ignore[index] + new_module_variable = tx.output.register_attr_or_module( + new_module, + f"{self}.__getitem__(slice)", + source=NNModuleSource( + GetItemSource(self.source, args[0].as_python_constant()) + ), + ) + return new_module_variable + else: + # slice on nn module results in a creation of new module instance, so we need to make it sourceless. + # Convert to unspecialized so that UnspecializedNNModule variable can take care of it. + self.convert_to_unspecialized(tx) + + from .tensor import SymNodeVariable + + key_value = 0 + if isinstance(args[0], SymNodeVariable): + key_value = args[0].evaluate_expr(tx.output) + elif args[0].is_python_constant(): + key_value = args[0].as_python_constant() + else: + unimplemented( + gb_type="Unsupported key type for nn.Module.__getitem__", + context=f"call_method: {self} {name} {args} {kwargs}", + explanation="Dynamo does not support getitem on " + "`nn.Module` with non-constant key.", + hints=[], + ) + + submod = module[key_value] # type: ignore[index] + return tx.output.register_attr_or_module( + submod, + self.module_key, + key_value, + source=NNModuleSource(GetItemSource(self.source, key_value)), + ) + elif ( + name == "_get_abs_string_index" + or ( + isinstance(module, torch.nn.modules.conv._ConvNd) + and name == "_conv_forward" + ) + or ( + isinstance(module, torch.nn.modules.conv._ConvTransposeNd) + and name == "_output_padding" + ) + ): + # Inline the function + fn = getattr(module, name).__func__ + fn_source = AttrSource(AttrSource(self.source, name), "__func__") # type: ignore[arg-type] + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn, source=fn_source), + [self] + list(args), + kwargs, + ) + # A loose heuristic, but seems to be generally good before we drop into the + # manual handling of inputs + elif ( + name in module.__class__.__dict__ + and callable(module.__class__.__dict__[name]) + and all(x.is_tensor() for x in itertools.chain(args, kwargs.values())) + ): + return generic_call_method_helper(name) + else: + return super().call_method(tx, name, list(args), kwargs) + + +class UnspecializedNNModuleVariable(UserDefinedObjectVariable): + _nonvar_fields = { + "value_type", + "is_state_mutated", + "nn_module_stack_source", + *UserDefinedObjectVariable._nonvar_fields, + } + + """ + The above class will specialize on the id() of a module and place + parameters on the torch.fx.GraphModule. Giving one graph per + module instance. This version treats nn.Modules() like other user + defined objects and will pass parameters into the FX graph as inputs. + Giving one graph per module class. + """ + + def __init__(self, value: torch.nn.Module, **kwargs: Any) -> None: + if type(value) is torch.jit._script.RecursiveScriptModule: + unimplemented( + gb_type="UnspecializedNNModuleVariable wrapped around ScriptModules unsupported", + context=str(value), + explanation="ScriptModules aren't supported in UnspecializedNNModuleVariable" + " because their .forward function isn't a static member of their type.", + hints=[ + *graph_break_hints.DIFFICULT, + ], + ) + if "value_type" in kwargs: + lazy_value_to_become = getattr(kwargs["value_type"], "cls_to_become", None) + if type(value) is lazy_value_to_become: + # We may have cloned a variabletracker for a LazyModule earlier (e.g. tracking side-effects) + # and then later we called and mutated the LazyModule into a MaterializedModule. + # We do not do the mutation upon first seeing a LazyModule since we preserve eager semantics to only + # mutate upon first call, but this requires we update multiple copies of the VariableTracker post-mutation. + kwargs["value_type"] = type(value) + + super().__init__(value=value, **kwargs) + self.is_state_mutated = False + # nn_module_stack_source is used to ensure BC for nn_module_stack. + # Downstream users prefer mod.linear instead of mod._modules['linear'] + # as the module stack. When Dynamo inlines the __getattr__ method, we + # cannot use self.source for nn_module_stack because it will be similar + # to mod._modules['linear']. In these cases, we set the + # nn_module_stack_source appropriately to resemble mod.linear. + self.nn_module_stack_source = self.source + + def _wrap_source(self, attr_source: Source) -> Source: + # the vt is already wrapped with UnspecializedNNModuleSource + return attr_source + + def get_nn_module_stack_source(self) -> Source: + res = self.nn_module_stack_source or self.source + assert res + return res + + def set_nn_module_stack_source(self, source: Source) -> None: + self.nn_module_stack_source = source + + @staticmethod + @functools.cache + def _nn_module_method_ids() -> set[int]: + # Allow __setattr__ to fall through to base class handler + supported = { + torch.nn.Module.__setattr__, + torch.nn.Module.__init__, + torch.nn.Module.__delattr__, + } + return { + id(x.__code__) + for x in torch.nn.Module.__dict__.values() + if hasattr(x, "__code__") and x not in supported + } + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + try: + fn = inspect.getattr_static(self.value_type, "__iter__") + except AttributeError as e: + raise NotImplementedError from e + + if fn in ( + torch.nn.ModuleList.__iter__, + torch.nn.ParameterList.__iter__, + torch.nn.Sequential.__iter__, + ): + # The program can mutate the nn module object but the saved `value` + # will not reflect the mutations. So, trace through the `__iter__` + # function to reflect any tracked mutations. + return tx.inline_user_function_return( + VariableTracker.build(tx, fn), + [ + self, + ], + {}, + ).unpack_var_sequence(tx) + + return super().unpack_var_sequence(tx) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + mod = self.value + # see comment on lazy module handling in NNModuleVariable.call_function for context + if is_lazy_module(mod): # type: ignore[arg-type] + if mod.cls_to_become is not None: # type: ignore[attr-defined] + self.value_type = mod.cls_to_become # type: ignore[attr-defined,assignment] + initialize_lazy_module(tx, mod, args, kwargs) # type: ignore[arg-type] + + if not isinstance(mod, torch.fx.GraphModule): + name = "__call__" + fn = getattr(self.value_type, name) + else: + name = "_call_impl" + fn = getattr(self.value_type, name) + + # Check if we can short circuit nn.Module._call_impl to the forward + # method. NB - This is done to reduce the compile time of Dynamo. + if ( + istype(mod.__call__, types.MethodType) # type: ignore[operator] + and istype(mod._call_impl, types.MethodType) # type: ignore[attr-defined] + and mod.__call__.__func__ is unpatched_nn_module_call # type: ignore[operator] + and mod._call_impl.__func__ is unpatched_nn_module_call_impl # type: ignore[attr-defined] + and "forward" not in mod.__dict__ + ): + forward_method = inspect.getattr_static(mod, "forward") + if isinstance(forward_method, types.FunctionType): + globals_vt = tx.nn_modules_globals_vt + if not ( + self.var_getattr(tx, "_backward_hooks").realize().len() # type: ignore[attr-defined] + or self.var_getattr(tx, "_backward_pre_hooks").realize().len() # type: ignore[attr-defined] + or self.var_getattr(tx, "_forward_hooks").realize().len() # type: ignore[attr-defined] + or self.var_getattr(tx, "_forward_pre_hooks").realize().len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_backward_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_forward_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_backward_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_forward_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len() # type: ignore[attr-defined] + ): + name = "forward" + fn = self.value_type.forward + + if self.source: + source = self.get_source_by_walking_mro(name) + else: + source = None + + guard_to_detect_forward_monkeypatching(self.source, mod) # type: ignore[arg-type] + + ctx = ( + record_nn_module_stack( + str(id(mod)), + self.get_nn_module_stack_source(), + tx, + mod, # type: ignore[arg-type] + ) + if self.source + else nullcontext() + ) + with ctx: + if not isinstance(fn, (types.FunctionType, torch.jit.ScriptFunction)): + fn_vt = VariableTracker.build(tx, fn, source=source) + return fn_vt.call_function(tx, [self] + list(args), kwargs) + else: + # Ideally we would have just used VariableTracker.build(tx, fn, + # source=source) but that introduces guard on the + # `forward.__code__` object. Given that we already guard on the + # forward not present in generic dict, we dont need this guard. + return variables.UserFunctionVariable(fn, source=source).call_function( + tx, [self] + list(args), kwargs + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name in ["_call_impl", "_wrapped_call_impl"]: + fn = getattr(self.value_type, name) + if self.source: + source = self.get_source_by_walking_mro(name) + else: + source = None + + fn_vt = VariableTracker.build(tx, fn, source=source) + return fn_vt.call_function(tx, [self] + list(args), kwargs) + + if name not in getattr(self.value, "__dict__", {}): + try: + method = inspect.getattr_static(type(self.value), name) + except AttributeError: + method = None + + if isinstance(method, staticmethod): + source = AttrSource(self.get_source_by_walking_mro(name), "__func__") + fn_vt = VariableTracker.build(tx, method.__func__, source=source) + return fn_vt.call_function(tx, args, kwargs) + + if ( + hasattr(method, "__code__") + and id(method.__code__) in self._nn_module_method_ids() + ): + unimplemented( + gb_type="UnspecializedNNModuleVariable missing method", + context=f"call_method: {self} {name} {args} {kwargs}", + explanation=f"Dynamo does not support tracing method {name} of nn.Module {self.value}", + hints=[ + "Dynamo does not really define unspecialized nn.Module very well.", + *graph_break_hints.DIFFICULT, + ], + ) + + # "_parameters" in self.value.__dict__ checks that module is initialized + if name == "__setattr__" and "_parameters" in self.value.__dict__: + # Record if mutations happens on parameters/buffers/modules. The + # mutations on these are not tracked by base class + # UserDefinedObject vt. This will be used later to graph break + # on seeing a parameters() and family calls. + # TODO(anijain2305) - This might not be needed if we let Dynamo + # inline both getattr and setattr. In that case, it should see + # the lowest level dicts - _parameters and family and + # automatically track mutations on those. Investigate if that + # can be done. + attr_name = args[0].as_python_constant() + value = args[1] + + # This is reverse engineered by looking at nn module __setattr__ + # logic. + if ( + value.is_tensor() and value.python_type() is torch.nn.Parameter + ) or attr_name in self.value.__dict__["_parameters"]: + # Handle parameters + self.is_state_mutated = True + elif attr_name in self.value.__dict__["_buffers"]: + # Handle buffers + self.is_state_mutated = True + elif ( + isinstance( + value, + ( + variables.NNModuleVariable, + variables.UnspecializedNNModuleVariable, + ), + ) + or attr_name in self.value.__dict__["_modules"] + ): + # Handle submodules + self.is_state_mutated = True + + if ( + method is torch.nn.Module.__setattr__ + and isinstance(args[1], variables.DeletedVariable) + ) or method is torch.nn.Module.__delattr__: + # Trace through __delattr__ to track mutations on the module + # members like `_modules``. + fn_vt = VariableTracker.build(tx, torch.nn.Module.__delattr__) + return fn_vt.call_function(tx, [self, args[0]], kwargs) + + return super().call_method(tx, name, list(args), kwargs) + + def getattr_helper( + self, tx: "InstructionTranslator", field: str, name_vt: VariableTracker + ) -> Optional[VariableTracker]: + dict_vt = self.var_getattr(tx, field) + if isinstance(dict_vt, variables.ConstDictVariable): + return dict_vt.maybe_getitem_const(name_vt) + return None + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + # Allow skipping of empty hook dict guards on inbuilt nn modules + if name in ( + "_backward_hooks", + "_backward_pre_hooks", + "_forward_hooks", + "_forward_pre_hooks", + ): + # For empty hooks, make an EMPTY_NN_MODULE_HOOKS_DICT. This allows us to control the installation of empty + # hooks guard via skip_nnmodule_hook_guards + if not tx.output.side_effects.has_pending_mutation_of_attr(self, name): + hooks_dict = getattr(self.value, name) + if isinstance(hooks_dict, dict) and len(hooks_dict) == 0: + if self.source: + hooks_source = AttrSource(self.source, name) + install_guard( + hooks_source.make_guard( + GuardBuilder.EMPTY_NN_MODULE_HOOKS_DICT + ) + ) + return variables.ConstDictVariable({}) + + # For non-empty hook dicts, one way is to just fallback to VariableTracker.build() and create a ConstDictVariable. + # However, ConstDictVariable guards on keys. This can cause recompiles when the same hook is installed for + # different nn module instances, because the key keeps changing (look more into RemovableHandle to understand why + # key changes - also related https://github.com/pytorch/pytorch/issues/125836). Here, we carefully craft a + # NNModuleHooksDictVariable (a subclass of ConstDictVariable) to avoid any guard on the keys. + if ( + self.source + and name + in ( + "_forward_pre_hooks", + "_forward_hooks", + ) + and not tx.output.side_effects.has_pending_mutation_of_attr(self, name) + ): + hooks_dict = getattr(self.value, name) + hooks_dict_source = AttrSource(self.source, name) + install_guard(hooks_dict_source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + tx.output.guard_on_key_order.add(hooks_dict_source) + + def build_key_value( + i: int, k: Any, v: Any + ) -> tuple[VariableTracker, VariableTracker]: + # Make key sourceless to avoid any guard on it + key = variables.ConstantVariable.create(k) + + # Instead of using dict[key] to access the value, use a dict[dict.keys()[index]] to access the + # value. This removes the reliance on the actual key value. + source_key = ConstDictKeySource(hooks_dict_source, i) + source_value = DictGetItemSource(hooks_dict_source, source_key) + value = LazyVariableTracker.create(v, source_value) + return key, value + + result = dict( + build_key_value(i, k, v) for i, (k, v) in enumerate(hooks_dict.items()) + ) + + return variables.NNModuleHooksDictVariable( + result, type(hooks_dict), source=hooks_dict_source + ) + return super().var_getattr(tx, name) + + def manually_trace_nn_module_getattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: + """ + Dynamo tracing of nn.Module __getattr__ can be expensive if the model + has deep submodule hierarchy. Since the __getattr__ is stable, we can + directly look into the underlying datastructures. This saves a lot of + compilation time. + """ + name_vt = variables.ConstantVariable(name) + out = self.getattr_helper(tx, "_parameters", name_vt) + if out is None: + out = self.getattr_helper(tx, "_modules", name_vt) + if out is None: + out = self.getattr_helper(tx, "_buffers", name_vt) + if out is None: + raise_observed_exception( + AttributeError, + tx, + args=[ + f"'{type(self.value).__name__}' object has no attribute '{name}'" + ], + ) + assert out is not None + return out + + +class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable): + """ + Differentiates between builtin nn modules (e.g. torch.nn.Linear) and user defined nn modules. + """ + + def _wrap_source(self, attr_source: Source) -> Source: + # vt is already wrapped with the UnspecializedBuiltinNNModuleSource + return attr_source + + +class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable): + """ + Tracing behavior: trace into submodules and treat them as Unspecialized, do not + register parameters to the top-level, treat them as function inputs. + + Guards behavior: if 'skip_fsdp_guards', many guards that would be installed + by a vanilla UnspecializedNNModuleVariable are simply dropped, on the basis + that a user wrapping their model in FSDP(model) is already opting into a + requirement to not modify internal model state, which would already break FSDP without + compilation. + """ + + def __init__(self, value: torch.nn.Module, **kwargs: Any) -> None: + source = kwargs.get("source") + assert source is not None, ( + "FSDPManagedNNModule depends on having an accurate source to control guarding." + ) + + super().__init__(value=value, **kwargs) + self.source = source + + def _wrap_source(self, attr_source: Any) -> Any: + if not isinstance( + attr_source, (FSDPNNModuleSource, UnspecializedNNModuleSource) + ): + if torch._dynamo.config.skip_fsdp_guards: + return FSDPNNModuleSource(attr_source) + else: + return UnspecializedNNModuleSource(attr_source) + return attr_source diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/optimizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..53d3acc0d40118acecfa4d71bbcf10486e3f3dcf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/optimizer.py @@ -0,0 +1,420 @@ +""" +This module implements variable tracking for PyTorch optimizers during Dynamo tracing. + +The OptimizerVariable class provides specialized handling for optimizer instances by: +- Optimizing the tracing of expensive optimizer initialization +- Managing optimizer state and parameter group tracking +- Handling tensor sources and guards for optimizer state tensors +- Supporting CUDA graph execution through static tensor address management +- Providing special handling for parameter gradients and optimizer state tensors + +Key features include: +- Efficient initialization tracing via _init_group optimization +- Automatic marking of optimizer state tensors as static for CUDA graphs +- Proper source tracking for parameter groups, gradients, and state tensors +- Guard installation for optimizer state structure +- Support for both CPU and GPU tensor handling +- Cleanup of static tensor references via finalizers + +The module integrates with Dynamo's broader tracing system while providing +optimizer-specific optimizations and safety guarantees. +""" + +import logging +import weakref +from collections.abc import Iterable +from typing import Any, Optional, TYPE_CHECKING + +import torch +from torch._dynamo.variables.tensor import TensorVariable +from torch._guards import Source +from torch._logging import getArtifactLogger +from torch.utils._pytree import tree_map_only + +from ..guards import GuardBuilder, install_guard +from ..source import ( + AttrSource, + ConstDictKeySource, + DictGetItemSource, + GetItemSource, + GlobalWeakRefSource, + GradSource, +) +from ..utils import GLOBAL_KEY_PREFIX +from .base import VariableTracker +from .constant import ConstantVariable +from .dicts import ConstDictVariable +from .lists import ListVariable +from .misc import GetAttrVariable +from .user_defined import UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class ArgMappingException(Exception): + pass + + +class GuardInstallException(Exception): + pass + + +perf_hint_log = getArtifactLogger(__name__, "perf_hints") + + +def _is_static_for_cudagraphs(x: torch.Tensor) -> bool: + from torch._inductor.cudagraph_trees import get_manager + + if x.is_cuda: + manager = get_manager(x.device.index, False) + is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None + if manager: + assert manager.current_node is not None + return ( + is_static_address + or manager.current_node._is_cuda_graph_recorded_tensor(x) + ) + else: + return is_static_address + else: + # Don't print a warning for non-cuda tensors + return True + + +class OptimizerVariable(UserDefinedObjectVariable): + _nonvar_fields = { + "grad_to_source", + "tensor_to_source", + "static_tensor_names", + *UserDefinedObjectVariable._nonvar_fields, + } + + def __init__( + self, + value: torch.optim.Optimizer, + grad_to_source: Optional[dict[Any, GradSource]] = None, + static_tensor_names: Optional[set[str]] = None, + tensor_to_source: Optional[dict[torch.Tensor, Source]] = None, + **kwargs: Any, + ) -> None: + super().__init__(value, **kwargs) + # pyrefly: ignore [bad-override] + self.value: torch.optim.Optimizer = value + self.grad_to_source = grad_to_source or {} + self.tensor_to_source = tensor_to_source or {} + self.static_tensor_names = static_tensor_names or set() + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "VariableTracker": + """This is an optimization to avoid tracing the very slow initialization of the optimizer""" + if name == "_init_group": + if not hasattr(self.value, "_init_group"): + # Fallback: if the optimizer does not have _init_group, trace normally + return super().call_method(tx, name, args, kwargs) + try: + self.graph_break_if_pending_mutation(tx) + self.move_step_if_cpu() + py_args, py_kwargs = self.get_python_args(*args, **kwargs) + ret_val = self.value._init_group(*py_args, **py_kwargs) + self.map_sources_and_install_guards(tx) + self.update_list_args(tx, args, kwargs, py_args, py_kwargs) + # stash a weak_ptr to optimizer to invalidate code + # if the optimizer object dies + mangled_name = f"__optimizer_{id(self.value)}" + tx.store_global_weakref_by_id(mangled_name, self.value) + self.create_finalizer(tx) + + # This is currently safe only because the only actual `ret_val`s returned + # by the `_init_group` of existing optimizers are properties that are invariant + # to the input tensors (e.g. dtype, layout). Changing these would trigger a + # recompilation and hence never result in the wrong specialization of `ret_val`. + return ConstantVariable.create(ret_val) + except (ArgMappingException, GuardInstallException) as _: + # trace normally if we can't map args or install guards correctly + pass + + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + # Note: this allows us to intercept the call in call_method + # in the typical case, we return a UserMethodVariable + # which will directly inline + if name in ("_init_group"): + assert self.source + return GetAttrVariable(self, name, source=AttrSource(self.source, name)) + + if name == "param_groups": + from ..decorators import mark_static_address + + for group in self.value.param_groups: + for p in group["params"]: + mark_static_address(p, guard=True) + + self._set_capturable(tx) + + return super().var_getattr(tx, name) + + def graph_break_if_pending_mutation(self, tx: "InstructionTranslator") -> None: + # If there are pending mutations on a parameter (due to using closure) + # then we need to graph break to allow the python version of the parameter + # to update, so that running _init_group will initialize the states with + # the correct values + for g in self.value.param_groups: + for p in g["params"]: + side_effects = tx.output.side_effects + variable = side_effects.id_to_variable.get(id(p), None) + if variable and side_effects.has_pending_mutation(variable): + from ..exc import unimplemented + + unimplemented( + gb_type="optimizer: pending mutation on parameter", + context=f"variable: {variable}, parameter: {p}", + explanation="Pending mutations on a parameter (e.g. due to using closure) require a graph break.", + hints=[], + ) + + def _set_capturable(self, tx: "InstructionTranslator") -> None: + from . import LazyVariableTracker + + # We only set capturable if params are on cuda + # and the state is not initialized + def safe_to_set_capturable(group: dict[str, Any]) -> bool: + all_uninitialized = True + all_gpu = True + + for p in group.get("params", []): + all_gpu &= p.is_cuda or p.is_xpu + all_uninitialized &= p not in self.value.state + + return "capturable" in group and all_uninitialized and all_gpu + + # track indices to not set so we don't need to + # in the variable tracker realize the whole state + # we handle guarding the state specially + for group in self.value.param_groups: + if safe_to_set_capturable(group): + group["capturable"] = True + + source = self.source and AttrSource(self.source, "param_groups") + param_groups_vt = LazyVariableTracker.realize_all( + VariableTracker.build(tx, self.value.param_groups, source) + ) + for param_group_vt in param_groups_vt.items: + key = ConstDictVariable._HashableTracker( + ConstantVariable.create("capturable") + ) + param_group_vt.items[key] = ConstantVariable.create(True) + + def get_python_args( + self, *args: Any, **kwargs: Any + ) -> tuple[list[Any], dict[str, Any]]: + """Get python values equivalent to the variable tracker args""" + + def map_arg(arg: Any) -> Any: + if isinstance(arg, VariableTracker) and arg.is_python_constant(): + return arg.as_python_constant() + elif isinstance(arg, ListVariable) and not arg.items: + return [] + elif ( + isinstance(arg, ConstDictVariable) + and isinstance(arg.source, GetItemSource) + and isinstance(arg.source.base, AttrSource) + and arg.source.base.member == "param_groups" + ): + return self.value.param_groups[arg.source.index] + + raise ArgMappingException + + new_args = [map_arg(arg) for arg in args] + new_kwargs = {k: map_arg(v) for k, v in kwargs.items()} + + return new_args, new_kwargs + + # If users load an old state dictionary, + # it's possible that step could be on the cpu + # if this is the case, move it to the GPU + # corresponding to the parameter + # in most cases this is a no-op because the state is empty + def move_step_if_cpu(self) -> None: + for p, state in self.value.state.items(): + if "step" in state and state["step"].is_cpu: + state["step"] = state["step"].to(p.device) + + def map_sources_and_install_guards(self, tx: "InstructionTranslator") -> None: + from ..decorators import mark_static_address + from .lazy import LazyVariableTracker + + self.grad_to_source = {} + self.tensor_to_source = {} + + def mark_static(x: Any) -> None: + mark_static_address(x, guard=True) + + tree_map_only(torch.Tensor, mark_static, self.value.state) + + # Recursively realize the variable trackers for optim.state and + # optim.param_groups, which recursively install the necessary guards. + params_groups_source = self.source and AttrSource(self.source, "param_groups") + param_groups_vt = LazyVariableTracker.realize_all( + VariableTracker.build(tx, self.value.param_groups, params_groups_source) + ) + + state_source = self.source and AttrSource(self.source, "state") + state_vt = VariableTracker.build(tx, self.value.state, state_source) + + # We need to realize the top level state dict to populate + # the guard locals + state_vt.realize() + assert state_source is not None + tx.output.guard_on_key_order.add(state_source) + + # Populate self.grad_to_source and self.tensor_to_source so that we can + # manually update_list_args + for group, group_vt in zip(self.value.param_groups, param_groups_vt.items): + # we assume here that all params within a param group + # are initialized similarly + if len(group["params"]) > 0: + for param in group["params"]: + if param.grad is not None: + key_index = None + for i, k in enumerate(self.value.state.keys()): + if k is param: + key_index = i + break + if key_index: + LazyVariableTracker.realize_all( + VariableTracker.build( + tx, + self.value.state[param], + DictGetItemSource( + state_source, + ConstDictKeySource(state_source, key_index), + ), + ) + ) + break + + params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params")) + all_static = True + non_static_grads = [] + for p, p_vt in zip(group["params"], params_vt.unpack_var_sequence(tx)): + param_source = p_vt.source + self.tensor_to_source[p] = param_source + grad_source = GradSource( + param_source, + "grad", + ) + + if p.grad is not None: + self.grad_to_source[p.grad] = grad_source + if not _is_static_for_cudagraphs(p.grad): + all_static = False + non_static_grads.append(grad_source) + else: + install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH)) + + # Note: to avoid spam logs only warn if perf hint artifact is enabled + # (NB: artifacts are only enabled at the debug or warning level) + if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG): + non_static_grad_names = [src.name for src in non_static_grads] + perf_hint_log.warning( + ( + "Grad tensors %s will be copied during cudagraphs execution." + "If using cudagraphs and the grad tensor addresses will be the same across runs," + " use torch._dynamo.decorators.mark_static_address to elide this copy.", + ), + non_static_grad_names, + ) + + # We have to again iterate over the state dict to collect the + # tensor_to_source dict. This is used for the finalizer. + for idx, value in enumerate(self.value.state.values()): + p_state_source = DictGetItemSource( + state_source, ConstDictKeySource(state_source, idx) + ) + tx.output.guard_on_key_order.add(p_state_source) + for inner_idx, v in enumerate(value.values()): + if ( + isinstance(v, torch.Tensor) + and v not in self.grad_to_source + and v not in self.tensor_to_source + ): + self.tensor_to_source[v] = DictGetItemSource( + p_state_source, ConstDictKeySource(p_state_source, inner_idx) + ) + + def wrap_tensor( + self, tx: "InstructionTranslator", tensor_value: torch.Tensor + ) -> TensorVariable: + """Wrap state tensor in a TensorVariable""" + from ..decorators import mark_static_address + + # If we have a source for a tensor already use it, + # if we have not seen a tensor before, stash and use a + # global weak ref source, since it must be an optimizer tensor + # that we have missed + + if tensor_value in self.tensor_to_source: + # mark these tensors as static for cudagraphs + mark_static_address(tensor_value, guard=True) + source = self.tensor_to_source[tensor_value] + self.static_tensor_names.add(tx.output.module_key_name(source.name)) + elif tensor_value in self.grad_to_source: + source = self.grad_to_source[tensor_value] + else: + # mark these tensors as static for cudagraphs + mark_static_address(tensor_value, guard=True) + + global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value) + source = GlobalWeakRefSource(global_name) + self.static_tensor_names.add(tx.output.module_key_name(source.name)) + + return VariableTracker.build(tx, tensor_value, source) + + def update_list_args( + self, + tx: "InstructionTranslator", + args: Iterable[VariableTracker], + kwargs: Any, + py_args: Iterable[Any], + py_kwargs: Any, + ) -> None: + """Update the args and kwargs to the traced optimizer call""" + for arg, py_arg in zip(args, py_args): + if isinstance(arg, ListVariable): + assert isinstance(py_arg, list), ( + "py_arg should be a list in optimizer variable" + ) + for i, val in enumerate(py_arg): + tx.output.side_effects.mutation(arg) + if isinstance(val, torch.Tensor): + arg.items.append(self.wrap_tensor(tx, val)) + else: + source = arg.source and GetItemSource(arg.source, i) + arg.items.append(VariableTracker.build(tx, val, source)) + + def create_finalizer(self, tx: "InstructionTranslator") -> None: + names_to_delete = self.static_tensor_names + value = self.value + tc = tx.output.tracing_context + + def init_finalizer(gm: torch.fx.GraphModule) -> None: + def clear_static_tensor_refs() -> None: + for name in names_to_delete: + gm._buffers.pop(name, None) + gm._parameters.pop(name, None) + if tc.params_flat: + tc.params_flat.clear() + if tc.params_flat_unwrap_subclasses: + tc.params_flat_unwrap_subclasses.clear() + + weakref.finalize(value, clear_static_tensor_refs) + + tx.output.add_graph_finalizer(init_finalizer) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/script_object.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/script_object.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7f0873e8eb0164a8671c2b6e575e8495da9d0e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/script_object.py @@ -0,0 +1,236 @@ +""" +This module implements variable tracking for TorchScript objects during Dynamo tracing. + +The TorchScriptObjectVariable class provides specialized handling for TorchScript +objects with strong safety guarantees by: +- Enforcing method-call-only access to prevent unsafe attribute manipulation +- Converting graph breaks into hard errors via _raise_hard_error_if_graph_break +- Proper proxy and source tracking for TorchScript method calls +- Integration with higher-order operators for method call handling + +Key safety features: +- Strict validation that only method calls are allowed (no direct attribute access) +- Immediate error reporting for potentially unsafe operations +- Proper source tracking for debugging and guard installation +- Safe handling of TorchScript object method calls through torchbind + +The module ensures that TorchScript objects are handled safely during tracing +by limiting operations to known-safe patterns and failing fast for unsafe usage. +""" + +import functools +from collections.abc import Callable, Iterable +from typing import Any, Optional, TYPE_CHECKING, TypeVar +from typing_extensions import ParamSpec + +import torch +from torch._guards import Source +from torch._library.opaque_object import ( + is_opaque_reference_type, + is_opaque_type, + is_opaque_value_type, +) +from torch.fx.proxy import Proxy + +from .. import graph_break_hints +from ..eval_frame import skip_code +from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported +from .base import VariableTracker +from .constant import ConstantVariable +from .dicts import ConstDictVariable +from .lists import TupleVariable +from .user_defined import UserDefinedObjectVariable, UserDefinedVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +def _raise_hard_error_if_graph_break( + reason: str, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def deco(fn: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(fn) + def graph_break_as_hard_error(*args: _P.args, **kwargs: _P.kwargs) -> _T: + try: + return fn(*args, **kwargs) + except Unsupported as e: + raise UnsafeScriptObjectError(e.msg) from e + + return graph_break_as_hard_error + + return deco + + +class OpaqueObjectClassVariable(UserDefinedVariable): + """ + A variable that represents an opaque object class (not instance). + Since UserDefinedClassVariable has some special handling for side effects, + we have a separate class here which will directly return the object when + __init__ is called. + """ + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def as_python_constant(self): + return self.value + + def is_python_hashable(self): + return is_opaque_value_type(type(self.value)) + + def as_proxy(self): + return self.value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.value})" + + def call_function( # pyrefly: ignore[bad-override] + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # disallow creating reference-type opaque objects in the middle of the + # program + if is_opaque_reference_type(self.value): + # Skip __init__ to prevent dynamo from tracing it during resume + skip_code(self.value.__init__.__code__) + + unimplemented( + gb_type="An opaque object was created in the middle of the program.", + context=f"Opaque object type: {self.value}.", + explanation=( + "Opaque objects cannot be created inside the torch.compile region. " + "They must be created before entering the compiled function." + ), + hints=[ + "Please create the opaque object before calling torch.compile " + "and pass it in as an argument or as a global variable." + ], + ) + + var_args = TupleVariable(list(args)) + var_kwargs = ConstDictVariable( + {ConstantVariable(k): v for k, v in kwargs.items()} + ) + opaque_obj = self.value( # pyrefly: ignore[not-callable] + *(var_args.as_python_constant()), + **(var_kwargs.as_python_constant()), + ) + + return TorchScriptObjectVariable.create(opaque_obj, opaque_obj) + + +class TorchScriptObjectVariable(UserDefinedObjectVariable): + _fake_script_object_cache: dict[int, "TorchScriptObjectVariable"] = {} + + @classmethod + def is_matching_cls(cls, user_cls: type) -> bool: + return issubclass(user_cls, torch.ScriptObject) or is_opaque_type(user_cls) + + @staticmethod + def create(proxy: Proxy, value: Any, **options: Any) -> "TorchScriptObjectVariable": + return TorchScriptObjectVariable(proxy, value, **options) + + def __init__( + self, proxy: Proxy, value: Any, source: Optional[Source] = None, **kwargs: Any + ) -> None: + super().__init__(value, **kwargs) + self.proxy = proxy + if isinstance(self.proxy, torch.fx.Proxy): + self.proxy.node.meta["example_value"] = value + self.source = source + + def as_proxy(self) -> Proxy: + return self.proxy + + @_raise_hard_error_if_graph_break( + "Dynamo cannot safely trace script object due to graph break." + ) + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + from torch._higher_order_ops.torchbind import call_torchbind + + from ..source import AttrSource + from .higher_order_ops import TorchHigherOrderOperatorVariable + + if is_opaque_value_type(type(self.value)): + res = super().var_getattr(tx, name) + return res + + if hasattr(self.value, "script_class_name") and is_opaque_type( + self.value.script_class_name + ): + # For non-value opaque types, block attribute access + unimplemented( + gb_type="Attempted to access attributes/methods on an OpaqueObject", + context=f"value={self.value}, attr={name}", + explanation="Attribute/method access of OpaqueObjects is not supported.", + hints=[ + "Use custom operators instead of direct attribute/method access.", + ], + ) + + method = getattr(self.value, name, None) + if method is None: + unimplemented( + gb_type="FakeScriptObject missing method implementation", + context=f"value={self.value}, method={name}", + explanation=f"TorchScript object {self.value} doesn't define the method {name}.", + hints=[ + f"Ensure the method {name} is implemented in {self.value}.", + *graph_break_hints.USER_ERROR, + ], + ) + + if not callable(method): + unimplemented( + gb_type="Attempted to access non-callable attribute of TorchScript object", + context=f"value={self.value}, method={name}", + explanation="Attribute accesses of TorchScript objects to non-callable attributes are not supported.", + hints=[ + "Use method calls instead of attribute access.", + ], + ) + assert self.source is not None + return TorchHigherOrderOperatorVariable.make( + call_torchbind, + source=AttrSource(self.source, name), + script_obj_var=self, + method_name=name, + ) + + # We only support method calls on script objects. Interpreting the bytecodes + # should go through var_getattr then call_function instead of call_method. + # + # However, it's possible for call_method to be used directly e.g. for __setattr__. + @_raise_hard_error_if_graph_break( + "Dynamo cannot safely trace script object due to graph break." + ) + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> VariableTracker: + unimplemented( + gb_type="Weird method call on TorchScript object", + context=f"value={self.value}, method={name}", + explanation=( + f"This particular method call ({name}) is not supported (e.g. calling `__setattr__`). " + "Most method calls to TorchScript objects should be supported." + ), + hints=[ + "Avoid calling this method.", + ], + ) + + def as_python_constant(self): + if is_opaque_value_type(type(self.value)): + return self.value + return super().as_python_constant() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/sdpa.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/sdpa.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7006f5d56ab364d91a974a3cd9e14aab6af317 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/sdpa.py @@ -0,0 +1,95 @@ +from collections.abc import Sequence +from inspect import getattr_static +from typing import Any, TYPE_CHECKING, TypeGuard + +from torch._guards import Source +from torch.backends.cuda import SDPAParams +from torch.fx.proxy import Proxy + +from ..bytecode_transformation import create_call_function +from ..exc import unimplemented +from ..source import AttrSource +from .base import VariableTracker + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + +PARAM_NAMES = [ + "query", + "key", + "value", + "attn_mask", + "dropout", + "is_causal", + "enable_gqa", +] + + +class SDPAParamsVariable(VariableTracker): + """Represents the c++ params struct for scaled dot product attention. + This is a read-only container.""" + + @staticmethod + def create( + tx: "InstructionTranslator", value: Any, source: Source + ) -> VariableTracker: + from .torch import TorchInGraphFunctionVariable + + params = [ + VariableTracker.build(tx, getattr(value, p), AttrSource(source, p)) + for p in PARAM_NAMES + ] + return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {}) + + def __init__( + self, proxy: Proxy, param_vars: Sequence[VariableTracker], **kwargs: Any + ) -> None: + self.proxy = proxy + self.param_vars = param_vars + super().__init__(**kwargs) + + def reconstruct(self, codegen: "PyCodegen") -> None: + assert self.source is None + assert self.param_vars is not None + codegen.add_push_null( + lambda: codegen.load_import_from("torch._C", "_SDPAParams") + ) + codegen.foreach(self.param_vars) + codegen.extend_output(create_call_function(len(self.param_vars), False)) + + def as_proxy(self) -> Proxy: + return self.proxy + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + import torch._C + + from .builder import wrap_fx_proxy + from .misc import GetAttrVariable + + try: + getattr_static(torch._C._SDPAParams, name) + except AttributeError: + import torch._dynamo.graph_break_hints as graph_break_hints + + unimplemented( + gb_type="unsupported torch._C._SDPAParams attribute", + context=f"name: {name}", + explanation=f"Unable to fetch attribute {name} from torch._C._SDPAParams.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) + if self.source is not None: + return wrap_fx_proxy( + tx=tx, proxy=proxy, source=AttrSource(self.source, name) + ) + else: + return wrap_fx_proxy(tx=tx, proxy=proxy) + + @staticmethod + def is_sdpa_params(value: Any) -> TypeGuard["SDPAParams"]: + return value is SDPAParams diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/streams.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/streams.py new file mode 100644 index 0000000000000000000000000000000000000000..426f50e76d6ab918bfc1862ff6c4ff06556d9f68 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/streams.py @@ -0,0 +1,549 @@ +import collections +from collections.abc import Callable +from typing import Any, Optional + +import torch +from torch._dynamo.variables.dicts import ConstDictVariable +from torch._dynamo.variables.lists import TupleVariable +from torch.fx import has_side_effect, Proxy + +from .. import graph_break_hints +from ..bytecode_transformation import create_call_function +from ..exc import TYPE_CHECKING, unimplemented +from ..graph_bytecode_inputs import ( + get_external_object_by_index, + register_graph_created_object, +) +from ..source import CurrentStreamSource +from .base import VariableTracker +from .constant import ConstantVariable +from .ctx_manager import FxTracebackAnnotateVariable +from .lazy import LazyVariableTracker + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + from ..codegen import PyCodegen + +from torch._library.custom_ops import custom_op + + +Tensor = torch.Tensor + + +def new_event(*args: Any, **kwargs: Any) -> int: + event = torch.Event(*args, **kwargs) + return register_graph_created_object( + event, + EventVariable.make_construct_in_graph_event_fn( + TupleVariable([]), ConstDictVariable({}) + ), + ) + + +def new_stream(*args: tuple[Any], **kwargs: Any) -> int: + stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload] + return register_graph_created_object( + stream, + StreamVariable.make_construct_in_graph_stream_fn( + TupleVariable([]), ConstDictVariable({}) + ), + ) + + +def _codegen_current_stream(device: torch.device, cg: "PyCodegen") -> None: + cg.add_push_null( + lambda: cg.load_import_from( + torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports] + "stash_graph_created_object", + ) + ) + cg(CurrentStreamSource(device)) + cg.extend_output(create_call_function(1, False)) + + +def get_current_stream(device: torch.device) -> int: + stream = torch.accelerator.current_stream(device) + return register_graph_created_object( + stream, lambda _, cg: _codegen_current_stream(device, cg) + ) + + +def _get_stream_by_index(index: int) -> torch.Stream: + stream = get_external_object_by_index(index) + assert isinstance(stream, torch.Stream), ( + f"Fork/join stream expected a stream object at index {index}" + ) + return stream + + +def _get_event_by_index(index: int) -> torch.Event: + event = get_external_object_by_index(index) + assert isinstance(event, torch.Event), ( + f"Record/wait event expected an event object at index {index}" + ) + return event + + +@custom_op("streams::fork", mutates_args=()) +def fork_stream( + from_index: int, # kept to make stream transitions clearer + to_index: int, +) -> None: + torch.accelerator.set_stream(_get_stream_by_index(to_index)) + + +@fork_stream.register_fake +def _( + from_index: int, # kept to make stream transitions clearer + to_index: int, +) -> None: + pass + + +has_side_effect(torch.ops.streams.fork.default) + + +@custom_op("streams::join", mutates_args=()) +def join_stream(from_index: int, to_index: int) -> None: + torch.accelerator.set_stream(_get_stream_by_index(to_index)) + + +@join_stream.register_fake +def _( + from_index: int, + to_index: int, +) -> None: + pass + + +has_side_effect(torch.ops.streams.join.default) + + +@custom_op("streams::record_event", mutates_args=()) +def record_event(event_index: int, stream_index: int) -> None: + event = _get_event_by_index(event_index) + stream = _get_stream_by_index(stream_index) + stream.record_event(event) + + +@record_event.register_fake +def _( + event_index: int, + stream_index: int, +) -> None: + pass + + +has_side_effect(torch.ops.streams.record_event.default) + + +@custom_op("streams::wait_event", mutates_args=()) +def wait_event(event_index: int, stream_index: int) -> None: + event = _get_event_by_index(event_index) + stream = _get_stream_by_index(stream_index) + stream.wait_event(event) + + +@wait_event.register_fake +def _( + event_index: int, + stream_index: int, +) -> None: + pass + + +has_side_effect(torch.ops.streams.wait_event.default) + + +@custom_op("streams::wait_stream", mutates_args=()) +def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None: + waiting = _get_stream_by_index(waiting_stream_index) + waited_on = _get_stream_by_index(waited_on_stream_index) + waiting.wait_stream(waited_on) + + +@wait_stream.register_fake +def _( + event_index: int, + stream_index: int, +) -> None: + pass + + +has_side_effect(torch.ops.streams.wait_stream.default) + + +@custom_op("streams::sync_dealloc", mutates_args=()) +def sync_dealloc( + wait_event_index: int, src_stream_index: int, to_dealloc: torch.Tensor +) -> None: + """An op which waits on an event and moves the last usage of to_dealloc + after the wait, so that after the sync occurs, the deallocation or + subsequent reuse of the tensor's memory will be guaranteed to happen + after a side stream is finished using it. + See https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream + for more details""" + torch.ops.streams.wait_event.default(wait_event_index, src_stream_index) + + +has_side_effect(torch.ops.streams.sync_dealloc.default) + + +@custom_op("streams::record_stream", mutates_args=()) +def record_stream(tensor: torch.Tensor, stream_index: int) -> None: + tensor.record_stream(_get_stream_by_index(stream_index)) + + +@record_stream.register_fake +def _( + src_stream_index: int, + wait_event_index: int, + to_dealloc: torch.Tensor, +) -> None: + pass + + +class SymbolicStreamState: + """Track the currently entered stream if any""" + + def __init__(self) -> None: + from ..source import CurrentStreamSource + + cur_stack: list[StreamVariable] = [] + if torch.accelerator.is_available(): + stream_var = LazyVariableTracker.create( + torch.accelerator.current_stream(), + source=CurrentStreamSource(torch.accelerator.current_stream().device), + ) + cur_stack = [stream_var] # type: ignore[list-item] + + self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque( + cur_stack + ) + + def enter_stream(self, stream: "StreamVariable") -> None: + self.cur_stream_stack.append(stream) + + def exit_stream(self) -> None: + self.cur_stream_stack.pop() + + def cur_stream(self, device: Optional[torch.device] = None) -> "StreamVariable": + if device is not None: + for stream in reversed(self.cur_stream_stack): + if stream.device == device: + return stream + + return self.cur_stream_stack[-1] + + def in_stream_context(self) -> bool: + return len(self.cur_stream_stack) > 0 + + +class StreamContextVariable(FxTracebackAnnotateVariable): + """This represents torch.cuda.StreamContext""" + + @staticmethod + def create( + tx: "InstructionTranslator", + stream_to_enter: "StreamVariable", + **kwargs: dict[str, Any], + ) -> "StreamContextVariable": + return StreamContextVariable( + stream_to_enter, + **kwargs, + ) + + def __init__(self, stream: Optional["StreamVariable"], **kwargs: Any) -> None: + self.stream = stream + super().__init__( + target_values={"stream": self.get_stream().user_object_index}, + initial_values=None, + **kwargs, + ) + + def enter( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + # to stream, from stream is the order of the arguments + # we are entering the target, and leaving the initial stream + tx.symbolic_stream_state.enter_stream(self.get_stream()) + return super().enter(tx) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + # to stream, from stream is the order of the arguments + # we are leaving the target, and entering the initial stream + tx.symbolic_stream_state.exit_stream() + return super().exit(tx, *args) + + def supports_graph_breaks(self) -> bool: + return True + + def get_stream(self) -> "StreamVariable": + assert self.stream, "Stream context should have a separate stream" + return self.stream + + +class StreamVariable(StreamContextVariable): + """Represents the device-agnostic torch.Stream class""" + + def __init__( + self, + proxy: Proxy, + value: torch.Stream, + user_object_index: Optional[int] = None, + **kwargs: Any, + ) -> None: + # Index into the user object table + # used to pass arbitrary objects to the graph + if proxy is not None and "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == value + + self.proxy = proxy + self.value = value + # pyrefly: ignore [read-only] + self.device = value.device + # pyrefly: ignore [read-only] + self.user_object_index = user_object_index + super().__init__(None, **kwargs) + + def python_type(self) -> type: + return torch.Stream + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + assert hasattr(self.value, name), f"no stream method found named {name}" + + from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs + from .builder import wrap_fx_proxy_cls + + if name in ("wait_stream", "synchronize", "wait_event"): + tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ) + return ConstantVariable(None) + elif name == "query": + return wrap_fx_proxy_cls( + target_cls=ConstantVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ), + ) + elif name == "record_event": + return wrap_fx_proxy_cls( + target_cls=EventVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ), + ) + elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs: + from ..guards import GuardBuilder, install_guard + + if self.source: + install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) + + # NB : Checking for mutation is necessary because we compare + # constant values + other = args[0] + if not isinstance(other, StreamVariable): + return ConstantVariable.create(NotImplemented) + + if other.source: + assert self.source is not None + install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) + return ConstantVariable.create( + cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type] + ) + + return super().call_method(tx, name, args, kwargs) + + def as_proxy(self) -> Proxy: + return self.proxy + + def module_name(self) -> str: + return "torch._C" + + def fn_name(self) -> str: + return "Stream" + + def reconstruct(self, codegen: "PyCodegen") -> None: + # If we got here, this stream is fully subsumed by the graph - this means it is + # not an input or global + assert not self.source + if self.user_object_index is not None: + codegen.add_push_null( + lambda: codegen.load_import_from( + torch._dynamo.graph_bytecode_inputs.__name__, + "get_external_object_by_index", + ) + ) + codegen.append_output(codegen.create_load_const(self.user_object_index)) + codegen.extend_output(create_call_function(1, False)) + else: + # This will support the legacy behavior + prefix = f"_stream_{self.device}" + name = codegen.tx.output.install_global_by_id(prefix, self.value) + codegen.append_output(codegen.create_load_global(name, add=True)) + + def get_stream(self) -> "StreamVariable": + return self + + @staticmethod + def make_construct_in_graph_stream_fn( + args: TupleVariable, kwargs: ConstDictVariable + ) -> Callable[[int, "PyCodegen"], None]: + def fn(index: int, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from( + torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports] + "stash_graph_created_object", + ) + ) + codegen.add_push_null( + lambda: codegen.load_import_from( + torch._dynamo.utils.__name__, "build_stream" + ) + ) + codegen(args) + codegen(kwargs) + codegen.extend_output(create_call_function(2, False)) + codegen.extend_output(create_call_function(1, False)) + + return fn + + +class EventVariable(VariableTracker): + def __init__( + self, + proxy: Proxy, + value: torch.Event, + user_object_index: Optional[int], + **kwargs: Any, + ) -> None: + if proxy is not None and "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == value + super().__init__(**kwargs) + self.proxy = proxy + self.value = value + self.user_object_index = user_object_index + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from ..utils import proxy_args_kwargs + from .builder import wrap_fx_proxy_cls + + if name == "wait": + tx.output.create_proxy( + "call_function", + torch.ops.streams.wait_event, + ( + self.user_object_index, + EventVariable._get_stream_arg(tx, args, kwargs).user_object_index, + ), + {}, + ) + return ConstantVariable(None) + elif name == "record": + tx.output.create_proxy( + "call_function", + torch.ops.streams.record_event, + ( + self.user_object_index, + EventVariable._get_stream_arg(tx, args, kwargs).user_object_index, + ), + {}, + ) + return ConstantVariable(None) + elif name == "synchronize": + tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ) + return ConstantVariable(None) + elif name == "query": + return wrap_fx_proxy_cls( + target_cls=ConstantVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ), + ) + else: + method_name = ( + f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}" + ) + unimplemented( + gb_type="Unsupported event method", + context=str(name), + explanation=f"Dynamo doesn't support tracing the {method_name} method. " + f"We currently support wait, record, synchronize, and query.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + def as_proxy(self) -> Proxy: + return self.proxy + + @staticmethod + def _get_stream_arg( + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "StreamVariable": + stream_arg = None + if args: + stream_arg = args[0] + elif kwargs: + stream_arg = kwargs.get("stream") + + if not stream_arg: + stream_arg = tx.symbolic_stream_state.cur_stream() + + return stream_arg # type: ignore[return-value] + + @staticmethod + def make_construct_in_graph_event_fn( + args: TupleVariable, kwargs: ConstDictVariable + ) -> Callable[[int, "PyCodegen"], None]: + def fn(index: int, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from( + torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports] + "stash_graph_created_object", + ) + ) + codegen.add_push_null( + lambda: codegen.load_import_from( + torch._dynamo.utils.__name__, "build_event" + ) + ) + codegen(args) + codegen(kwargs) + codegen.extend_output(create_call_function(2, False)) + codegen.extend_output(create_call_function(1, False)) + + return fn + + def reconstruct(self, codegen: "PyCodegen") -> None: + # If we got here, this event is fully subsumed by the graph - this means it is + # not an input or global + assert not self.source + # Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there. + prefix = "_event" + name = codegen.tx.output.install_global_by_id(prefix, self.value) + codegen.append_output(codegen.create_load_global(name, add=True)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/tensor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..94b72200c72fa2e73a59a1bd0333d30e7ddc85f0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/tensor.py @@ -0,0 +1,1889 @@ +# mypy: ignore-errors + +""" +This module contains variable tracker classes for handling tensors and tensor-related operations in Dynamo. + +The main class is TensorVariable which represents torch.Tensor inputs and intermediate values in the FX graph. +It handles tensor operations, method calls, and maintains metadata about tensor properties like dtype, device, etc. + +Other key classes include: +- SymNodeVariable: Represents symbolic scalars (int/float/bool) used for size computation and unspecialized values +- NumpyNdarrayVariable: Handles numpy array interop through torch._numpy +- UnspecializedPythonVariable: Represents unspecialized Python numeric values as 1-element tensors +- TensorSubclassVariable: Handles tensor subclasses with __torch_function__ overrides +- UntypedStorageVariable: Represents tensor storage objects +- DataPtrVariable: Handles tensor data pointer operations + +These classes work together to track tensor operations and properties during Dynamo's tracing process. +""" + +import functools +import logging +import operator +import textwrap +import traceback +import types +from collections.abc import Sequence +from contextlib import nullcontext +from typing import TYPE_CHECKING + +import sympy + +import torch._numpy as tnp +import torch.fx +import torch.random +from torch._dynamo import compiled_autograd +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx.experimental.symbolic_shapes import ( + guard_scalar, + GuardOnDataDependentSymNode, + has_free_symbols, + is_symbolic, + SymTypes, +) +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config, graph_break_hints, variables +from .._trace_wrapped_higher_order_op import trace_wrapped +from ..exc import ( + unimplemented, + UnknownPropertiesDuringBackwardTrace, + UserError, + UserErrorType, +) +from ..external_utils import call_hook_from_backward_state +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource +from ..utils import ( + fqn, + get_custom_getattr, + get_fake_value, + get_real_value, + guard_if_dyn, + object_has_getattribute, + product, + proxy_args_kwargs, + raise_args_mismatch, + set_example_value, + tensortype_to_dtype, +) +from .base import AttributeMutationNew, ValueMutationNew, VariableTracker +from .constant import ConstantVariable +from .lists import ListIteratorVariable, SizeVariable +from .user_defined import UserDefinedClassVariable + + +try: + import numpy as np +except ModuleNotFoundError: + np = None + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + from .functions import UserFunctionVariable + + +log = logging.getLogger(__name__) + +# Ops that allow tensor tensor +supported_tensor_comparison_ops = { + ">": operator.gt, + "<": operator.lt, + ">=": operator.ge, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, + "is": operator.is_, + "is not": operator.is_not, +} +# Ops that allow tensor None +supported_const_comparison_ops = { + "is": operator.is_, + "is not": operator.is_not, + "==": operator.eq, + "!=": operator.ne, +} +supported_comparison_ops = { + **supported_tensor_comparison_ops, + **supported_const_comparison_ops, +} +supported_tensor_comparison_op_values = dict.fromkeys( + supported_tensor_comparison_ops.values() +) +supported_const_comparison_op_values = dict.fromkeys( + supported_const_comparison_ops.values() +) + + +def is_bound_tensor_method(value): + return ( + callable(value) + and not torch._dynamo.utils.object_has_getattribute(value) + and hasattr(value, "__self__") + and isinstance(value.__self__, torch.Tensor) + and getattr(value.__self__, value.__name__, None) + ) + + +# instead of using inspect.getattr_static, we directly lookup the appropriate +# dicts. It is necessary to keep the torch._C.TensorBase first in the or +# operation, because the second arg takes priority in or operation when there +# are common keys. +all_tensor_attrs = torch._C.TensorBase.__dict__ | torch.Tensor.__dict__ + + +class TensorVariable(VariableTracker): + """A torch.Tensor input or an intermediate value in the FX graph""" + + _nonvar_fields = { + "proxy", + "dtype", + "device", + "layout", + "ndim", + "size", + "stride", + "requires_grad", + "is_quantized", + "is_contiguous", + "is_nested", + "is_sparse", + "class_type", + "specialized_value", + "_is_name_set", + *VariableTracker._nonvar_fields, + } + + def get_real_value(self): + """ + Get the actual value represented by this variable if computation is run + using the user-provided inputs. + NOTE: this runs actual tensor computation and may be + slow and memory-intensive. + """ + return get_real_value(self.proxy.node, self.proxy.tracer) + + def __init__( + self, + proxy: torch.fx.Proxy, + *, + dtype, + device, + layout, + ndim, + requires_grad, + is_nested, + is_quantized, + is_sparse, + class_type, + has_grad_fn, + _size=None, + stride=None, + is_contiguous=None, + _is_name_set=None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.proxy = proxy + self.dtype = dtype + self.device = device + self.layout = layout + self.ndim = ndim + self._size = _size # this is accessed as a property for validation + self.stride = stride + self.requires_grad = requires_grad + self.is_quantized = is_quantized + self.is_contiguous = is_contiguous + self.is_nested = is_nested + self.is_sparse = is_sparse + self.class_type = class_type + self.has_grad_fn = has_grad_fn + if _is_name_set is None: + # no need to rename inputs + _is_name_set = self.proxy.node.op == "placeholder" + self._is_name_set: bool = _is_name_set + + def synchronize_attributes(self, tx, target_cls=None): + from .builder import get_specialized_props, infer_subclass_type + + if target_cls is None: + target_cls = type(self) + + example_value = self.proxy.node.meta.get("example_value") + specialized_props = get_specialized_props( + target_cls, tx, example_value, infer_subclass_type(example_value) + ) + for k, v in specialized_props.items(): + setattr(self, k, v) + + def debug_repr(self): + # TODO: strip off fake tensor from repr here + return repr(self.proxy.node.meta["example_value"]) + + def as_proxy(self): + return self.proxy + + def python_type(self): + return self.class_type + + def is_tensor(self) -> bool: + return True + + @staticmethod + def specialize(value: torch.Tensor): + props = { + "dtype": value.dtype, + "device": value.device, + "layout": value.layout, + "ndim": int(value.ndim), + "requires_grad": value.requires_grad, + "is_nested": value.is_nested, + "is_quantized": value.is_quantized, + "is_sparse": value.is_sparse, + "class_type": type(value), + } + try: + props["has_grad_fn"] = value.grad_fn is not None + except Exception: + # Workaround for issues with create_parameter_op in Dynamo. Reading + # grad_fn should never cause an issue. + props["has_grad_fn"] = False + + if is_sparse_any(value) and not has_free_symbols(value): + props["_size"] = tuple( + int(s) if is_symbolic(s) else s for s in value.size() + ) + elif not has_free_symbols(value): + # this is a fully static shape, and the keys on props here inform specialization. + # We have to cast to int here, because these might get accessed as ConstantVariable, which has + # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant + # already. We could remove the discrepancy here, by having ConstantVariable be more permissive for + # constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and + # I'd like to keep it around for now. + props["_size"] = tuple( + # the non is_symbolic case applies to the jagged layout + # NestedTensor case as singleton ints are not symbolic + int(s) if is_symbolic(s) else s + for s in value.size() + ) + props["stride"] = tuple(value.stride()) + if torch._C._functorch.is_batchedtensor(value): + # Batched tensors does not support contiguity patterns, so + # we refrain from computing the `is_contiguous` property + props["is_contiguous"] = None + else: + props["is_contiguous"] = tuple( + x + for x in torch._prims_common._memory_formats + if value.is_contiguous(memory_format=x) + ) + return props + + def dynamic_getattr(self, tx: "InstructionTranslator", name): + fake_val = self.proxy.node.meta["example_value"] + # For getattrs on tensors without sources, + # we can do better than the default (creating a GetAttrVariable) + # if: + # (1) the tensor is a traceable tensor subclass + # (2) We are getattr'ing an inner tensor from that subclass + if not self.source and is_traceable_wrapper_subclass(fake_val): + attrs, _ctx = fake_val.__tensor_flatten__() + proxy = getattr(self.as_proxy(), name) + example_value = getattr(fake_val, name) + if name in attrs: + # attrs returned from tensor_flatten are always tensors + assert isinstance(example_value, torch.Tensor) + from .builder import wrap_fx_proxy + + return wrap_fx_proxy(tx=tx, proxy=proxy, example_value=example_value) + # any other attributes on the subclass (that are not methods) + # are assumed to be constant metadata. + elif not callable(example_value): + return VariableTracker.build(tx, example_value) + + if not (self.source and self.source.subguards_allowed()): + raise NotImplementedError + + # For local source, we associate the real value. We use this real value + # for implementing getattr fallthrough on the variable tracker base class. + + # Note - this scope construction is mirrored in guards + # A subsequent PR will introduce a util. + scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} + try: + # We raise in case we get a typerror bug w/ SuperSource. + # SuperSource has bugs in it atm, and can produce code like + # eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__, + # L['mod'].model.model.encoder.embed_positions)", scope) + # Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope. + _input_associated_real_value = eval(self.source.name, scope) + except Exception as exc: + raise NotImplementedError from exc + + if _input_associated_real_value is None: + raise NotImplementedError + + if object_has_getattribute(_input_associated_real_value): + raise NotImplementedError + + if get_custom_getattr(_input_associated_real_value): + raise NotImplementedError + + real_value = getattr(_input_associated_real_value, name) + + attr_source = AttrSource(self.source, name) + + # Typically we'd want to use variable builder here + # but unfortunately id(real_value.__self__) is not id() + if is_bound_tensor_method(real_value): + # No need to install the guard because its a bound tensor method + from .misc import GetAttrVariable + + return GetAttrVariable( + self, name, source=attr_source, py_type=type(real_value) + ) + + install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) + return VariableTracker.build(tx, real_value, attr_source) + + def method_attr_ndim(self, tx): + if self.ndim is not None: + return ConstantVariable.create(self.ndim) + else: + return self.call_method(tx, "dim", [], {}) + + def method_attr_dtype(self, tx): + if self.dtype is not None: + return ConstantVariable.create(self.dtype) + + def method_attr_device(self, tx): + if self.device is not None: + return ConstantVariable.create(self.device) + + def method_attr_layout(self, tx): + if self.layout is not None: + return ConstantVariable.create(self.layout) + + def method_attr_is_cuda(self, tx): + if self.device is not None: + return ConstantVariable.create(self.device.type == "cuda") + + def method_attr_shape(self, tx): + if self.valid_size(): + sizes = [variables.ConstantVariable.create(x) for x in self.size] + return SizeVariable(sizes) + else: + return self.call_method(tx, "size", [], {}) + + def method_attr_requires_grad(self, tx): + if self.requires_grad is not None: + return ConstantVariable.create(self.requires_grad) + + def method_attr_is_quantized(self, tx): + if self.is_quantized is not None: + return ConstantVariable.create(self.is_quantized) + + def method_attr_is_sparse(self, tx): + if self.is_sparse is not None: + return ConstantVariable.create(self.is_sparse) + + def method_attr_is_nested(self, tx): + if self.is_nested is not None: + return ConstantVariable.create(self.is_nested) + + def method_attr_retain_grad(self, tx): + unimplemented( + gb_type="Tensor.retain_grad() with AOTDispatcher", + context=f"var_getattr {self} retain_grad", + explanation="`Tensor.retain_grad()` does not work with AOTDispatcher.", + hints=[], + ) + + def method_attr_data(self, tx): + return variables.TorchInGraphFunctionVariable( + torch._C._autograd._get_data_attr + ).call_function(tx, [self], {}) + + def method_attr_grad_fn(self, tx): + if self.has_grad_fn: + unimplemented( + gb_type="Tensor with grad_fn()", + context=f"var_getattr {self} grad_fn", + explanation="Dynamo does not support tracing tensors with a grad_fn directly.", + hints=[], + ) + else: + return variables.ConstantVariable(None) + + def method_attr__version(self, tx): + from ..tensor_version_op import _tensor_version + + return variables.TorchInGraphFunctionVariable(_tensor_version).call_function( + tx, [self], {} + ) + + def call_obj_hasattr(self, tx: "InstructionTranslator", name): + from . import GetAttrVariable + from .builtin import BuiltinVariable + + # TODO - This is not a good solution but solves an accuracy issue. + # Today, var_getattr returns GetAttrVariable for both non-existent + # attributes and existing attributes. This is a bug and requires more + # deep dive. + if name in all_tensor_attrs: + return ConstantVariable(True) + + try: + var = BuiltinVariable(getattr).call_function( + tx, [self, ConstantVariable(name)], {} + ) + # in the event that TensorVariable returns NotImplemented + # BuiltinVariable.call_getattr returns GetAttrVariable + ret_val = not isinstance(var, GetAttrVariable) + except AttributeError: + ret_val = False + + if self.source: + install_guard( + AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) + ) + + return ConstantVariable(ret_val) + + def var_getattr(self, tx: "InstructionTranslator", name): + if self.is_strict_mode(tx): + if name in self._strict_mode_banned_ops(): + unimplemented( + gb_type="Strict mode banned op", + context=f"var_getattr {self} {name}", + explanation=f"Getattr invocation '{name}' in strict mode is not supported.", + hints=[ + f"Remove `{name}` from the list of banned ops by " + "setting `torch._dynamo.config._autograd_backward_strict_mode_banned_ops`.", + ], + ) + elif name in self._strict_mode_conditional_banned_ops(): + raise UnknownPropertiesDuringBackwardTrace( + f"Unknown property {name} during speculating backward, dynamo will insert contiguous call ahead and speculate it again" # noqa: B950 + ) + + if name == "__class__": + return UserDefinedClassVariable(self.python_type()) + + handler = getattr(self, f"method_attr_{name}", None) + result = handler(tx) if handler is not None else None + + # Add a guard for type matching, these guards are checked before tensor guards + # In some cases, a . guard can be evaluated first, and break if + # is later changed to another type + if ( + result is not None + and self.source + and self.source.subguards_allowed() + and not ( + name not in ("grad", "requires_grad") and result.is_python_constant() + ) + ): + install_guard(self.make_guard(GuardBuilder.TYPE_MATCH)) + result.source = AttrSource(self.source, name) + + # It's hard to get inplace view (metadata mutation) on graph input work properly across + # dynamo/aot/inductor, just fall back. + if self.source is not None and hasattr(torch.ops.aten, name): + fn = getattr(torch.ops.aten, name) + if ( + hasattr(fn, "overloads") + and hasattr(fn, fn.overloads()[0]) + and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags + ): + # Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc. + return variables.misc.DelayGraphBreakVariable( + source=AttrSource(self.source, name), + msg="Getting an inplace view on a graph input is not supported", + ) + + # For attributes (not methods) that were not caught in the special handling above, + # (e.g. tensor.real), we handle these generically, assuming that the output type is + # a tensor. + if result is None and name != "grad": + + def try_generic_attr_handling(): + from .builder import wrap_fx_proxy + from .misc import GetAttrVariable + + static_attr = all_tensor_attrs.get(name, None) + if static_attr is None: + return None + + # Make sure this is an attribute, not a method. + # type(torch.Tensor.H) should be "getset_descriptor" + # This is a because of CPython implementation, see THPVariableType: + # these attributes are implemented under tp_getset, which appear + # as `getset_descriptor`s, (compared to, say, methods which appear + # as `method_descriptor`s) + if type(static_attr) is not types.GetSetDescriptorType: + return None + + proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) + if self.source is not None: + return wrap_fx_proxy( + tx=tx, proxy=proxy, source=AttrSource(self.source, name) + ) + else: + return wrap_fx_proxy(tx=tx, proxy=proxy) + + result = try_generic_attr_handling() + + if result is None: + result = self.dynamic_getattr(tx, name) + + if result is None: + raise NotImplementedError + return result + + def call_id(self, tx): + if not self.source: + unimplemented( + gb_type="Unsupported call_id() without source", + context=f"call_id {self}", + explanation="call_id() not supported for sourceless TensorVariable.", + hints=[], + ) + + # For local source, we associate the real value. We use this real value + scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} + try: + _input_associated_real_value = eval(self.source.name, scope) + except Exception as exc: + unimplemented( + gb_type="Error getting associated real value", + context=f"call_id {self}", + explanation="Dynamo encountered an error while trying to " + "get the associated real value.", + hints=[], + from_exc=exc, + ) + + if _input_associated_real_value is None: + unimplemented( + gb_type="call_id() without associated real value", + context=f"call_id {self}", + explanation="Dynamo could not find an associated real value for the tensor.", + hints=[], + ) + + install_guard(self.source.make_guard(GuardBuilder.ID_MATCH)) + id_value = id(_input_associated_real_value) + return ConstantVariable.create(id_value) + + def has_unpack_var_sequence(self, tx): + return self.ndim > 0 + + def unpack_var_sequence(self, tx: "InstructionTranslator", idxes=None): + from .builder import wrap_fx_proxy_cls + + if self.valid_size(): + size_len = len(self.size) + else: + size_var = self.call_method(tx, "size", [], {}) + assert isinstance(size_var, SizeVariable) + size_len = len(size_var.items) + # Ensure we don't unpack a scalar tensor. + assert size_len != 0, "Can't unpack scalar tensors." + + if self.valid_size(): + length = self.size[0] + else: + dyn_length = self.call_method(tx, "size", [ConstantVariable.create(0)], {}) + # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through + # symbolic_shapes, but that end up as int/sympy.Integer + assert ( + isinstance(dyn_length, SymNodeVariable) + or dyn_length.is_python_constant() + ) + if isinstance(dyn_length, SymNodeVariable): + length = dyn_length.evaluate_expr(tx.output) + else: + length = dyn_length.as_python_constant() + + if idxes is None: + idxes = range(length) + else: + assert len(idxes) == length, ( + f"Can't unpack a tensor of {length} rows into a tuple of {len(idxes)} elements." + ) + return [ + wrap_fx_proxy_cls(target_cls=type(self), tx=tx, proxy=self.as_proxy()[i]) + for i in idxes + ] + + def call_tree_map( + self, + tx, + tree_map_fn: "UserFunctionVariable", + map_fn, + rest, + tree_map_kwargs, + ) -> "VariableTracker": + return map_fn.call_function(tx, [self, *rest], {}) + + def valid_size(self): + return self._size is not None + + @property + def size(self): + assert self._size is not None, "accessing None size in TensorVariable" + return self._size + + def _strict_mode_banned_ops(self): + return torch._dynamo.config._autograd_backward_strict_mode_banned_ops + + def _strict_mode_conditional_banned_ops(self): + return ( + torch._dynamo.config._autograd_backward_strict_mode_conditional_banned_ops + ) + + def call_method( + self, + tx, + name, + args: Sequence[VariableTracker], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import SourcelessBuilder, VariableBuilder + from .torch_function import can_dispatch_torch_function, dispatch_torch_function + + if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops(): + unimplemented( + gb_type="Illegal method invocation in strict mode", + context=f"call_method {self} {name} {args} {kwargs}", + explanation="Dynamo currently does not support this method " + f"({name}) invocation in strict mode.", + hints=[], + ) + + # Only override builtin tensor methods + # The user can manually add override handling + # with a decorator for other methods (e.g. a dispatch subclass with other methods) + static_attr = all_tensor_attrs.get(name, None) + is_base_tensor_method = static_attr is not None + + if ( + can_dispatch_torch_function(tx, tuple([self] + list(args)), kwargs) + and is_base_tensor_method + ): + if self.source: + func_var = VariableBuilder( + tx, AttrSource(AttrSource(self.source, "__class__"), name) + )(static_attr) + else: + func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name)) + + return dispatch_torch_function( + tx, func_var, tuple([self] + list(args)), kwargs + ) + + """ + Dispatch to a method-specific handler defined below. If the + handler returns None (or doesn't exist) we put the method call + in the graph. + """ + + # This is seen in inspect signature where we check if the value is a default value + if name == "__eq__" and isinstance(args[0], UserDefinedClassVariable): + return variables.ConstantVariable(False) + + # For historical reasons, these ops decompose down to syntactically + # invalid aten ops because they contain the python keyword `from`, see + # discussions in #151432 for more details. + # We graph break for now since this use case is uncommon. + if name == "random_": + unimplemented( + gb_type="Tensor.random_ op", + context=f"Tensor.{name}({args=}, {kwargs=})", + explanation="This is currently not supported.", + hints=[ + "Use the out-of-place version of this op", + *graph_break_hints.SUPPORTABLE, + ], + ) + elif name == "uniform_" and "from" in kwargs: + unimplemented( + gb_type="Tensor.uniform_ op called with `from` keyword", + context=f"Tensor.{name}({args=}, {kwargs=})", + explanation="This is currently not supported.", + hints=[ + "Avoid using the `from` keyword.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + try: + handler_method = getattr(self, f"method_{name}") + except AttributeError: + pass + else: + try: + result = handler_method(*args, **kwargs) + if result: + return result + except TypeError as e: + unimplemented( + gb_type="Unhandled args for method", + context=f"call_method {self} {name} {args} {kwargs}", + explanation="Dynamo encountered an error while calling " + f"the method `{name}`.", + hints=[], + from_exc=e, + ) + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self, *args], kwargs), + ), + ) + + def method_size(self, *args, **kwargs): + return self._method_size_stride("size", *args, **kwargs) + + def method_stride(self, *args, **kwargs): + return self._method_size_stride("stride", *args, **kwargs) + + def _method_size_stride(self, name, dim=None): + dim = guard_if_dyn(dim) + + def make_const_size_variable(x, **options): + return SizeVariable( + [ConstantVariable.create(y, **options) for y in x], **options + ) + + RetVariable = ( + make_const_size_variable if name == "size" else ConstantVariable.create + ) + + # Technically, this should not be necessary, but I'm including it + # for enhanced BC, in case example_value is sometimes not set + # (it really should always be set though!) + if name != "size": + r = getattr(self, name) + elif name == "size" and self.valid_size(): + r = self.size + else: + r = None + + if r is not None: + if dim is None: + return RetVariable(r) + else: + return ConstantVariable.create(r[dim]) + + # It might still be constant! Consult the fake tensor and see + if (fake := self.proxy.node.meta.get("example_value")) is not None: + if dim is None: + fake_r = getattr(fake, name)() + if not has_free_symbols(fake_r): + # int conversion for safety, in case a SymInt refined + # to constant + return RetVariable(tuple(int(r) for r in fake_r)) + else: + fake_r = getattr(fake, name)(dim) + if not has_free_symbols(fake_r): + return ConstantVariable.create(int(fake_r)) + + def method_numel(self): + if self.valid_size(): + return ConstantVariable.create(product(self.size)) + + # It might still be constant! Consult the fake tensor and see + if (fake := self.proxy.node.meta.get("example_value")) is not None: + fake_r = fake.numel() + if not has_free_symbols(fake_r): + return ConstantVariable.create(int(fake_r)) + + method_nelement = method_numel + + def method_dim(self): + if self.ndim is not None: + return ConstantVariable.create(self.ndim) + + method_ndimension = method_dim + + def method_is_floating_point(self): + if self.dtype is not None: + return ConstantVariable.create(self.dtype.is_floating_point) + + def method_is_inference(self): + if config.fake_tensor_disable_inference_mode: + unimplemented( + gb_type="Encountered tensor.is_inference() during tracing", + context="", + explanation="tensor.is_inference() is not supported", + hints=[ + *graph_break_hints.FUNDAMENTAL, + *graph_break_hints.INFERENCE_MODE, + ], + ) + if (fake := self.proxy.node.meta.get("example_value")) is not None: + return ConstantVariable.create(fake.is_inference()) + + def method_is_complex(self): + if self.dtype is not None: + return ConstantVariable.create(self.dtype.is_complex) + + def method_is_contiguous(self, memory_format=None): + memory_format = ( + memory_format.as_python_constant() + if memory_format is not None + else torch.contiguous_format + ) + if self.is_contiguous is not None: + return ConstantVariable.create(memory_format in self.is_contiguous) + elif (fake := self.proxy.node.meta.get("example_value")) is not None: + return ConstantVariable.create( + fake.is_contiguous(memory_format=memory_format) + ) + + def method_type(self, dtype=None, non_blocking=False, **kwargs): + if ( + dtype is None + and self.dtype is not None + and isinstance(self.device, torch.device) + ): + tensortype = next( + k for k, v in tensortype_to_dtype.items() if self.dtype in v + ) + if self.device.type == "cpu": + return ConstantVariable.create(f"torch.{tensortype.__name__}") + else: + return ConstantVariable.create( + f"torch.{self.device.type}.{tensortype.__name__}" + ) + elif ( + dtype is not None + and fqn(type(dtype.as_python_constant())) == "torch.tensortype" + ): + # torch.FloatTensor, etc. are all of type "torch.tensortype". + # torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type. + # So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args) + tensor_type = dtype.as_python_constant() + tensor_type_const = ConstantVariable.create(fqn(tensor_type)) + + from ..symbolic_convert import InstructionTranslator + from .builder import wrap_fx_proxy + + tx = InstructionTranslator.current_tx() + + if non_blocking: + kwargs = {"non_blocking": non_blocking, **kwargs} + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + "type", + *proxy_args_kwargs([self, tensor_type_const], kwargs), + ), + ) + + def method_as_subclass(self, cls): + if isinstance(cls, TensorSubclassVariable) and cls.source: + from ..symbolic_convert import InstructionTranslator + from .torch_function import TensorWithTFOverrideVariable + + tx = InstructionTranslator.current_tx() + py_cls = cls.as_python_constant() + var = TensorWithTFOverrideVariable.from_tensor_var( + tx, self, py_cls, cls.source + ) + # See NOTE [Side effect tracking for newly constructed tensor] + tx.output.side_effects._track_obj( + object(), var, mutation_type_cls=AttributeMutationNew + ) + return var + unimplemented( + gb_type="Argument of `as_subclass` must be a non-dispatcher-style tensor subclass", + context=f"{self}.as_subclass({cls})", + explanation="Currently not supported", + hints=[ + "Avoid this call or move it outside `torch.compile` regione", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def method_get_device(self): + if isinstance(self.device, torch.device): + index = self.device.index if self.device.type != "cpu" else -1 + return ConstantVariable.create(index) + + def method_element_size(self): + return ConstantVariable.create(self.dtype.itemsize) + + def method_numpy(self, *, force=False): + if not config.trace_numpy: + unimplemented( + gb_type="Tensor.numpy() with trace_numpy=False", + context=f"call_method {self} numpy", + explanation="`Tensor.numpy()` was called, but the `trace_numpy` " + "configuration was manually disabled.", + hints=[ + "Set `torch._dynamo.config.trace_numpy = True` to allow " + "Dynamo to trace through NumPy.", + ], + ) + if not np: + unimplemented( + gb_type="Tensor.numpy() without NumPy installed", + context=f"call_method {self} numpy", + explanation="`Tensor.numpy()` was called, but the NumPy library " + "is not available in the current environment.", + hints=[ + "Ensure NumPy is installed in your Python environment.", + ], + ) + if self.layout != torch.strided: + raise TypeError( + f"can't convert {self.layout} layout tensor to numpy. Use Tensor.to_dense() first" + ) + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + + # We don't check that the tensor is on CPU when force is False, as this + # allows us to execute NumPy code on CUDA. Same for requires_grad=True + if force and force.as_python_constant(): + # If the user set force=True we try to preserve the semantics (no gradients, move to CPU...) + t = self.call_method(tx, "detach", [], {}) + proxy = tx.output.create_proxy("call_method", "cpu", (t.as_proxy(),), {}) + else: + # Hacky way to create a view of self that will be marked as NumpyNdarrayVariable + proxy = tx.output.create_proxy( + "call_method", "view_as", *proxy_args_kwargs([self, self], {}) + ) + return NumpyNdarrayVariable.create(tx, proxy) + + def method_tolist(self): + from ..symbolic_convert import InstructionTranslator + from .builder import wrap_fx_proxy + + tx = InstructionTranslator.current_tx() + + def tolist(tensor, sub_proxy): + def wrap(i, sub_proxy): + return wrap_fx_proxy( + tx, + sub_proxy.item(), + ) + + if tensor.dtype not in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ]: + unimplemented( + gb_type="Tensor.tolist() with non-integer tensor", + context=f"call_method {self} to_list", + explanation="Dynamo currently does not support tracing " + "`tolist()` on non-integer tensors.", + hints=[ + "Ensure the input tensor to `tolist()` is an integer " + "type (e.g., int8, int16, int32, int64)." + ], + ) + + if tensor.dim() == 0: + return wrap(tensor, sub_proxy) + + if tensor.dim() == 1: + return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)] + + return [ + tolist(sub_tensor, sub_proxy=sub_proxy[i]) + for i, sub_tensor in enumerate(tensor) + ] + + tensor = self.as_proxy().node.meta["example_value"] + out = tolist(tensor, self.as_proxy()) + return VariableTracker.build(tx, out) + + def method_backward(self, *args, **kwargs): + unimplemented( + gb_type="Unsupported Tensor.backward() call", + context=f"call_method {self} backward {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.backward()`.", + hints=[*graph_break_hints.FUNDAMENTAL], + ) + + def method_data_ptr(self, *args, **kwargs): + return DataPtrVariable(self) + + def method_item(self, *args, **kwargs): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + # We enable capture_scalar_outputs when full_graph=True by default. + if not tx.one_graph and not config.capture_scalar_outputs: + self._warn_capture_scalar_outputs() + unimplemented( + gb_type="Unsupported Tensor.item() call with capture_scalar_outputs=False", + context=f"call_method {self} item {args} {kwargs}", + explanation="Dynamo does not support tracing `Tensor.item()` " + "with config.capture_scalar_outputs=False.", + hints=[ + "Set `torch._dynamo.config.capture_scalar_outputs = True` " + "or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` " + "to include these operations in the captured graph.", + ], + ) + + def method___getitem__(self, *args, **kwargs): + from ..symbolic_convert import InstructionTranslator + from .builder import wrap_fx_proxy + + tx = InstructionTranslator.current_tx() + if isinstance(args[0], SymNodeVariable): + # Standard indexing will force specialization due to + # __index__. Rewrite as a regular torch op which will + # trace fine + fn, args = ( + torch.select, + [ + variables.ConstantVariable.create(0), + args[0], + ], + ) + else: + fn = operator.getitem + + proxy = tx.output.create_proxy( + "call_function", + fn, + *proxy_args_kwargs([self] + list(args), kwargs), + ) + + return wrap_fx_proxy(tx, proxy) + + @staticmethod + @functools.cache + def _warn_capture_scalar_outputs(): + user_stack = torch._guards.TracingContext.extract_stack() + user_stack_formatted = "".join(traceback.format_list(user_stack)) + log.warning( + textwrap.dedent( + """\ + Graph break from `Tensor.item()`, consider setting: + torch._dynamo.config.capture_scalar_outputs = True + or: + env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 + to include these operations in the captured graph. + + Graph break: from user code at: + %s + """ + ), + user_stack_formatted, + ) + + def method___len__(self): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + return self.call_method(tx, "size", [ConstantVariable.create(0)], {}) + + def method___iter__(self): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + return ListIteratorVariable( + self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() + ) + + def method_addcmul_(self, tensor1, tensor2, *, value=None): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + if value is not None: + from .. import polyfills + + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.addcmul_inplace), + [self, tensor1, tensor2, value], + {}, + ) + + def method___setitem__(self, key, value): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + proxy = tx.output.create_proxy( + "call_function", + operator.setitem, + *proxy_args_kwargs([self, key, value], {}), + ) + + if value.is_tensor(): + # [Note: Tensor.__setitem__ and VariableTracker metadata] + # At this point, we proxied a node representing `self[key] = value` into the graph. + # When executed, this node will mutate `self`'s tensor metadata, so it's important + # even during tracing to propagate. For example: + # value.requires_grad is True => self.requires_grad becomes True + # value.requires_grad is True => self.has_grad_fn becomes True + + # Not sure if __setitem__ can ever save activations, disabling just in case + + # Ignore fresh unbacked symbols that could arise from the internal indexing (selection), + # that happen in code like t[idx] += 1 when idx is unbacked. Namely the selection + # during 'setitem'. + # When the selection happens if idx is unbacked we allocate a new unbacked symbol for the + # storage offset in select_meta, but the output of the operation 'setitem' does not depend + # on the selection. + with ( + torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(), + tx.fake_mode.shape_env.ignore_fresh_unbacked_symbols() + if tx.fake_mode and tx.fake_mode.shape_env + else nullcontext(), + ): + get_fake_value(proxy.node, tx, allow_non_graph_fake=False) + + vt = value + if isinstance(vt, variables.lazy.LazyVariableTracker): + vt = variables.lazy.LazyVariableTracker.realize_all(vt) + + self.synchronize_attributes(tx, type(vt)) + + if config.use_graph_deduplication or config.track_nodes_for_deduplication: + tx.output.region_tracker.add_node_mutation(proxy.node, 0) + + return ConstantVariable.create(None) + + def method_resize_(self, *args, **kwargs): + unimplemented( + gb_type="Unsupported Tensor.resize_() call", + context=f"call_method {self} resize_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.resize_()`.", + hints=[], + ) + + def method_resize_as_(self, *args, **kwargs): + unimplemented( + gb_type="Unsupported Tensor.resize_as_() call", + context=f"call_method {self} resize_as_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.resize_as_()`.", + hints=[], + ) + + def method_sparse_resize_(self, *args, **kwargs): + unimplemented( + gb_type="Unsupported Tensor.sparse_resize_() call", + context=f"call_method {self} sparse_resize_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.sparse_resize_()`.", + hints=[], + ) + + def method_sparse_resize_and_clear_(self, *args, **kwargs): + unimplemented( + gb_type="Unsupported Tensor.sparse_resize_and_clear_() call", + context=f"call_method {self} sparse_resize_and_clear_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.sparse_resize_and_clear_()`.", + hints=[], + ) + + def method_set_(self, *args, **kwargs): + if len(args) > 1: + # torch.Tensor.set_() has several overloads. + # aten::set_.source_Tensor(Tensor) gets special handling + # in AOTAutograd and functionalization, because it is the most common + # overload and is used by FSDP. + # graph-breaking on aten::set_source_Tensor_storage_offset for now, + # unless we find that we need to make it work. + unimplemented( + gb_type="Unsupported Tensor.set_() call", + context=f"call_method {self} set_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.set_()` " + "overloads that include more than one argument.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + def method_add_(self, other, *, alpha=None): + if alpha is not None: + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + result = variables.TorchInGraphFunctionVariable(torch.mul).call_function( + tx, [other, alpha], {} + ) + return self.call_method(tx, "add_", [result], {}) + + def method_addcdiv_(self, tensor1, tensor2, *, value=None): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + if value is not None: + result = variables.TorchInGraphFunctionVariable(torch.div).call_function( + tx, [tensor1, tensor2], {} + ) + result = variables.TorchInGraphFunctionVariable(torch.mul).call_function( + tx, [result, value], {} + ) + return self.call_method(tx, "add_", [result], {}) + + def method___contains__(self, arg): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + + # Rewrite __contains__ here so that downstream passes can trace through + # without dealing with unbacked symbool. Roughly the code we translate is: + # def __contains__(self, x): + # return (x == self).any().item() + result = variables.TorchInGraphFunctionVariable(torch.eq).call_function( + tx, [self, arg], {} + ) + result = variables.TorchInGraphFunctionVariable(torch.any).call_function( + tx, [result], {} + ) + return result.call_method(tx, "item", [], {}) + + def method_redistribute(self, *args, **kwargs): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function + # and rewrite args to have only proxyable args, then insert call_function + args_as_value = [x.as_python_constant() for x in args] + kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} + + def redistribute_fn_with_prim_types(x): + return x.redistribute(*args_as_value, **kwargs_as_value) + + # attach the same function name for better debugging + redistribute_fn_with_prim_types.__name__ = "prim_redistribute" + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + redistribute_fn_with_prim_types, + *proxy_args_kwargs([self], {}), + ), + ) + + def method_to_local(self, *args, **kwargs): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function + # and rewrite args to have only proxyable args, then insert call_function + + grad_placements_vt = kwargs.get( + "grad_placements", ConstantVariable.create(None) + ) + if isinstance(grad_placements_vt, variables.UserDefinedObjectVariable): + # grad_placement is a sequence-like structure, iterate over the value + grad_placements_vt = variables.BuiltinVariable(tuple).call_function( + tx, [grad_placements_vt], {} + ) + + if kwargs.get("grad_placements") is not None: + kwargs["grad_placements"] = grad_placements_vt + + args_as_value = [x.as_python_constant() for x in args] + kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} + + def to_local_fn_with_prim_types(x): + return x.to_local(*args_as_value, **kwargs_as_value) + + # attach the same function name for better debugging + to_local_fn_with_prim_types.__name__ = "prim_to_local" + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + to_local_fn_with_prim_types, + *proxy_args_kwargs([self], {}), + ), + ) + + def method_register_hook(self, *args, **kwargs): + return self._method_register_hook("register_hook", *args, **kwargs) + + def method_register_post_accumulate_grad_hook(self, *args, **kwargs): + return self._method_register_hook( + "register_post_accumulate_grad_hook", *args, **kwargs + ) + + def _method_register_hook(self, name: str, hook: VariableTracker): + # Note - do not arbitrarily add hooks here - make sure they match the same contract + # see [On tensor.register_hook] + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + + if not self.source: + if not compiled_autograd.compiled_autograd_enabled: + # TODO(voz): + # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary + # python state. + # We *Must* be in compiled_autograd here because backward hooks can contain anything, and it is unsafe to run + # them in a compiled bwd without re-entering dynamo as compiled_autograd does. + # + # Discussion point 1 - Should we bypass this if nopython/fullgraph = True? + # No. Because this was going to be a graph break anyway - this check does not + # introduce new graph breaks where there were none. + # + # Discussion point 2 - Should we defer this check to backwards? + # No. Because compiled autograd is not yet ready for prime time. As such, if we defer, a user + # would have no recourse - their forward traces just fine, but will fail at backwards unless + # compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today) + # then they have nothing they can do except disable compile. + unimplemented( + gb_type="Compilation of intermediate hooks requires compiled autograd", + context=f"var_getattr {self} {name}", + explanation="Dynamo must be in compiled_autograd to register hooks.", + hints=[], + ) + + hook_name, bw_state_proxy = tx.output.add_backward_state_hook(hook) + + def _register_hook_trampoline(tensor, bw_state): + register_hook = getattr(tensor, name) + register_hook( + functools.partial( + trace_wrapped, + fn=call_hook_from_backward_state, + bw_state=bw_state, + hook_name=hook_name, + ) + ) + # TODO(jansel): returning None here is wrong, it should be + # RemovableHandle, but we need some extra work to support + # this properly. + return None + + from .builder import wrap_fx_proxy + + self_proxy = self.as_proxy() + self_proxy.node.meta["has_backward_hook"] = True + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + _register_hook_trampoline, + (self_proxy, bw_state_proxy), + {}, + ), + ) + + handle_variable = variables.RemovableHandleVariable( + mutation_type=variables.base.ValueMutationNew(), + ) + tx.output.side_effects.register_hook(self, hook, handle_variable, name) + return handle_variable + + def method_requires_grad_(self, requires_grad=True): + if requires_grad is not True: + requires_grad = requires_grad.as_python_constant() + + if self.as_proxy().node.meta["example_value"].requires_grad != requires_grad: + unimplemented( + gb_type="Unsupported Tensor.requires_grad_() call", + context=f"call_method {self} requires_grad_", + explanation="Dynamo does not support changes to a Tensor's " + "`requires_grad` through calling `requires_grad_()`.", + hints=[], + ) + else: + return self + + def method_new(self, *args, **kwargs): + # Convert x.new(torch.Size) into x.new_empty(torch.Size), + # as Tensor.new acts differently with a Size input versus a tuple input. + if (len(args) == 1 and isinstance(args[0], SizeVariable)) or ( + len(args) >= 1 + and all( + a.is_python_constant() and isinstance(a.as_python_constant(), int) + for a in args + ) + ): + from ..symbolic_convert import InstructionTranslator + + return self.call_method( + InstructionTranslator.current_tx(), "new_empty", args, kwargs + ) + + def method_untyped_storage(self): + return UntypedStorageVariable( + self, self.as_proxy().node.meta["example_value"].untyped_storage() + ) + + def set_name_hint(self, name: str): + if not self._is_name_set: + self.proxy.node._rename(name) + self._is_name_set = True + + def is_python_hashable(self): + # Tensors are hashable if they have an example_value (a fake tensor) + # Most VT's should have one. + # It'd be nice if at some point we could assert that they all have one + return self.as_proxy().node.meta["example_value"] is not None + + def get_python_hash(self): + return hash(self.as_proxy().node.meta["example_value"]) + + def is_python_equal(self, other): + a = self.as_proxy().node.meta["example_value"] + b = other.as_proxy().node.meta["example_value"] + return a is b + + +class SymNodeVariable(VariableTracker): + """ + Represents a symbolic scalar, either int, float or bool. This is most commonly used to + handle symbolic size computation, e.g., tensor.size(0), but it is also used to + handle logic like float_tensor.item() or unspecialized float inputs. + """ + + _nonvar_fields = { + "proxy", + "sym_num", + *VariableTracker._nonvar_fields, + } + + def debug_repr(self): + return repr(self.sym_num) + + @classmethod + def create(cls, tx, proxy, sym_num=None, **options): + if sym_num is None: + sym_num = get_fake_value(proxy.node, tx) + if "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == sym_num + set_example_value(proxy.node, sym_num) + + if isinstance(sym_num, (sympy.Integer, int, bool)): + sym_num = int(sym_num) if isinstance(sym_num, sympy.Integer) else sym_num + return ConstantVariable.create(sym_num) + + out = SymNodeVariable(proxy, sym_num, **options) + if proxy.node.op != "placeholder": + tx.output.current_tracer.record_tensor_or_symint_vt(out) + return out + + def __init__(self, proxy, sym_num, **kwargs) -> None: + super().__init__(**kwargs) + self.proxy = proxy + # TODO: Should we allow non SymTypes here? Today it is allowed + self.sym_num = sym_num + self._tensor_var = None + + def python_type(self): + if isinstance(self.sym_num, SymTypes): + return self.sym_num.node.pytype + else: + return type(self.sym_num) + + def is_symnode_like(self) -> bool: + return True + + def as_proxy(self): + return self.proxy + + def as_tensor(self, tx, dtype): + if self._tensor_var is None: + self._tensor_var = VariableTracker.build( + tx, torch.scalar_tensor + ).call_function(tx, [self], {"dtype": VariableTracker.build(tx, dtype)}) + return self._tensor_var + + def evaluate_expr(self, output_graph=None): + try: + return guard_scalar(self.sym_num) + except GuardOnDataDependentSymNode as e: + if torch.fx.experimental._config.no_data_dependent_graph_break: + raise + + raise UserError( # noqa: B904 + UserErrorType.ANTI_PATTERN, + f"Consider annotating your code using torch._check*(). {str(e)}", + case_name="constrain_as_size_example", + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self, *args], kwargs), + ), + ) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + # Essentially convert the SymNode to a constant variable whenever its + # searched for a dict key. + return hash(self.evaluate_expr()) + + def is_python_equal(self, other): + if isinstance(other, SymNodeVariable): + return self.evaluate_expr() == other.evaluate_expr() + # could be constant variable as well + return self.evaluate_expr() == other.as_python_constant() + + +class NumpyNdarrayVariable(TensorVariable): + """ + Represents a np.ndarray, but backed by torch Tensor via torch._numpy.ndarray. + Use this for Tensor.numpy() call. + """ + + @staticmethod + def create(tx: "InstructionTranslator", proxy, **options): + from .builder import wrap_fx_proxy_cls + + return wrap_fx_proxy_cls( + target_cls=NumpyNdarrayVariable, + tx=tx, + proxy=proxy, + **options, + ) + + def var_getattr(self, tx: "InstructionTranslator", name): + # NB: This INTENTIONALLY does not call super(), because there is + # no intrinsic reason ndarray properties are related to Tensor + # properties. The inheritance here is for implementation sharing. + + from ..utils import numpy_attr_wrapper + from .builder import wrap_fx_proxy + + result = None + + example_value = self.as_proxy().node.meta["example_value"] + example_ndarray = tnp.ndarray(example_value) + + def insert_into_graph(): + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", numpy_attr_wrapper, (self.as_proxy(), name), {} + ), + ) + + if name in ["T", "real", "imag"]: + proxy = tx.output.create_proxy( + "call_function", + numpy_attr_wrapper, + (self.as_proxy(), name), + {}, + ) + result = NumpyNdarrayVariable.create(tx, proxy) + + # These are awkward to implement. The standard playbook for torch._numpy + # interop is to trace a call into the torch._numpy wrapper which works for + # Tensor operations. However, we don't want to do this for calls + # that don't return Tensors, because in those cases we may not want + # to trace the attribute access into the graph at all (it is sort + # of harmless to do so, because AOTAutograd will eliminate them, + # but it's best not to trace them in to begin with.) But in any + # case, tracing these into the graph is like trying to fit a square + # peg into a round hole; best not to do it. So instead we + # painstakingly implement these by hand + # + # NB: only ALWAYS specialized attributes can go here; notably, + # size/shape not allowed! + elif name in ("ndim", "itemsize"): + return ConstantVariable.create(getattr(example_ndarray, name)) + elif name in ("shape", "stride"): + if not has_free_symbols(r := getattr(example_ndarray, name)): + return ConstantVariable.create(tuple(int(r) for r in r)) + return insert_into_graph() + elif name == "size": + if not has_free_symbols(r := example_ndarray.size): + return ConstantVariable.create(int(r)) + return insert_into_graph() + elif name in ["base", "flags", "dtype"]: + unimplemented( + gb_type="Unsupported ndarray attribute access", + context=f"var_getattr {self} {name}", + explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.", + hints=[], + ) + elif name == "__version__": + unimplemented( + gb_type="Unsupported ndarray.__version__ access", + context=f"var_getattr {self} {name}", + explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.", + hints=[], + ) + if result is None: + raise NotImplementedError + return result + + @staticmethod + def patch_args(name, args, kwargs): + if name == "clip": + kwargs_rename = {"a_min": "min", "a_max": "max"} + kwargs = {kwargs_rename.get(k, k): v for k, v in kwargs.items()} + return args, kwargs + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from ..exc import unimplemented + from ..utils import numpy_method_wrapper + + args, kwargs = self.patch_args(name, args, kwargs) + + if name == "astype": + from .builtin import BuiltinVariable + + dtype_arg = None + if "dtype" in kwargs: + dtype_arg = kwargs["dtype"] + elif len(args) > 0: + dtype_arg = args[0] + is_object_str = dtype_arg is not None and dtype_arg.is_constant_match("O") + is_object_type = ( + isinstance(dtype_arg, BuiltinVariable) and dtype_arg.fn is object + ) + if is_object_str or is_object_type: + unimplemented( + gb_type="ndarray.astype(object)", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=( + "`ndarray.astype('O')` or `ndarray.astype(object)` is not supported " + "by torch.compile, as there is no equivalent to object type in torch.Tensor. " + "This will be executed eagerly." + ), + hints=[*graph_break_hints.FUNDAMENTAL], + ) + if name in ["__len__", "size", "tolist", "__iter__"]: + # delegate back to TensorVariable + return super().call_method(tx, name, args, kwargs) + if name in ("tostring", "tobytes", "__delattr__"): + unimplemented( + gb_type="Unsupported ndarray method call", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=f"`ndarray.{name}()` is not modelled in `torch._numpy`.", + hints=[], + ) + proxy = tx.output.create_proxy( + "call_function", + numpy_method_wrapper(name), + *proxy_args_kwargs([self] + list(args), kwargs), + ) + return NumpyNdarrayVariable.create(tx, proxy) + + def python_type(self): + return np.ndarray + + +class UnspecializedPythonVariable(TensorVariable): + """ + This is a 1-element tensor represents unspecialized python float/int. + """ + + _nonvar_fields = { + "raw_value", + "need_unwrap", + *TensorVariable._nonvar_fields, + } + + def __init__( + self, proxy: torch.fx.Proxy, *, raw_value=None, need_unwrap=True, **kwargs + ) -> None: + super().__init__(proxy, **kwargs) + self.raw_value = raw_value + self.need_unwrap = need_unwrap + + @classmethod + def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True): + # Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance. + return UnspecializedPythonVariable( + **dict(tensor_variable.__dict__), + raw_value=raw_value, + need_unwrap=need_unwrap, + ) + + +class FakeItemVariable(TensorVariable): + """An unspecialized python variable which prevents access to the underlying raw value. + This is needed if item is called on a FakeTensor.""" + + _nonvar_fields = { + "need_unwrap", + *TensorVariable._nonvar_fields, + } + + def __init__(self, proxy: torch.fx.Proxy, **kwargs) -> None: + need_unwrap = kwargs.pop("need_unwrap", False) + super().__init__(proxy, **kwargs) + self.need_unwrap = need_unwrap + + @classmethod + def from_tensor_variable(cls, tensor_variable): + return FakeItemVariable(**dict(tensor_variable.__dict__)) + + +class TensorSubclassVariable(UserDefinedClassVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # Handle `Subclass(existing_tensor, ...)` calls. + from .torch_function import TensorWithTFOverrideVariable + + new_func = self.value.__new__ + if new_func is torch.Tensor.__new__: + if len(args) == 1 and args[0].is_tensor() and len(kwargs) == 0: + data = args[0] + # Simulate `torch.Tensor.__new__` as shallow-copying the input + # tensor data with a new type. TODO polyfill? + var = TensorWithTFOverrideVariable.from_tensor_var( + tx, data, self.value, self.source + ) + else: + unimplemented( + gb_type="Calling subclass default constructor with more than tensor argument", + context=f"{self.value}(args={args}, kwargs={kwargs})", + explanation="Currently not supported", + hints=[ + "Avoid this constructor call or move it outside " + "`torch.compile` regione", + *graph_break_hints.SUPPORTABLE, + ], + ) + else: + # Let Dynamo trace through custom `__new__` + var = VariableTracker.build(tx, new_func).call_function( + tx, [self] + args, kwargs + ) + + # Let Dynamo trace through custom `__init__` + init_func = self.value.__init__ + # TODO builder should be able to handle `torch.Tensor.__init__`, + # which is `object.__init__`, so that we can remove this check. + if init_func is not torch.Tensor.__init__: + VariableTracker.build(tx, init_func).call_function(tx, [var], kwargs) + + # See NOTE [Side effect tracking for newly constructed tensor] + tx.output.side_effects._track_obj( + object(), var, mutation_type_cls=AttributeMutationNew + ) + return var + + def as_python_constant(self): + return self.value + + +class UntypedStorageVariable(VariableTracker): + _nonvar_fields = { + "example_value", + *VariableTracker._nonvar_fields, + } + + def __init__( + self, + from_tensor: TensorVariable, + example_value: torch.UntypedStorage, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.from_tensor = from_tensor + # Example_value will always have device="meta" + self.example_value = example_value + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "size": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + result = self.example_value.size() + if not has_free_symbols(result): + # avoid creating a node in the graph + return ConstantVariable.create(int(result)) + else: + from ..external_utils import untyped_storage_size + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + untyped_storage_size, + (self.from_tensor.as_proxy(),), + {}, + ), + ) + if name == "resize_" and len(args) == 1: + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + tx.output.create_proxy( + "call_function", + torch.ops.inductor.resize_storage_bytes_, + (self.from_tensor.as_proxy(), args[0].as_proxy()), + {}, + ) + return self + + return super().call_method(tx, name, args, kwargs) + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.from_tensor) + codegen.load_method("untyped_storage") + codegen.call_method(0) + + +class DataPtrVariable(VariableTracker): + def __init__( + self, + from_tensor: TensorVariable, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.from_tensor = from_tensor + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.from_tensor) + codegen.load_method("data_ptr") + codegen.call_method(0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3c3afc551b8f0fde5527bf4adcae3689bb3b9e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py @@ -0,0 +1,2183 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs + +""" +This module implements variable tracking for torch functions and operations during Dynamo tracing. + +It provides classes to handle different types of torch operations: + +TorchInGraphFunctionVariable: Handles torch.* functions that should be captured in the FX graph. +Provides special handling for constant folding, tensor methods, and torch function overrides. +Manages complex cases like out= variants and parameter construction. + +TorchCtxManagerClassVariable: Handles torch context managers like torch.no_grad(), autocast, etc. +Provides implementations for entering/exiting these contexts during tracing. + +DispatchKeySetVariable: Represents torch.DispatchKeySet for managing dispatch keys and +device-specific operations during tracing. + +The module includes special handling for: +- Constant folding of pure functions +- Tensor method calls +- torch.nn.Parameter construction +- __torch_function__ overrides +- Context manager state tracking +- Device and dtype management + +This is a core part of Dynamo's tracing system, translating torch operations into +traceable graph nodes while preserving correct semantics and handling edge cases. +""" + +import functools +import inspect +import logging +import math +import re +from collections.abc import Callable, Sequence +from typing import Any, Optional, TYPE_CHECKING + +import torch._C +import torch._refs +import torch.fx +import torch.nn +from torch._guards import TracingContext +from torch._logging import warning_once +from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type + +from .. import config, graph_break_hints, polyfills, variables +from ..codegen import PyCodegen +from ..create_parameter_op import ( + can_convert_to_tracable_parameter, + new_parameter_placeholder, + tracable_create_parameter, +) +from ..device_interface import get_registered_device_interfaces +from ..exc import raise_observed_exception, unimplemented +from ..guards import GuardBuilder, install_guard +from ..source import ( + AttrSource, + CallFunctionNoArgsSource, + SyntheticLocalSource, + TorchSource, +) +from ..utils import ( + check_unspec_or_constant_args, + guard_if_dyn, + has_torch_function, + hashable, + is_wrapper_or_member_descriptor, + product, + proxy_args_kwargs, + unwrap_if_wrapper, +) +from .base import raise_type_error_exc, typestr, VariableTracker +from .ctx_manager import ( + AutocastModeVariable, + ProfilerContextVariable, + TorchFunctionDisableVariable, +) +from .dicts import ConstDictVariable +from .distributed import DistributedVariable, ProcessGroupVariable +from .functions import bind_args_cached, NestedUserFunctionVariable +from .lists import ListVariable, TupleVariable +from .torch_function import ( + can_dispatch_torch_function, + dispatch_torch_function, + TensorWithTFOverrideVariable, + TorchFunctionModeStackVariable, +) + + +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + +try: + from torch.distributed.fsdp._fully_shard import _fsdp_param_group +except ModuleNotFoundError: + _fsdp_param_group = None # type: ignore[assignment] + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +log = logging.getLogger(__name__) + +supported_ctx_manager_classes = dict.fromkeys( + [ + torch.profiler.profiler.profile, + torch.autograd.forward_ad._set_fwd_grad_enabled, + torch.autograd.forward_ad.dual_level, + torch.autograd.profiler.profile, + torch.autograd.profiler.record_function, + torch._C.DisableTorchFunctionSubclass, + torch._C.DisableTorchFunction, + torch._functorch.vmap.vmap_increment_nesting, + torch._functorch.eager_transforms.grad_increment_nesting, + torch._functorch.eager_transforms.jvp_increment_nesting, + torch._functorch.eager_transforms.enable_inplace_requires_grad, + torch.amp.autocast_mode.autocast, + torch.autograd.grad_mode.enable_grad, + torch.autograd.grad_mode.inference_mode, + torch.autograd.grad_mode.no_grad, + torch.autograd.grad_mode.set_grad_enabled, + torch.autograd.graph.disable_saved_tensors_hooks, + torch.cpu.amp.autocast_mode.autocast, + torch.cuda.amp.autocast_mode.autocast, + torch.fx.traceback.annotate, + torch.fx.traceback.annotate.__wrapped__, # type: ignore[attr-defined] + # We'll let Dynamo inline into the contextlib part of these context + # manager instances, all the way till it invokes the wrapped function + # itself (at which point we wrap it back to special context manager + # VTs). + # + # This allows us to support calling functions decorated with these + # context managers, without much extra effort or code dup. + torch.nn.attention.sdpa_kernel.__wrapped__, # type: ignore[attr-defined] + ] +) + + +REWRITE_OPS_TO_TENSOR_SIZE_METHOD = dict.fromkeys( + [ + torch._shape_as_tensor, + ] +) + +constant_fold_functions_need_guards = [ + torch.accelerator.current_device_index, + torch.accelerator.current_accelerator, + torch.cuda.current_device, + torch.cuda.is_initialized, + torch.xpu.current_device, + torch.xpu.is_initialized, +] + +constant_fold_functions = [ + torch._assert, + torch._utils._get_device_index, + torch._C._get_cublas_allow_tf32, + torch._C._is_any_autocast_enabled, + torch.accelerator.is_available, + torch.cuda.get_device_properties, + torch.cuda.is_available, + torch.distributed.is_available, + torch.get_autocast_dtype, + torch.get_autocast_gpu_dtype, + torch.get_default_dtype, + torch.is_autocast_cache_enabled, + torch.is_autocast_cpu_enabled, + torch.is_autocast_enabled, + torch.is_complex, + torch.is_floating_point, + torch.nn.functional._Reduction.get_enum, # type: ignore[attr-defined] + torch.promote_types, + torch._C._get_privateuse1_backend_name, + torch.autograd._is_checkpoint_valid, + torch.xpu.get_device_properties, + torch.xpu.is_available, +] + constant_fold_functions_need_guards +if torch.distributed.is_available(): + constant_fold_functions.extend( + [ + torch.distributed.is_initialized, + torch.distributed.get_rank, + torch.distributed.get_world_size, + ] + ) +# Convert to dict for O(1) access times +constant_fold_functions_need_guards = dict.fromkeys(constant_fold_functions_need_guards) +constant_fold_functions = dict.fromkeys(constant_fold_functions) + + +@functools.cache +def tracing_state_functions() -> dict[Callable[[], Any], Optional[bool]]: + # Defined as a function to avoid circular import like torch.onnx + return { + torch.jit.is_scripting: False, + torch.jit.is_tracing: False, + torch._C._get_tracing_state: None, + torch.fx._symbolic_trace.is_fx_tracing: False, + torch.fx._symbolic_trace.is_fx_symbolic_tracing: False, + torch.onnx.is_in_onnx_export: False, + torch._dynamo.external_utils.is_compiling: True, + torch._utils.is_compiling: True, + torch.compiler.is_compiling: True, + torch.compiler.is_dynamo_compiling: True, + torch.compiler.is_exporting: True, + torch._dynamo.eval_frame._is_in_optimized_module: True, + # Look into https://github.com/pytorch/pytorch/pull/164721 why this is + # turned to True for Dynamo. + torch.nn.modules.activation._is_make_fx_tracing: True, + } + + +bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) + +dispatch_key_set_functions = { + torch._C._dispatch_keys, + torch._C._dispatch_tls_local_include_set, + torch._C._dispatch_tls_local_exclude_set, +} + + +@functools.cache +def get_overridable_functions(): + from itertools import chain + + from torch.overrides import get_overridable_functions as get_overridable_functions_ + + funcs = set(chain.from_iterable(get_overridable_functions_().values())) + more: set[Callable[..., Any]] = { + torch.ones, + torch.ones_like, + torch.zeros, + torch.zeros_like, + torch.empty, + torch.full, + } + funcs.update(more) + return funcs + + +class BaseTorchVariable(VariableTracker): + """common base for all torch.* functions, classes, modules and other things""" + + @classmethod + def create_with_source(cls, value, source): + if inspect.isclass(value): + install_guard(source.make_guard(GuardBuilder.CLASS_MATCH)) + elif inspect.ismodule(value): + install_guard(source.make_guard(GuardBuilder.MODULE_MATCH)) + elif inspect.isfunction(value): + install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) + elif inspect.isbuiltin(value) or isinstance( + value, (torch._ops.OpOverload, torch._ops.OpOverloadPacket) + ): + install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) + elif is_wrapper_or_member_descriptor(value) or isinstance( + value, torch._dynamo.compiled_autograd.Op + ): + # Dont need to guard on wrappers + pass + else: + install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) + return cls(value, source=source) + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def reconstruct(self, codegen: "PyCodegen"): + try: + name = f"{self.value.__module__}.{self.value.__name__}" + except Exception: + name = f"torch_obj_{id(self.value)}" + unique_var_name = "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name) + codegen.extend_output( + codegen.setup_globally_cached(unique_var_name, self.value) + ) + + def as_proxy(self): + return self.value + + def as_python_constant(self): + return self.value + + def call_obj_hasattr(self, tx: "InstructionTranslator", name): + result = hasattr(self.value, name) + return variables.ConstantVariable.create(result) + + def can_constant_fold_through(self): + if self.value in constant_fold_functions: + return True + + if ( + self.value is torch.autograd._profiler_enabled + and config.constant_fold_autograd_profiler_enabled + ): + # The relevant flag is enabled only for export. One might wonder + # why? + # + # Actually we would like to not graph break even in the case of + # Dynamo. But there is a weird-unsolved bug with Kineto + Dynamo + # when there are distributed jobs that lead to NCCL timeouts. This + # bug is a rare edege case, but we have not been able to root cause + # it yet. See https://www.internalfb.com/sevmanager/view/560336 for + # more details. + # + # So is this safe for export? Yes, for export, we do not anticipate + # JIT tracing in distributed job training, and the weird edge-case + # interaction with Kineto is not a valid usecase. So, this is ok. + return True + + return getattr(self.value, "__module__", None) == "math" + + +class TorchCtxManagerClassVariable(BaseTorchVariable): + """Points to a context manager class in torch.* that dynamo has implementations""" + + def __repr__(self) -> str: + return f"TorchCtxManagerClassVariable({self.value})" + + @staticmethod + def is_matching_cls(value): + # Unwrap if it's a functools.lru_cache wrapper + value = unwrap_if_wrapper(value) + # We can't do isinstance(value, type) check because some ctx managers + # are implemented as a function decorated by contextlib.contextmanager, + # E.g., torch._functorch.vmap.vmap_increment_nesting. + return ( + # Context manager type or function with @contextmanager is callable + callable(value) + and ( + hashable(value) # accesses value.__hash__() + and value in supported_ctx_manager_classes + ) + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ( + DisabledSavedTensorsHooksVariable, + DualLevelContextManager, + FSDPParamGroupUseTrainingStateVariable, + FxTracebackAnnotateVariable, + GradIncrementNestingCtxManagerVariable, + GradInplaceRequiresGradCtxManagerVariable, + GradModeVariable, + InferenceModeVariable, + JvpIncrementNestingCtxManagerVariable, + SDPAKernelVariable, + SetFwdGradEnabledContextManager, + StreamVariable, + VmapIncrementNestingCtxManagerVariable, + ) + + if self.value is torch.no_grad: + if len(args) == 1 and isinstance( + args[0], variables.functions.BaseUserFunctionVariable + ): + ctx = GradModeVariable.create(tx, False) + return ctx.call_function(tx, args, kwargs) + else: + return GradModeVariable.create(tx, False) + elif self.value is torch.enable_grad: + if len(args) == 1 and isinstance( + args[0], variables.functions.BaseUserFunctionVariable + ): + ctx = GradModeVariable.create(tx, True) + return ctx.call_function(tx, args, kwargs) + return GradModeVariable.create(tx, True) + elif self.value is torch.set_grad_enabled and len(args) == 1: + return GradModeVariable.create( + tx, args[0].as_python_constant(), initialized=True + ) + elif self.value is torch.inference_mode: + assert len(args) <= 1 and len(kwargs) == 0 + inf_mode = args[0].as_python_constant() if len(args) == 1 else True + return InferenceModeVariable.create(tx, inf_mode) + elif self.value in ( + torch.fx.traceback.annotate, + torch.fx.traceback.annotate.__wrapped__, # type: ignore[attr-defined] + ): + assert len(args) <= 1 and len(kwargs) == 0 + return FxTracebackAnnotateVariable( + args[0].as_python_constant(), source=self.source + ) + elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream): + from torch._dynamo.variables.builder import wrap_fx_proxy_cls + + return wrap_fx_proxy_cls( + StreamVariable, + tx, + tx.output.create_proxy( + "call_function", + self.value, + (), + {}, + ), + ) + elif self.value in ( + torch.amp.autocast_mode.autocast, + torch.cuda.amp.autocast, + torch.cpu.amp.autocast, + ): + # pyrefly: ignore [bad-argument-type] + return AutocastModeVariable.create(self.value, args, kwargs) + elif self.value in ( + # NOTE any class added here must align with the semantic + # requirements of `ProfilerContextVariable`. + torch.profiler.profile, + torch.profiler.record_function, + torch.autograd.profiler.profile, + torch.autograd.profiler.record_function, + ): + warning_once(log, "Profiler function %s will be ignored", self.value) + return ProfilerContextVariable() + elif ( + self.value is torch._C.DisableTorchFunctionSubclass + or self.value is torch._C.DisableTorchFunction + ): + assert not (args or kwargs) + return TorchFunctionDisableVariable.create( + tx, only_subclass=self.value is torch._C.DisableTorchFunctionSubclass + ) + elif self.value is torch._functorch.vmap.vmap_increment_nesting: + assert len(args) == 2 + return VmapIncrementNestingCtxManagerVariable.create( + tx, + args, + ) + elif self.value is torch._functorch.eager_transforms.jvp_increment_nesting: + assert len(args) == 0 + return JvpIncrementNestingCtxManagerVariable.create(tx) + elif self.value is torch.autograd.forward_ad._set_fwd_grad_enabled: + assert len(args) == 1 + return SetFwdGradEnabledContextManager.create( + tx, + [guard_if_dyn(x) for x in args], + ) + elif self.value is torch.autograd.forward_ad.dual_level: + assert len(args) == 0 + return DualLevelContextManager.create(tx) + elif self.value is torch._functorch.eager_transforms.grad_increment_nesting: + assert len(args) == 0 + return GradIncrementNestingCtxManagerVariable.create(tx) + elif ( + self.value is torch._functorch.eager_transforms.enable_inplace_requires_grad + ): + assert len(args) == 1 + return GradInplaceRequiresGradCtxManagerVariable.create( + tx, + [guard_if_dyn(x) for x in args], + ) + elif self.value is torch.autograd.graph.disable_saved_tensors_hooks: + assert len(args) == 1 + return DisabledSavedTensorsHooksVariable.create( + tx, args[0].as_python_constant() + ) + elif ( + _fsdp_param_group is not None + and self.value is _fsdp_param_group.FSDPParamGroup.use_training_state + ): + assert len(args) == 2 + return FSDPParamGroupUseTrainingStateVariable.create( + tx, args[0], args[1].as_python_constant() + ) + elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined] + name_to_arg_map = bind_args_cached( + # pyrefly: ignore[bad-argument-type] + self.value, + tx, + self.source, + args, + kwargs, + ) + backends = name_to_arg_map["backends"].as_python_constant() + set_priority = name_to_arg_map["set_priority"].as_python_constant() + return SDPAKernelVariable.create(tx, backends, set_priority) + + return super().call_function(tx, args, kwargs) + + +class TorchInGraphFunctionVariable(BaseTorchVariable): + """Points to a torch function/method that should be put in FX graph""" + + def __init__(self, value, nonstrict_traceable=None, **kwargs) -> None: + super().__init__(value, **kwargs) + from ..trace_rules import is_nonstrict_trace_callable + + if nonstrict_traceable is None: + nonstrict_traceable = is_nonstrict_trace_callable(value) + self.nonstrict_traceable = nonstrict_traceable + + def __repr__(self) -> str: + return f"TorchInGraphFunctionVariable({self.value}, nonstrict_traceable={self.nonstrict_traceable})" + + def get_function(self): + return self.value + + @staticmethod + @functools.cache + def _get_handlers(): + """Build a dict from function -> method to handle it so that we are O(1) + in terms of the number of function with special handling.""" + handlers = {} + + def register(*fns): + def _register(handler): + for fn in fns: + assert fn not in handlers, fn + handlers[fn] = handler + return handler + + assert callable(fns[0]) + return _register + + from torch.backends.cuda import SDPAParams + + from . import ( + ConstantVariable, + DeterministicAlgorithmsVariable, + GradModeVariable, + StreamContextVariable, + SymNodeVariable, + TensorVariable, + UserDefinedObjectVariable, + ) + from .builder import wrap_fx_proxy, wrap_fx_proxy_cls + + @register(*tracing_state_functions()) + def handle_tracing_state_functions( + self, tx: "InstructionTranslator", *args, **kwargs + ): + assert not args and not kwargs + # See: https://github.com/pytorch/pytorch/issues/110765 + if self.value in ( + torch._utils.is_compiling, + torch._dynamo.external_utils.is_compiling, + torch.compiler.is_compiling, + torch.compiler.is_dynamo_compiling, + torch.compiler.is_exporting, + torch._dynamo.eval_frame._is_in_optimized_module, + ): + tx.mark_inconsistent_side_effects() + return ConstantVariable.create(tracing_state_functions()[self.value]) + + @register(*dispatch_key_set_functions) + def handle_dispatch_key_set_functions( + self, tx: "InstructionTranslator", *args, **kwargs + ): + assert not kwargs + if self.value is torch._C._dispatch_keys: + assert len(args) == 1 + assert args[0].is_tensor() + example_value = args[0].proxy.node.meta["example_value"] + dks = self.value(example_value) + # Remove Python and PythonTLSSnapshot from the dispatch key set, + # as they originate from FakeTensor propagation. + # This should only be done if the example_value is a FakeTensor. + # However, if tensor subclasses are present, + # it is reasonable for Python to remain in the dispatch key set. + if isinstance(example_value, torch._subclasses.FakeTensor): + dks = ( + dks + - torch._C.DispatchKeySet(torch._C.DispatchKey.Python) + - torch._C.DispatchKeySet( + torch._C.DispatchKey.PythonTLSSnapshot + ) + ) + return DispatchKeySetVariable.create(dks) + else: + assert not args + return DispatchKeySetVariable.create(self.value()) + + @register(torch.overrides.get_default_nowrap_functions.__wrapped__) + def handle_get_default_nowrap_functions( + self, tx: "InstructionTranslator", *args, **kwargs + ): + # [Note: __torch_function__] we return empty here because we restrict + # the set of functions that we trace __torch_function__ on to + # functions outside of the actual set. Implementing this properly will require implementing + # some variable types to track and compare tensor getset descriptors + return VariableTracker.build( + tx, torch.overrides.get_default_nowrap_functions() + ) + + @register(torch.ops.inductor.accumulate_grad_.default) + def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs): + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.accumulate_grad), args, kwargs + ) + + @register(math.radians) + def handle_radians(self, tx: "InstructionTranslator", *args, **kwargs): + if not check_unspec_or_constant_args(args, kwargs): + # Use polyfill to convert math.radians(x) into math.pi * x / 180.0 + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.radians), args, kwargs + ) + + if hasattr(math, "fma"): # Python 3.13+ + + @register(math.fma) + def handle_fma(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) != 3 or kwargs: + return None + + if all(arg.is_tensor() for arg in args): + x, y, z = args + addcmul_fn = TorchInGraphFunctionVariable(torch.addcmul) + return addcmul_fn.call_function(tx, [z, x, y], {}) + + # Use math.fma if constants + return None + + @register(torch.is_inference_mode_enabled) + def handle_is_inference_mode_enabled(self, tx: "InstructionTranslator"): + unimplemented( + gb_type="Encountered torch.is_inference_mode_enabled during tracing", + context="", + explanation="torch.is_inference_mode_enabled() is not supported", + hints=[ + *graph_break_hints.FUNDAMENTAL, + *graph_break_hints.INFERENCE_MODE, + ], + ) + + @register(torch.is_tensor, torch.overrides.is_tensor_like) + def handle_is_tensor(self, tx: "InstructionTranslator", arg): + if arg.is_tensor() or ( + self.value is torch.overrides.is_tensor_like + and isinstance(arg, UserDefinedObjectVariable) + and hasattr(arg.value, "__torch_function__") + ): + return ConstantVariable.create(True) + else: + return ConstantVariable.create(False) + + @register( + torch.is_floating_point, + torch.is_complex, + ) + def handle_is_floating_point(self, tx: "InstructionTranslator", input): + input_arg = input + if input_arg.is_tensor() and input_arg.dtype is not None: + if self.value is torch.is_floating_point: + return ConstantVariable.create(input_arg.dtype.is_floating_point) + elif self.value is torch.is_complex: + return ConstantVariable.create(input_arg.dtype.is_complex) + else: + raise AssertionError(f"calling {self.value}") + + @register(torch.numel) + def handle_numel(self, tx: "InstructionTranslator", input): + if input.is_tensor() and input.valid_size(): + return ConstantVariable.create(product(input.size)) + elif input.is_tensor(): + # Workaround dynamic shapes issue + return input.call_method(tx, "numel", [], {}) + + @register(torch.compile) + def handle_torch_compile(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) == 1: + # torch.compile is a no-op in dynamo + return args[0] + + unimplemented( + gb_type="torch.compile call with > 1 args", + context=f"args={args}, kwargs={kwargs}", + explanation="Attempted to call `torch.compile` with > 1 args. Dynamo does not support this.", + hints=[ + "Remove the torch.compile call or its additional args.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + @register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD) + def handle_tensor_size_rewrites(self, tx: "InstructionTranslator", input): + assert input.is_tensor() + return input.call_method(tx, "size", [], {}) + + @register( + torch.nn.modules.utils._single, + torch.nn.modules.utils._pair, + torch.nn.modules.utils._triple, + torch.nn.modules.utils._quadruple, + torch.nn.modules.utils._ntuple, + ) + def handle_ntuple(self, tx: "InstructionTranslator", *args, **kwargs): + return self._call_ntuple(tx, args, kwargs) + + @register(torch.is_grad_enabled) + def handle_is_grad_enabled(self, tx): + install_guard(GradModeVariable._guards_singleton) + return ConstantVariable.create(torch.is_grad_enabled()) + + @register(torch.use_deterministic_algorithms) + def handle_use_deterministic_algorithms( + self, tx: "InstructionTranslator", mode, warn_only=False + ): + # pyrefly: ignore [missing-attribute] + if warn_only and warn_only.as_python_constant(): + unimplemented( + gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)", + context=f"mode={mode}, warn_only={warn_only}", + explanation="Dynamo does not support this.", + hints=[ + "Remove param warn_only in function call torch.use_deterministic_algorithms.", + *graph_break_hints.SUPPORTABLE, + ], + ) + return DeterministicAlgorithmsVariable.create(tx, mode.as_python_constant()) + + @register(torch.are_deterministic_algorithms_enabled) + def handle_are_deterministic_algorithms_enabled(self, tx): + install_guard(DeterministicAlgorithmsVariable._guards_singleton) + return ConstantVariable.create(torch.are_deterministic_algorithms_enabled()) + + @register(torch._C._is_torch_function_enabled) + def handle_is_torch_function_enabled(self, tx): + install_guard(TorchFunctionDisableVariable._guards_singleton) + # see comment on SymbolicTorchFunctionState class as to why + # this is not a bug + return ConstantVariable.create( + tx.symbolic_torch_function_state.torch_function_subclass_enabled + ) + + @register(torch._C._is_torch_function_all_disabled) + def handle_is_torch_function_all_disabled(self, tx): + install_guard(TorchFunctionDisableVariable._guards_singleton) + return ConstantVariable.create( + not tx.symbolic_torch_function_state.torch_function_mode_enabled + ) + + @register( + torch.overrides.has_torch_function, + torch.overrides.has_torch_function_variadic, + torch.overrides.has_torch_function_unary, + ) + def handle_has_torch_function(self, tx: "InstructionTranslator", *args): + elems = ( + args[0].unpack_var_sequence(tx) + if len(args) == 1 and isinstance(args[0], TupleVariable) + else args + ) + return ConstantVariable.create( + any(has_torch_function(x) for x in elems), + ) + + @register( + *dict.fromkeys( # remove duplicates + device_interface.stream + for _, device_interface in get_registered_device_interfaces() + ) + ) + def handle_device_interface_stream(self, tx: "InstructionTranslator", stream): + return StreamContextVariable.create(tx, stream) + + @register(torch.from_numpy) + def handle_from_numpy(self, tx: "InstructionTranslator", *args): + if not config.trace_numpy: + unimplemented( + gb_type="call `torch.from_numpy` with `torch._dynamo.config.trace_numpy=False`", + context=f"trace_numpy={config.trace_numpy}", + explanation=( + "Attempted to call `torch.from_numpy` with config " + "`torch._dynamo.config.trace_numpy` set to `False`." + ), + hints=[ + "Change `torch._dynamo.config.trace_numpy` to `True`.", + ], + ) + if not np: + unimplemented( + gb_type="`torch.from_numpy` with NumPy unavailable", + context="", + explanation="Attempted to call `torch.numpy` but NumPy could not be imported.", + hints=[ + "Check NumPy version and installation in your environment.", + *graph_break_hints.USER_ERROR, + ], + ) + return wrap_fx_proxy_cls( + target_cls=TensorVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.as_tensor, + *proxy_args_kwargs(args, {}), + ), + example_value=None, + ) + + @register(torch.jit.annotate) + def handle_jit_annotate(self, tx: "InstructionTranslator", the_type, the_value): + return the_value + + @register(torch.backends.cudnn.is_acceptable) + def handle_cudnn_is_acceptable( + self, tx: "InstructionTranslator", tensor, *extra + ): + # is_acceptable(tensor) returns true if + # (a) tensor dtype/device are supported by cudnn + # (b) cudnn is available + # (c) some initialization has completed + # technically, it depends on some global state from (c) (torch.backends.cudnn.__cudnn_version) + assert not extra, "Expect 1 input to cudnn.is_acceptable" + assert tensor.is_tensor(), ( + "Expect input to cudnn.is_acceptable to be a tensor" + ) + tensor_inp = torch.tensor(0, dtype=tensor.dtype, device=tensor.device) + return ConstantVariable.create( + torch.backends.cudnn.is_acceptable(tensor_inp) + ) + + @register(torch.utils.hooks.BackwardHook) + def handle_backward_hook(self, tx: "InstructionTranslator", *args, **kwargs): + return variables.BackwardHookVariable.create(tx, *args, **kwargs) + + @register(torch.nn.Parameter) + def handle_parameter(self, tx: "InstructionTranslator", *args, **kwargs): + return self.call_nn_parameter(tx, *args, **kwargs) + + @register(torch.ops.aten.sym_size, torch.ops.aten.sym_size.int) + def handle_sym_size(self_, tx, self, dim=None): + # we see this when retracing already traced code + if dim is not None: + return self.call_method(tx, "size", [dim], {}) + + @register(torch.ops.aten.sym_stride, torch.ops.aten.sym_stride.int) + def handle_sym_stride(self_, tx, self, dim=None): + if dim is not None: + return self.call_method(tx, "stride", [dim], {}) + + @register(torch.addcdiv) + def handle_addcdiv(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) == 3 and "value" in kwargs and len(kwargs) == 1: + # decompose addcdiv into constituent ops, prevents a graph break due to converting + # value to a scalar + result = TorchInGraphFunctionVariable(torch.div).call_function( + tx, [*args[1:]], {} + ) + result = TorchInGraphFunctionVariable(torch.mul).call_function( + tx, [result, kwargs["value"]], {} + ) + return TorchInGraphFunctionVariable(torch.add).call_function( + tx, [args[0], result], {} + ) + + @register(torch.full) + def handle_full(self, tx, size, fill_value, **kwargs): + if fill_value.is_tensor(): + # Decompose: create empty tensor and fill it + # This avoids the scalar extraction at compile time + empty_result = TorchInGraphFunctionVariable(torch.empty).call_function( + tx, [size], kwargs + ) + # Call fill_ method on the empty tensor + return empty_result.call_method(tx, "fill_", [fill_value], {}) + + @register(torch._foreach_lerp_) + def handle_inplace_foreach_lerp_scalar( + _, tx: "InstructionTranslator", *args, **kwargs + ): + if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs: + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.foreach_lerp_inplace), + args, + kwargs, + ) + + @register(torch._foreach_pow) + def handle_foreach_pow_scalar(_, tx: "InstructionTranslator", *args, **kwargs): + # In eager it's more performant to call item() from within the C op implementation + # in compile, it's more performant to not graph break. + if len(args) == 2 and args[0].is_tensor() and not kwargs: + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.foreach_pow_scalar), + args, + kwargs, + ) + + @register(torch._assert) + def handle_assert(self, tx: "InstructionTranslator", condition, message): + if (condition.is_python_constant() and condition.as_python_constant()) or ( + isinstance(condition, variables.SymNodeVariable) + and condition.evaluate_expr() + ): + return ConstantVariable(None) + + @register(SDPAParams) + def handle_sdpa_params(self, tx: "InstructionTranslator", *args, **kwargs): + return wrap_fx_proxy( + tx, + proxy=tx.output.create_proxy( + "call_function", + torch._C._SDPAParams, + *proxy_args_kwargs(args, kwargs), + ), + param_vars=args, + ) + + if DistributedVariable.is_available(): + from torch.distributed.distributed_c10d import ( + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + _resolve_group_name_by_ranks_and_tag, + get_process_group_ranks, + ) + from torch.distributed.tensor import DTensor + + @register( + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + get_process_group_ranks, + _resolve_group_name_by_ranks_and_tag, + ) + def handle_constant_processgroup_functions( + self, tx: "InstructionTranslator", *args + ): + # because the input is a "ProcessGroupVariable", we'll be guarding on its + # ID_MATCH based on how it was constructed. + + # We desugar it at trace-time into ranks by directly calling util + # bake the result into the trace + if len(args) == 1: + # group or group name + assert ( + isinstance(args[0], ProcessGroupVariable) + or args[0].is_python_constant() + ) + elif len(args) == 2: + # ranks + tag + assert ( + isinstance(args[0], ListVariable) + and args[1].is_python_constant() + ) + else: + raise AssertionError( + f"Invalid group value ({args}) for constant pg " + f"function {self.value}" + ) + args_as_value = [arg.as_python_constant() for arg in args] + invocation_result = self.value(*args_as_value) + + # Note - while we *could* cook up sources around invocations, like a FunctionSource + # the space of invoking functions in the middle of the guard chain is very iffy. As such, + # guard propagation via options is the best we can do. + return VariableTracker.build(tx, invocation_result) + + @register(DTensor.from_local) + def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs): + # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function + # and rewrite args to have only proxyable args, then insert call_function + placements_vt = kwargs.get("placements") + + if placements_vt is None and len(args) >= 3: + placements_vt = args[2] + + if placements_vt is None: + placements_vt = ConstantVariable.create(None) + elif isinstance(placements_vt, variables.UserDefinedObjectVariable): + placements_vt = variables.BuiltinVariable(tuple).call_function( + tx, [placements_vt], {} + ) + + new_args = list(args) + if len(new_args) >= 3: + new_args[2] = placements_vt + elif kwargs.get("placements") is not None: + kwargs["placements"] = placements_vt + + args_as_value = [x.as_python_constant() for x in new_args[1:]] + kwargs_as_value = { + k: v.as_python_constant() + for k, v in kwargs.items() + if k not in ["shape", "stride"] + } + + kwargs_to_be_proxied = { + k: kwargs[k] for k in ["shape", "stride"] if k in kwargs + } + + def fn_with_prim_types(x, shape=None, stride=None): + return self.value( + x, *args_as_value, **kwargs_as_value, shape=shape, stride=stride + ) + + # attach the same function name for better debugging + fn_with_prim_types.__name__ = "prim " + self.value.__name__ + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_with_prim_types, + *proxy_args_kwargs( + [args[0]], + kwargs_to_be_proxied, + ), + ), + ) + + @register(torch.nested.nested_tensor) + def handle_nested_tensor( + self, + tx: "InstructionTranslator", + tensor_list=None, + *args, + layout=None, + **kwargs, + ): + from .lists import BaseListVariable + + if layout and layout.is_constant_match(torch.strided): + unimplemented( + gb_type="Attempted to use strided NestedTensor", + context=f"layout={layout}", + explanation="Dynamo does not support this.", + hints=[ + "Change layout=torch.jagged.", + *graph_break_hints.SUPPORTABLE, + ], + ) + if not isinstance(tensor_list, BaseListVariable): + unimplemented( + gb_type="Attempted to use `nested_tensor` with non-list input", + context=f"tensor_list={tensor_list}", + explanation="Dynamo does not support this.", + hints=[ + "Change `nested_tensor` with list input.", + *graph_break_hints.USER_ERROR, + ], + ) + + @register(torch.nn.functional.one_hot) + def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) + len(kwargs) == 1 or ( + len(args) == 2 and args[1].is_constant_match(-1) + ): + unimplemented( + gb_type="Attempted to use `torch.nn.functional.one_hot` with data-dependent output shape", + context=f"args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + "Explicitly set the `num_classes` param of the function call " + "`torch.nn.functional.one_hot` to something other than -1.", + ], + ) + + @register(torch.fx.experimental.symbolic_shapes.guard_size_oblivious) + def handle_guard_size_oblivious(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + # TODO: this probably should be folded somewhere else but I'm not sure where + # TODO: some of the other symbolic_shapes special tools can also get this treatment too + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_size_oblivious( + expr.sym_num + ) + ) + elif expr.is_python_constant(): + return expr + + @register(torch.fx.experimental.symbolic_shapes.guard_or_true) + def handle_guard_or_true(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + # TODO: this probably should be folded somewhere else but I'm not sure where + # TODO: some of the other symbolic_shapes special tools can also get this treatment too + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_or_true(expr.sym_num) + ) + elif expr.is_python_constant(): + return expr + + @register(torch.fx.experimental.symbolic_shapes.guard_or_false) + def handle_guard_or_false(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + # TODO: this probably should be folded somewhere else but I'm not sure where + # TODO: some of the other symbolic_shapes special tools can also get this treatment too + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_or_false(expr.sym_num) + ) + elif expr.is_python_constant(): + return expr + + @register(torch.fx.experimental.symbolic_shapes.statically_known_false) + def handle_statically_known_false(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.statically_known_false( + expr.sym_num + ) + ) + elif expr.is_python_constant(): + return expr + + @register(torch.fx.experimental.symbolic_shapes.guard_scalar) + def guard_scalar(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + val = expr.sym_num + elif expr.is_python_constant(): + val = expr.as_python_constant() + else: + unimplemented( + gb_type="torch.fx.experimental.symbolic_shapes.guard_scalar branch not supported", + context=f"expr: {expr}", + explanation="Expected `expr` to be a symbolic variable or constant.", + hints=[], + ) + return variables.ConstantVariable.create( + # pyrefly: ignore [bad-argument-type, unbound-name] + torch.fx.experimental.symbolic_shapes.guard_scalar(val) + ) + + @register(torch.fx.experimental.symbolic_shapes.statically_known_true) + def handle_statically_known_true(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.statically_known_true( + expr.sym_num + ) + ) + elif expr.is_python_constant(): + return expr + + @register(torch.fx.experimental.symbolic_shapes.sym_and) + def handle_sym_and(self, tx: "InstructionTranslator", *terms): + if all(isinstance(x, SymNodeVariable) for x in terms): + return SymNodeVariable.create( + tx, + torch.fx.experimental.symbolic_shapes.sym_and( + *(x.as_proxy() for x in terms) + ), + sym_num=None, + ) + + @register(torch.fx.experimental.symbolic_shapes.sym_or) + def handle_sym_or(self, tx: "InstructionTranslator", *terms): + if all(isinstance(x, SymNodeVariable) for x in terms): + return SymNodeVariable.create( + tx, + torch.fx.experimental.symbolic_shapes.sym_or( + *(x.as_proxy() for x in terms) + ), + sym_num=None, + ) + + @register(torch.fx.experimental.symbolic_shapes.has_static_value) + def handle_has_static_value(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + val = expr.sym_num + elif expr.is_python_constant(): + val = expr.as_python_constant() + else: + return + + return variables.ConstantVariable.create( + # pyrefly: ignore [bad-argument-type] + torch.fx.experimental.symbolic_shapes.has_static_value(val) + ) + + @register(torch._C._autograd._unsafe_set_version_counter) + def handle_unsafe_set_version_counter( + self, tx: "InstructionTranslator", *args, **kwargs + ): + from ..tensor_version_op import _unsafe_set_version_counter + + return TorchInGraphFunctionVariable( + _unsafe_set_version_counter + ).call_function(tx, [*args], kwargs) + + @register(torch._C._functorch.peek_interpreter_stack) + def handle_functorch_peek_interpreter_stack( + self, tx: "InstructionTranslator", *args, **kwargs + ): + # Wrap C++ interpreter (torch._C._functorch.CInterpreter) as UserDefinedObjectVariable, + # but Python interpreter (torch._functorch.pyfunctorch.FuncTorchInterpreter) as FuncTorchInterpreterVariable. + return UserDefinedObjectVariable( + torch._C._functorch.peek_interpreter_stack() + ) + + @register(torch._functorch.pyfunctorch.coerce_cinterpreter) + def handle_functorch_pyfunctorch_coerce_cinterpreter( + self, tx: "InstructionTranslator", *args, **kwargs + ): + cinterpreter = args[0].value + return FuncTorchInterpreterVariable( + torch._functorch.pyfunctorch.coerce_cinterpreter(cinterpreter) + ) + + @register(torch.tensor) + def handle_torch_tensor(self, tx: "InstructionTranslator", *args, **kwargs): + def check_any_unspec(x): + # NB: This includes UnspecializedPythonVariable + if x.is_tensor() or isinstance(x, SymNodeVariable): + return True + elif isinstance(x, (ListVariable, TupleVariable)): + return any(check_any_unspec(y) for y in x.items) + # TODO: there maybe other recursive structures you need to + # check + else: + return False + + data_arg = None + if args: + data_arg = args[0] + elif "data" in kwargs: + data_arg = kwargs["data"] + + # NB: OK to pass torch.tensor(tensor), this will trace fine + if ( + data_arg is not None + and not data_arg.is_tensor() + and check_any_unspec(data_arg) + ): + # This is slower and less canonical, so only use it if we + # have to + return TorchInGraphFunctionVariable(torch._refs.tensor).call_function( + tx, [*args], kwargs + ) + + @register(torch._C._pop_torch_function_stack) + def handle_pop_torch_function( + self, tx: "InstructionTranslator", *args, **kwargs + ): + assert not args and not kwargs + if not tx.symbolic_torch_function_state.mode_stack: + unimplemented( + gb_type="Attempted to pop from empty torch function mode stack", + context="", + explanation="Called `torch._C._pop_torch_function_stack` when torch function mode stack is empty.", + hints=[ + "Do not pop from empty torch function mode stack.", + *graph_break_hints.USER_ERROR, + ], + ) + TorchFunctionModeStackVariable.register_mutation(tx) + return tx.symbolic_torch_function_state.pop_torch_function_mode() + + @register(torch._C._push_on_torch_function_stack) + def handle_push_torch_function( + self, tx: "InstructionTranslator", *args, **kwargs + ): + if len(args) != 1 or kwargs: + raise_type_error_exc( + tx, + f"push_torch_function takes exactly one argument ({len(args)} given)", + ) + TorchFunctionModeStackVariable.register_mutation(tx) + tx.symbolic_torch_function_state.push_torch_function_mode(args[0]) + return ConstantVariable.create(None) + + @register(torch._C._len_torch_function_stack) + def handle_len_torch_function( + self, tx: "InstructionTranslator", *args, **kwargs + ): + if args or kwargs: + raise_type_error_exc(tx, "len_torch_function_stack takes no arguments") + return ConstantVariable.create( + len(tx.symbolic_torch_function_state.mode_stack) + ) + + @register(torch._C._get_function_stack_at) + def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) != 1 or kwargs: + raise_type_error_exc( + tx, + f"get_function_stack_at takes exactly one argument ({len(args)} given)", + ) + ind = args[0].as_python_constant() + assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) + return tx.symbolic_torch_function_state.mode_stack[ind] + + @register(torch.get_device_module.__wrapped__) + def handle_get_device_module(self, tx, *args, **kwargs): + if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs): + unimplemented( + gb_type="improper torch.get_device_module arguments", + context=f"args={args}, kwargs={kwargs}", + explanation="torch.get_device_module accepts 1 optional argument `device`", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + try: + if kwargs: + device = kwargs["device"].as_python_constant() + elif args: + device = args[0].as_python_constant() + else: + device = None + module = torch.get_device_module(device) + except Exception as e: + unimplemented( + gb_type="bad device argument to torch.get_device_module", + context=f"args={args}, kwargs={kwargs}", + explanation="Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", + hints=[*graph_break_hints.USER_ERROR], + from_exc=e, + ) + + # need to guard only on no-arg get_device_module + # pyrefly: ignore [unbound-name] + if device is None: + source = CallFunctionNoArgsSource(self.source) + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + # assumes `module` is in the form `torch.xyz` + new_source = AttrSource( + TorchSource(), + # pyrefly: ignore [unbound-name] + module.__name__.rsplit(".", maxsplit=1)[-1], + ) + # pyrefly: ignore [unbound-name] + return VariableTracker.build(tx, module, new_source) + + @register(torch.accelerator.current_stream, torch.cuda.current_stream) + def handle_current_stream(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs): + unimplemented( + gb_type="unsupported arguments to torch.accelerator.current_stream", + context=f"args={args}, kwargs={kwargs}", + explanation="torch.accelerator.current_stream accepts one optional argument `device`", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + try: + if kwargs: + device = torch.device(kwargs["device"].as_python_constant()) + elif args: + device = torch.device(args[0].as_python_constant()) + else: + device = None + + return tx.symbolic_stream_state.cur_stream(device) + except Exception as e: + unimplemented( + gb_type="bad device argument to torch.accelerator.current_stream", + context=f"args={args}, kwargs={kwargs}", + explanation="Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", + hints=[*graph_break_hints.USER_ERROR], + from_exc=e, + ) + + @register(torch.set_default_device) + def handle_set_default_device( + self, tx: "InstructionTranslator", *args, **kwargs + ): + # Today this is inserted in the graph, once TF mode + # handling is complete, we can trace the device context + # like any other TF mode and remove this special handling + # Insert the TF mode representing the device context at + # the bottom of the stack to match the eager semantics + # Running the graph will ensure that the DeviceContext mode is + # at the correct position in the stack + TorchFunctionModeStackVariable.register_mutation(tx) + if args[0].is_constant_none(): + TorchFunctionModeStackVariable.clear_default_device(tx) + else: + TorchFunctionModeStackVariable.register_device_context_insertion(tx) + + return ConstantVariable.create(None) + + @register(torch._check) + def handle_check(self, tx: "InstructionTranslator", *args, **kwargs): + predicate_vt = None + message_vt = None + + if args: + predicate_vt = args[0] + rest_args = args[1:] + else: + rest_args = () + + if predicate_vt is None and "cond" in kwargs: + predicate_vt = kwargs.pop("cond") + + if rest_args: + message_vt = rest_args[0] + elif "message" in kwargs: + message_vt = kwargs.pop("message") + + if predicate_vt is None: + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + (), + {}, + ), + ) + + message_eager = None + message_graph_proxy = None + if message_vt is not None: + if ( + not isinstance(message_vt, NestedUserFunctionVariable) + or message_vt.has_closure() + ): + unimplemented( + gb_type="Can't extract message from torch._check()", + context=str(message_vt), + explanation=( + "The second argument of torch._check() must be a function" + "defined within the torch.compile region" + "that does not reference a non-local variable." + ), + hints=[ + "Make sure the message function is defined in the torch.compile region.", + "Remove any closure variables, e.g. " + "remove references to closure variable `x` in `lambda: f'{x} failed check'`", + *graph_break_hints.SUPPORTABLE, + ], + ) + message_eager = message_vt.get_function() + + message_graph_proxy = tx.output.register_static_attr_and_return_proxy( + "_check_message", message_eager + ) + + if predicate_vt.is_python_constant(): + self.value(predicate_vt.as_python_constant(), message_eager) + return ConstantVariable.create(None) + + predicate_proxy = predicate_vt.as_proxy() + + proxy_args: tuple[Any, ...] + if message_graph_proxy is None: + proxy_args = (predicate_proxy,) + else: + proxy_args = (predicate_proxy, message_graph_proxy) + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + proxy_args, + {}, + ), + ) + + return handlers + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ConstantVariable, SymNodeVariable + from .builder import wrap_fx_proxy + + if self.nonstrict_traceable: + return self._call_nonstrict_traceable_function(tx, args, kwargs) + + if self.torch_function_override_enabled(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + + if self.can_constant_fold_through() and check_unspec_or_constant_args( + args, kwargs + ): + # constant fold functions need to be guarded. + if self.value in constant_fold_functions_need_guards: + assert self.source is not None + source = CallFunctionNoArgsSource(self.source) + install_guard(source.make_guard(GuardBuilder.EQUALS_MATCH)) + # constant fold + try: + return ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + except (OverflowError, TypeError, ValueError) as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + + if self.is_tensor_method(): + name = self.value.__name__ + # Guard against inplace view op on input tensor (not supported) + if args and args[0].is_tensor(): + tensor_var = args[0] + # Check if input tensor and inplace_view op specifically + if tensor_var.source is not None and hasattr(torch.ops.aten, name): + fn = getattr(torch.ops.aten, name) + if ( + hasattr(fn, "overloads") + and hasattr(fn, fn.overloads()[0]) + and torch.Tag.inplace_view + in getattr(fn, fn.overloads()[0]).tags + ): + unimplemented( + gb_type="Inplace op on input tensor", + context="", + explanation=f"Attempted to trace an inplace view op on input tensor {typestr(self.value)}.", + hints=[ + *graph_break_hints.SUPPORTABLE, + "Ensure you do not modify input tensor in place.", + ], + ) + return self.call_tensor_method(tx, args, kwargs) + + special_handler = self._get_handlers().get(self.value) + if special_handler: + result = special_handler(self, tx, *args, **kwargs) + if result: + return result + + any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) + + all_ints_or_floats = all( + isinstance(x, SymNodeVariable) or x.is_python_constant() for x in args + ) + if ( + getattr(self.value, "__module__", "") == "torch" + and self.value.__name__ in bin_ops + and any_symints_or_symfloats + and all_ints_or_floats + ): + msg = f"""\ +Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. +To support this behavior, we need to allow const-propping tensors that store symint data. +For now, dynamo will explicitly graph break when it encounters user code with this behavior. +""" + log.warning(msg) + unimplemented( + gb_type="Attempted to call torch in-graph function on only torch.SymInt arguments", + context=f"fn={self.value}, args={args}, kwargs={kwargs}", + explanation=( + f"Attempted to call {str(self.value)} (that should be put in the FX graph) on only torch.SymInt arguments. " + "Dynamo does not support this." + ), + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any_symints_or_symfloats: + torch_sym_op = f"_sym_{self.value.__name__}" + if getattr(self.value, "__module__", None) == "math" and hasattr( + torch, torch_sym_op + ): + fn_ = getattr(torch, torch_sym_op) + + # TODO for each of the following check on `out=` or `requires_grad=` + # variant torch ops, the original function could come from a user + # defined `@allow_in_graph` function as well, which doesn't have the + # same semantics as the torch ops. + + # Calling fake tensor propagation can mutate the out= tensor in + # tx.output.tracked_fakes. tracked_fakes are used to apply + # symbolic_shape guards. Mutating them destroys the information + # prior to tracing, which is essential for creating right + # guards. So save the shape now, and check later if it has + # changed. If it has, graph break. + saved_out_shapes = None + out_kwarg_vt = None + if "out" in kwargs: + out_kwarg_vt = kwargs["out"] + + # e.g., out=(t1, t2, ...) + if isinstance(out_kwarg_vt, (TupleVariable, ListVariable)): + saved_out_shapes = [] + for vt in out_kwarg_vt.items: + if vt.is_tensor(): + shape = vt.as_proxy().node.meta["example_value"].shape + else: + shape = None + saved_out_shapes.append(shape) + + # e.g., out=output_tensor + if out_kwarg_vt.is_tensor(): + saved_out_shapes = ( + out_kwarg_vt.as_proxy().node.meta["example_value"].shape + ) + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + *proxy_args_kwargs(args, kwargs), + ), + ) + + # Handle e.g., `torch.ones(10, requires_grad=True)` + if ( + tensor_variable.is_tensor() + and "requires_grad" in kwargs + and kwargs["requires_grad"].as_python_constant() + ): + unimplemented( + gb_type="Attempted to use tensor creation function with requires_grad=True", + context=f"fn={self.value}, args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + "Create the tensor outside the compiled region.", + "Do not set `requires_grad=True`.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + # Handle e.g., `torch.add(a, b, out=result)` + if saved_out_shapes is not None: + # out variants of torch operators like torch.sort and torch.sigmoid + # mutate the tensors in the out field. + # + # However, it's non-trivial to update all references of the old + # `TensorVariable` to the new one returned (`result_var`), so we + # take the conservative approach to graph break on size changes, and + # assume other cases can fall through soundly. + # + # Note that although these tensor variables would hold different + # proxies, the in-place mutation semantics is preserved in the FX + # graph, so we won't have correctness issues. + if isinstance(saved_out_shapes, list): + for out_tensor_vt, saved_out_shape in zip( + out_kwarg_vt.items, # type: ignore[union-attr] + saved_out_shapes, + ): + if saved_out_shape is None: + # This should be extremely rare, but it's kept for now + # until we invest in enforcing the `out=` kwarg for only + # torch methods. + continue + + assert out_tensor_vt.is_tensor() + fake_out = out_tensor_vt.proxy.node.meta["example_value"] + if saved_out_shape != fake_out.shape: + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented( + gb_type="Shape mismatch with out= list of tensor variants", + context=f"fn={self.value}, args={args}, kwargs={kwargs}", + explanation=( + f"Shape mismatch when calling {self.value} with `out=`. " + f"Provided `out=` shape: {saved_out_shape}. Actual shape: {fake_out.shape}." + ), + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + gb_type="Attempted to call op with non-contiguous `out=` list of tensors", + context=f"self.value={self.value}, args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + else: + assert out_kwarg_vt is not None and out_kwarg_vt.is_tensor() + assert "example_value" in out_kwarg_vt.as_proxy().node.meta + fake_out = out_kwarg_vt.as_proxy().node.meta["example_value"] + if saved_out_shapes != fake_out.shape: + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented( + gb_type="Shape mismatch with out= tensor variant", + context=f"fn={self.value}, args={args}, kwargs={kwargs}", + explanation=( + f"Shape mismatch when calling {self.value} with `out=`. " + f"Provided `out=` shape: {saved_out_shapes}. Actual shape: {fake_out.shape}." + ), + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + if not torch._prims_common.is_contiguous_or_false(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + gb_type="Attempted to call op with non-contiguous `out=` tensor", + context=f"self.value={self.value}, args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + return tensor_variable + + def _call_nonstrict_traceable_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + import torch._higher_order_ops.flat_apply as flat_apply + from torch._higher_order_ops.flat_apply import ( + func_to_graphable, + is_graphable_type, + ) + from torch._subclasses.fake_tensor import fake_tensor_tls + from torch.utils._pytree import tree_flatten + + from .base import AsPythonConstantNotImplementedError + from .builder import wrap_fx_proxy + + # 1. Convert `args, kwargs` into pytree-flattened proxy forms. + # + # Rather than reconstructing `args, kwargs` into python objects and + # then tree_flatten them, we just let Dynamo symbolically interpret + # `tree_flatten((args, kwargs))`. This saves us from having to + # worry about the reconstruction logic, side effects, and guards. + packed_input_vt = TupleVariable.build( + tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs)) + ) + out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type] + tx, [packed_input_vt], {} + ) + assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2 + flat_args_vts, input_spec_vt = out_vt.items + assert isinstance(flat_args_vts, ListVariable) + + # Handle the case when the input contains a non-graphable type. + for flat_arg_vt in flat_args_vts.items: + arg_type = flat_arg_vt.python_type() + if not is_graphable_type(arg_type): + type_name = flat_arg_vt.python_type().__qualname__ + unimplemented( + gb_type="Invalid input type for nonstrict_trace-ed function", + context=f"Encountered input of type <{type_name}>.", + explanation=( + "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, float) " + "or pytree containers of those are allowed as inputs. The provided argument contains " + "an unsupported type." + ), + hints=[ + "Use one of the following to register the type with pytree:\n" + "* `torch.utils._pytree.register_constant`\n" + "* `torch.utils._pytree.register_dataclass`\n" + "* `torch.utils._pytree.register_pytree_node`", + ], + ) + + # Since we checked with `is_graphable` above, `as_proxy` on the + # flat_arg VT should always work. + proxified_flat_args = [ + flat_arg_vt.as_proxy() for flat_arg_vt in flat_args_vts.items + ] + + # The downstream `flat_apply` call requires the input spec; however, + # the spec not a graphable type, so we still have to reconstruct it + # into a python object, and store it as a constant attribute on the + # fx graph. + try: + input_spec = input_spec_vt.as_python_constant() + except AsPythonConstantNotImplementedError as e: + typ = e.vt.python_type() + type_name = typ.__qualname__ + import torch.utils._pytree as pytree + + if pytree.is_constant_class(typ): + unimplemented( + gb_type="Input marked with `pytree.register_constant` constructed in the `torch.compile` region", + context=f"Input={input_spec_vt}, offending type <{type_name}>.", + explanation=( + "Calling a `nonstrict_trace`-ed function with an input that contains an object " + f"of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object " + "was constructed _inside_ the `torch.compile` region. This is not supported." + ), + hints=[ + "Construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.", + *graph_break_hints.SUPPORTABLE, + ], + from_exc=e, + ) + else: + unimplemented( + gb_type="Invalid use of pytree_flatten with nonstrict_trace-ed function", + context=f"Input={input_spec_vt}, offending type <{type_name}>.", + explanation=( + "Calling a `nonstrict_trace`-ed function where one of the inputs has been registered " + f"with a `pytree_flatten` that places an object of type <{type_name}> into the context." + ), + hints=[ + "Modifying the `pytree_flatten` to avoid placing the object into the context.", + f"Apply one of the following to <{type_name}>:\n" + "* `torch.utils._pytree.register_constant`\n" + "* `torch.utils._pytree.register_dataclass`\n" + "* `torch.utils._pytree.register_pytree_node`", + *graph_break_hints.SUPPORTABLE, + ], + from_exc=e, + ) + + fn = self.value + + def patched_fn(*args, **kwargs): + # This enables reads to global/captured tensors, and we'll just + # treat them as constants in the graph. Note that after + # AOTDispatcher, this logic would disappear. + old_val = fake_tensor_tls.allow_non_fake_inputs_override + fake_tensor_tls.allow_non_fake_inputs_override = True + try: + res = fn(*args, **kwargs) + finally: # reset even when `fn` raises + fake_tensor_tls.allow_non_fake_inputs_override = old_val + return res + + # `flat_apply` wants a TreeSpec for the function input. + _, f_spec = func_to_graphable(patched_fn) + + # TreeSpec isn't graphable, so we register the function and input + # specs as attributes on the graph module. + f_spec_proxy = tx.output.register_static_attr_and_return_proxy( + f"{fn.__name__}_spec", f_spec + ) + input_spec_proxy = tx.output.register_static_attr_and_return_proxy( + fn.__name__ + "_input_spec", + # pyrefly: ignore [unbound-name] + input_spec, + ) + f_spec_proxy.node.type = type(f_spec) + # pyrefly: ignore [unbound-name] + input_spec_proxy.node.type = type(input_spec) + all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args) + + # 2. Create a proxy call to `flat_apply`, then fake-tensor propagate + # the call and wrap output into a VariableTracker. + proxy = tx.output.create_proxy("call_function", flat_apply, all_args, {}) + try: + # TODO support more output types once `flat_apply` supports + # pytree-able output types. We can have Dynamo trace through an + # unflatten call (just like we traced through a flatten above) + # to rebuild the actual output VT. + out_vt = wrap_fx_proxy(tx, proxy) + except ( + # From `handle_traced_output`. + torch._dynamo.exc.Unsupported, + # From `flat_apply` assert on output type. + torch._dynamo.exc.TorchRuntimeError, + ): + unimplemented( + gb_type="Unsupported output type for nonstrict_trace-ed function", + context=f"Function: {fn.__name__}", + explanation=( + "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, list)" + " are allowed as output. The result of this call contains an unsupported type." + ), + hints=[*graph_break_hints.SUPPORTABLE], + ) + + return out_vt + + def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): + """inline behavior of torch.nn.modules.utils._ntuple""" + if self.value is torch.nn.modules.utils._ntuple: + count = args[0].as_python_constant() + else: + count = self.value.__closure__[0].cell_contents + assert isinstance(count, int) + assert not kwargs + + def handle_ntuple(value): + if value.has_unpack_var_sequence(tx): + return variables.TupleVariable( + list(value.unpack_var_sequence(tx)), + ) + elif value.is_python_constant(): + # constant prop through it + return variables.ConstantVariable.create( + torch.nn.modules.utils._ntuple(count)(value.as_python_constant()), + ) + else: + unimplemented( + gb_type="Attempted to use `torch.nn.modules.utils._ntuple` with unsupported argument type", + context=f"value={value}", + explanation="Dynamo does not support this.", + hints=[ + "Change use of _ntuple with argument as constant or tensor.", + ], + ) + + if self.value is torch.nn.modules.utils._ntuple: + return variables.LambdaVariable(handle_ntuple) + else: + return handle_ntuple(args[0]) + + @classmethod + def call_nn_parameter(cls, tx, data=None, requires_grad=True): + """A call to torch.nn.Parameter() gets lifted to before the graph""" + if tx.export: + unimplemented( + gb_type="Attempted to use `torch.nn.Parameter()` with export", + context="", + explanation="Dynamo does not support this.", + hints=[ + "Do not use `torch.nn.Parameter()` with export.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + if isinstance(requires_grad, variables.VariableTracker): + try: + requires_grad = requires_grad.as_python_constant() + except NotImplementedError: + unimplemented( + gb_type="non-constant `requires_grad` argument to `torch.nn.Parameter`", + context=f"requires_grad={requires_grad}", + explanation="Dynamo does not support this.", + hints=[ + "Change `requires_grad` to be a bool.", + *graph_break_hints.USER_ERROR, + ], + ) + + if data is None or not data.is_tensor(): + unimplemented( + gb_type="`torch.nn.Parameter()` with unsupported data type", + context=f"data={data}", + explanation="Called `torch.nn.Parameter()` with non-Tensor argument.", + hints=[ + "Ensure the argument to `torch.nn.Parameter()` is a `torch.Tensor`.", + *graph_break_hints.USER_ERROR, + ], + ) + + # this results in cleaner graphs, but only works for inputs + # pyrefly: ignore [missing-attribute] + if data.source: + return cls._nn_param_via_prefix_insert(tx, data, requires_grad) + + if config.graph_break_on_nn_param_ctor: + # Need user to manually move since we cannot + unimplemented( + gb_type="Attempted to use `torch.nn.Parameter()` constructor with Dynamo", + context="", + explanation="Dynamo does not support this", + hints=[ + "Try to construct `torch.nn.Parameter()` outside the compiled region.", + "If this is not possible, turn `graph_break_on_nn_param_ctor` off", + *graph_break_hints.SUPPORTABLE, + ], + ) + + # TODO[@lucaskabela]: Remove the behavior below since it is deprecated + if isinstance( + data, + TensorWithTFOverrideVariable, + # pyrefly: ignore [missing-attribute] + ) or is_traceable_wrapper_subclass_type(data.class_type): + unimplemented( + gb_type="Attempted to use torch.nn.Parameter constructor with tensor subclass", + context=str(data), + explanation="Dynamo does not support this.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + if not can_convert_to_tracable_parameter(): + unimplemented( + gb_type="`torch.nn.Parameter`: cannot convert to traceable tracable", + context="", + explanation="convert_tracable_parameter is set to False.", + hints=[ + "Check usage of context manager: do_not_convert_to_tracable_parameter", + *graph_break_hints.DIFFICULT, + ], + ) + + try: + # pyrefly: ignore [missing-attribute] + shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) + # pyrefly: ignore [missing-attribute] + dtype = data.var_getattr(tx, "dtype").as_python_constant() + # pyrefly: ignore [missing-attribute] + device = data.var_getattr(tx, "device").as_python_constant() + except NotImplementedError as e: + unimplemented( + gb_type="`torch.nn.Parameter` with non-constant Tensor attributes", + context=f"data={data}", + explanation="Dynamo does not support this.", + hints=[ + "Ensure the Tensor argument's shape, dtype, and device are correct.", + *graph_break_hints.USER_ERROR, + ], + from_exc=e, + ) + + placeholder = tx.output.synthetic_graph_input( + new_parameter_placeholder, + # pyrefly: ignore [unbound-name] + [shape, dtype, device, requires_grad], + ) + # pyrefly: ignore [missing-attribute] + if data.requires_grad: + # pyrefly: ignore [missing-attribute] + data = data.call_method(tx, "detach", [], {}) + + from .builder import wrap_fx_proxy + + result = wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + tracable_create_parameter, + # pyrefly: ignore [missing-attribute] + (data.as_proxy(), placeholder.as_proxy()), + {}, + ), + # In reconstruct() we should use the original parameter. The one + # returned by the graph will be an alias. + source=placeholder.source, + ) + assert result.is_tensor() + result.class_type = torch.nn.Parameter # type: ignore[union-attr] + + # TODO(jansel/bdhirsh) - There is some issue with + # tracable_create_parameter. It does not seem to use the right + # grad_enabled. Since this is parameter, we can just override the + # has_grad_fn field to False to workaround the issue. + result.has_grad_fn = False # type: ignore[union-attr] + + # TODO(jansel): if the new param falls out of scope, currently it won't get freed until + # the end of the graph. We should fix this. + return result + + @staticmethod + def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad): + # Alternate version if we have a .source + varname = tx.output.new_var() + + # construct the nn.Parameter before the graph save it to varname + assert tx.output.root_tx is not None + cg = PyCodegen(tx.output.root_tx) + cg.add_push_null(lambda: cg.load_import_from("torch.nn", "Parameter")) + cg(data.source) + cg(variables.ConstantVariable(requires_grad)) + cg.call_function(2, False) + cg.store(varname) + tx.output.pregraph_bytecode.extend(cg.get_instructions()) + + data_node = data.as_proxy().node + if data_node.op not in ("placeholder", "get_attr"): + unimplemented( + gb_type="Unexpected type of data placeholder op for parameter construction", + context=f"data_node.op={data_node.op}", + explanation="Data node op should be placeholder or get_attr.", + hints=[ + *graph_break_hints.DIFFICULT, + ], + ) + + # add the newly constructed nn.Parameter as a graph input + source = SyntheticLocalSource(varname) + example_value = torch.nn.Parameter( + tx.output.example_value_from_input_node(data.as_proxy().node), + requires_grad=requires_grad, + ) + result = VariableTracker.build(tx, example_value, source) + # Realize the VT because we will delete the guards on it in the next line. + result = result.realize() + # No need to guard on this since we already guarded on `data`. + # These guards would fail since varname doesn't exist until after the function starts + TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( + source + ) + return result + + def call_tensor_method(self, tx, args, kwargs): + return args[0].call_method(tx, self.get_function().__name__, args[1:], kwargs) + + def is_tensor_method(self): + from ..trace_rules import get_tensor_method + + return ( + inspect.ismethoddescriptor(self.get_function()) + and hasattr(self.get_function(), "__objclass__") + and self.get_function().__objclass__ == torch._C.TensorBase + ) or self.get_function() in get_tensor_method() + + def torch_function_override_enabled(self, tx, args, kwargs): + return ( + self.get_function() in get_overridable_functions() + or isinstance( + self.get_function(), + (torch._ops.OpOverload, torch._ops.OpOverloadPacket), + ) + ) and can_dispatch_torch_function(tx, args, kwargs) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +class DispatchKeySetVariable(BaseTorchVariable): + """represents torch.DispatchKeySet""" + + @staticmethod + def create(value, **kwargs): + return DispatchKeySetVariable(value, **kwargs) + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.DISPATCH_KEY_SET_MATCH)) + return cls(value, source=source) + + def is_constant_fold_method(self, name): + return name == "has" + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "VariableTracker": + if self.is_constant_fold_method(name) and check_unspec_or_constant_args( + args, kwargs + ): + method = getattr(self.value, name) + return variables.ConstantVariable.create( + method( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + elif name == "highestPriorityTypeId": + return variables.EnumVariable(self.value.highestPriorityTypeId()) + return super().call_method(tx, name, args, kwargs) + + +class FuncTorchInterpreterVariable(BaseTorchVariable): + """represents torch._functorch.pyfunctorch.FuncTorchInterpreter""" + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + return cls(value, source=source) + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "VariableTracker": + if name == "key": + return variables.EnumVariable(self.value.key()) + elif name == "process": + return tx.inline_user_function_return( + VariableTracker.build(tx, self.value.process.__func__), + [self] + args, + kwargs, + ) + elif name in ["level", "batch_size", "randomness"]: + return variables.ConstantVariable.create(getattr(self.value, name)()) + elif name == "lower": + assert not args and not kwargs + return variables.TemporarilyPopInterpreterStackCtxManagerVariable.create( + tx, None + ) + return super().call_method(tx, name, args, kwargs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/torch_function.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/torch_function.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a86eb4f017f88b36fcd7ac94c488352bc4e26f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/torch_function.py @@ -0,0 +1,761 @@ +"""TorchDynamo support for __torch_function__ tensor subclasses. + +This module implements support for tensor subclasses with __torch_function__ overrides. +A tensor subclass instance is represented as a TensorWithTFOverrideVariable, which handles +dispatching __torch_function__ on attribute accesses, method calls, and torch API calls. + +Unsupported features: +- Triggering __torch_function__ on tensor subclass non-tensor custom attributes +- Graph breaking on mutating guardable tensor properties within a __torch_function__ context + (can cause excessive recompiles in certain cases) +- Matching exact eager behavior of ignoring __torch_function__ objects in non-tensor + argument positions of Torch API calls + +Supported features: +- Static method implementations of __torch_function__ on custom objects (triggers on torch + API calls with the object as any argument) +- Triggering __torch_function__ on torch API calls with tensor subclass arguments +- __torch_function__ calls on base tensor attribute access and method calls for tensor + subclass instances +- Matches dispatch ordering behavior of eager __torch_function__ with subclass/object + arguments in any position + +See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w +for more information on the design. +""" + +import collections +import contextlib +import functools +import inspect +import operator +from collections.abc import Generator, Iterable, Sequence +from types import TracebackType +from typing import Any, Optional, TYPE_CHECKING + +import torch._C +import torch.utils._pytree as pytree +from torch._guards import Source +from torch.overrides import ( + _get_overloaded_args, + get_default_nowrap_functions, + TorchFunctionMode, +) +from torch.utils._device import DeviceContext + +from .. import graph_break_hints +from ..exc import unimplemented +from ..guards import GuardBuilder, install_guard +from ..polyfills import NoEnterTorchFunctionMode +from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource +from ..utils import ( + class_has_getattribute, + clear_torch_function_mode_stack, + get_safe_global_name, + has_torch_function, + is_tensor_base_attr_getter, + set_torch_function_mode_stack, +) +from .base import VariableTracker +from .constant import ConstantVariable +from .ctx_manager import GenericContextWrappingVariable +from .functions import UserMethodVariable +from .lazy import LazyVariableTracker +from .lists import TupleVariable +from .tensor import TensorSubclassVariable, TensorVariable +from .user_defined import UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +bin_ops = [ + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + operator.ne, + operator.eq, + operator.sub, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.imod, + operator.iadd, + operator.isub, +] + +bin_int_ops = [ + operator.and_, + operator.or_, + operator.xor, + operator.iand, + operator.ixor, + operator.ior, +] + +un_int_ops = [operator.invert] + +tensor_and_int_ops = [ + operator.lshift, + operator.rshift, + operator.ilshift, + operator.irshift, + operator.getitem, +] + +un_ops = [ + operator.abs, + operator.pos, + operator.neg, + operator.not_, # Note: this has a local scalar dense call + operator.length_hint, +] + + +banned_attrs = [ + fn.__self__.__name__ # type: ignore[attr-defined] + for fn in get_default_nowrap_functions() + if is_tensor_base_attr_getter(fn) +] + + +@functools.cache +def get_prev_stack_var_name() -> str: + from ..bytecode_transformation import unique_id + + return unique_id("___prev_torch_function_mode_stack") + + +class TorchFunctionModeVariable(GenericContextWrappingVariable): + @staticmethod + def is_supported_torch_function_mode(ty: type[TorchFunctionMode]) -> bool: + # Supported in this sense means we can support graph breaks under the + # context. + # We are able to trace custom modes but if there are graph breaks under them + # and they have a custom __enter__/__exit__ we don't handle this for the + # same reason we don't handle generic context managers: there may be side effects + # that are now affected by executing the function across two frames instead of one + # Today we support the enter/exit of the default TorchFunctionMode as well as + # DeviceContext (which is used for set_default_device) + return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( + not class_has_getattribute(ty) + and inspect.getattr_static(ty, "__enter__") is TorchFunctionMode.__enter__ + and inspect.getattr_static(ty, "__exit__") is TorchFunctionMode.__exit__ + ) + + def __init__( + self, + value: Optional[TorchFunctionMode], + source: Optional[Source] = None, + **kwargs: Any, + ): + if value is not None: + super().__init__(value, **kwargs) + self.value = value + # needed for BC with calling enter from CM code + self.cm_obj = value # type: ignore[assignment] + self.source = source # type: ignore[assignment] + + def reconstruct(self, codegen: "PyCodegen") -> None: + # This shouldn't be called unless we have a source + assert self.source + self.source.reconstruct(codegen) + + def module_name(self) -> str: + return self.value.__module__ + + def fn_name(self) -> str: + return type(self.value).__name__ + + def python_type(self) -> type: + return type(self.value) + + def call_torch_function( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> VariableTracker: + return call_torch_function( + tx, + get_torch_function_fn(tx, self), # type: ignore[arg-type] + fn, + types, + args, + kwargs, + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + from .torch import TorchInGraphFunctionVariable + + if isinstance(self.value, NoEnterTorchFunctionMode): + return ConstantVariable.create(None) + + TorchInGraphFunctionVariable( + torch._C._push_on_torch_function_stack + ).call_function(tx, [self], {}) + return ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker: + from .torch import TorchInGraphFunctionVariable + + TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( + tx, [], {} + ) + return ConstantVariable.create(None) + + def reconstruct_type(self, codegen: "PyCodegen") -> None: + ty = NoEnterTorchFunctionMode + codegen( + AttrSource( + codegen.tx.import_source(ty.__module__), + ty.__name__, + ) + ) + + def supports_graph_breaks(self) -> bool: + return True + + def exit_on_graph_break(self) -> bool: + return False + + +# Used to clear/restore the python torch function mode stack and temporarily restore it as needed +class TorchFunctionModeStackStateManager: + def __init__(self) -> None: + self.stack: list[Any] = [] + + def __enter__(self) -> None: + self.stack = torch.overrides._get_current_function_mode_stack() + clear_torch_function_mode_stack() + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + set_torch_function_mode_stack(self.stack) + self.stack = [] + + @contextlib.contextmanager + def temp_restore_stack(self) -> Generator[None, None, None]: + prev = torch.overrides._get_current_function_mode_stack() + set_torch_function_mode_stack(self.stack) + try: + yield + finally: + set_torch_function_mode_stack(prev) + + +torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() + + +class SymbolicTorchFunctionState: + def __init__(self, py_stack: Iterable[Any]) -> None: + # This is annoyingly complicated because of how the torch function subclass + mode C API was designed + # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass + # These are their definitions: + # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered + # (if either are entered, this will be False) + # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR + # torch._C.DisableTorchFunction has been entered + # To disambiguate these and keep myself sane I added a C API to check whether all torch function + # concepts (modes and subclasses) are enabled. + # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate + # the stack length from the enablement state of torch function modes. + # This is important because now if a mode is pushed while dynamo is tracing, we know whether + # or not torch function modes are enabled and whether we should trace it. + self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled() + + # This differs from the C API of the same name + # this will only be false iff we have entered torch._C.DisableTorchFunction + # and does not take into account the mode stack length, while the C API bundles these + # two concepts + self.torch_function_mode_enabled = ( + not torch._C._is_torch_function_all_disabled() + ) + + self.cur_mode = None + + TorchFunctionModeStackVariable.reset() + + self.mode_stack: collections.deque[TorchFunctionModeVariable] = ( + collections.deque() + ) + + for i, val in enumerate(py_stack): + self.mode_stack.append( + LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) # type: ignore[arg-type] + ) + + def in_torch_function_mode(self) -> bool: + return len(self.mode_stack) > 0 + + def pop_torch_function_mode(self) -> TorchFunctionModeVariable: + return self.mode_stack.pop() + + def push_torch_function_mode(self, mode_var: TorchFunctionModeVariable) -> None: + self.mode_stack.append(mode_var) + + def call_torch_function_mode( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> Any: + with self._pop_mode_for_inlining() as cur_mode: + return cur_mode.call_torch_function(tx, fn, types, args, kwargs) + + @contextlib.contextmanager + def _pop_mode_for_inlining( + self, + ) -> Generator[TorchFunctionModeVariable, None, None]: + old_mode = self.cur_mode + self.cur_mode = self.pop_torch_function_mode() # type: ignore[assignment] + try: + yield self.cur_mode # type: ignore[misc] + finally: + mode = self.cur_mode + self.cur_mode = old_mode + self.push_torch_function_mode(mode) # type: ignore[arg-type] + + +class TorchFunctionModeStackVariable(VariableTracker): + """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation""" + + # singleton value representing the global torch function mode stack + # singleton (it exists in C++) + stack_value_singleton = object() + + # offset is used to track if we have inserted/removed a + # device context which is always placed at the bottom of the stack + # if a device context is inserted, the graph will run this mutation + # so when we want to reconstruct any other modes on the stack + # their indices should be shifted right by 1 (+1) + # Conversely, if there was a device context on the stack, and the graph + # mutates the stack to remove that context (set default device to None) + # each of the indices of other modes should be shifted left by 1 (-1) + offset = 0 + + def __init__( + self, + source: Source, + symbolic_stack: collections.deque[TorchFunctionModeVariable], + ) -> None: + self.source = source + self.symbolic_stack = symbolic_stack + + @classmethod + def reset(cls) -> None: + cls.offset = 0 + + @classmethod + def register_mutation(cls, tx: "InstructionTranslator") -> None: + if cls.stack_value_singleton not in tx.output.side_effects: + var = cls( + source=Source(), + symbolic_stack=tx.symbolic_torch_function_state.mode_stack, + ) + tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) + tx.output.side_effects.mutation(var) + + @classmethod + def register_device_context_insertion(cls, tx: "InstructionTranslator") -> None: + stack = tx.symbolic_torch_function_state.mode_stack + if stack and cls.is_device_context(stack[0]): + return + else: + cls.offset += 1 + stack.insert( + 0, + TorchFunctionModeVariable( + None, source=TorchFunctionModeStackSource(-cls.offset) + ), + ) + + @classmethod + def clear_default_device(cls, tx: "InstructionTranslator") -> None: + stack = tx.symbolic_torch_function_state.mode_stack + if stack and cls.is_device_context(stack[0]): + stack.popleft() + cls.offset -= 1 + + @staticmethod + def is_device_context(var: TorchFunctionModeVariable) -> bool: + return isinstance(var.value, DeviceContext) or var.value is None + + @classmethod + def get_mode_index(cls, ind: int) -> int: + return ind + cls.offset + + +def _get_all_args( + args: Iterable[Any], kwargs: dict[str, Any] +) -> Iterable[VariableTracker]: + return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs)) + + +def _flatten_vts(vts: Iterable[VariableTracker]) -> list[VariableTracker]: + from collections import deque + + from .dicts import ConstDictVariable + from .lists import ListVariable + + vts = deque(vts) + output = [] + + while vts: + vt = vts.popleft() + + if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): # type: ignore[attr-defined] + vt.realize() + + if vt.is_realized(): + if isinstance(vt, ListVariable): + vts.extend(vt.items) + continue + elif isinstance(vt, ConstDictVariable): + vts.extend(vt.items.values()) + continue + + output.append(vt) + + return output + + +def _get_subclass_type(var: VariableTracker) -> type: + assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) + return var.python_type() + + +def _get_subclass_type_var( + tx: "InstructionTranslator", var: VariableTracker +) -> VariableTracker: + if isinstance(var, TensorWithTFOverrideVariable): + return var.class_type_var(tx) + elif isinstance(var, UserDefinedObjectVariable): + source = var.source and TypeSource(var.source) + return VariableTracker.build(tx, var.python_type(), source) + else: + raise AssertionError(f"Unexpected type {type(var)}") + + +def _is_attr_overridden( + tx: "InstructionTranslator", var: VariableTracker, name: str +) -> bool: + if not isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)): + return False + import torch + + overridden = False + try: + attr_val = inspect.getattr_static(var.python_type(), name) + overridden |= attr_val != getattr(torch.Tensor, name) + except AttributeError: + pass + + return overridden + + +def call_torch_function( + tx: "InstructionTranslator", + torch_function_var: VariableTracker, + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], +) -> Any: + # This emulates calling __torch_function__, which has a signature + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # + # Also notice the `cls` is not explicitly passed in the reference + # implementations: + # 1. https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/csrc/utils/python_arg_parser.cpp#L368-L374 # noqa: B950 + # 2. https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/overrides.py#L1741-L1743 + tf_args = [ + fn, + types, + VariableTracker.build(tx, tuple(args)), + VariableTracker.build(tx, kwargs), + ] + return torch_function_var.call_function(tx, tf_args, {}) + + +def get_torch_function_fn( + tx: "InstructionTranslator", vt: VariableTracker +) -> VariableTracker: + # The underlying function could be a classmethod, staticmethod, regular + # function or a function with C-implementation. It doesn't matter as long as + # they satisfy the calling convention in `call_torch_function`. + from .builtin import BuiltinVariable + + args = [vt, ConstantVariable("__torch_function__")] + func_vt = BuiltinVariable(getattr).call_function(tx, args, {}) + return func_vt + + +def can_dispatch_torch_function( + tx: "InstructionTranslator", args: Iterable[Any], kwargs: dict[str, Any] +) -> bool: + has_overridden_args = any( + has_torch_function(arg) for arg in _get_all_args(args, kwargs) + ) + tf_state = tx.symbolic_torch_function_state + return (has_overridden_args and tf_state.torch_function_subclass_enabled) or ( + tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode() + ) + + +def dispatch_torch_function( + tx: "InstructionTranslator", + fn: VariableTracker, + args: Iterable[Any], + kwargs: dict[str, Any], +) -> Any: + """Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args""" + + all_args = _get_all_args(args, kwargs) + overloaded_args = _get_overloaded_args( + [arg for arg in all_args if has_torch_function(arg)], + _get_subclass_type, + ) + + types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]) + + if tx.symbolic_torch_function_state.in_torch_function_mode(): + res = tx.symbolic_torch_function_state.call_torch_function_mode( + tx, fn, types, args, kwargs + ) + if not res.is_constant_match(NotImplemented): + return res + + for arg in overloaded_args: + res = arg.call_torch_function( + tx, + fn, + types, + args, + kwargs, + ) + + if not res.is_constant_match(NotImplemented): + return res + + unimplemented( + gb_type="All __torch_function__ overrides returned NotImplemented due to TypeError from user code", + context=f"{fn=}, {args=}, {kwargs=}", + explanation=f"All __torch_function__ overrides for for function {fn} returned NotImplemented", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + +class TensorWithTFOverrideVariable(TensorVariable): + """ + Represents a tensor subclass instance with a __torch_function__ override. + """ + + @classmethod + def from_tensor_var( + cls, + tx: "InstructionTranslator", + tensor_var: VariableTracker, + class_type: type, + cls_source: Source, + ) -> "TensorWithTFOverrideVariable": + # [Note: __torch_function__] coerce `tensor_var` into a + # TensorWithTFOverrideVariable. In eager, this is just a type change. + import torch + + # This simulates shallow-copying the tensor object. + kwargs = dict(tensor_var.__dict__) + input_tensor_type = kwargs.pop("class_type") + assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), ( + f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var" + ) + var = cls(class_type=class_type, **kwargs) + var.install_global(tx) + return var + + def install_global(self, tx: "InstructionTranslator") -> None: + # stash the subclass type to rewrap an output tensor if needed + # this is needed because the actual type needs to be available + # each time the compiled artifact is run and outputs a wrapped tensor. + if self.global_mangled_class_name(tx) not in tx.output.global_scope: + # Safe because global_mangled_class_name figures it out + tx.output.install_global_unsafe( + self.global_mangled_class_name(tx), self.class_type + ) + + def python_type(self) -> type: + return self.class_type + + def class_type_var(self, tx: "InstructionTranslator") -> VariableTracker: + return TensorSubclassVariable( + self.class_type, source=GlobalSource(self.global_mangled_class_name(tx)) + ) + + def global_mangled_class_name(self, tx: "InstructionTranslator") -> str: + return get_safe_global_name( + tx, f"__subclass_{self.class_type.__name__}", self.class_type + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + # [Note: __torch_function__] We currently only support attributes that are defined on + # base tensors, custom attribute accesses will graph break. + import torch + + # I think only `_base` is breaking because we aren't modelling view + # relationship perfectly in some scenarios. + if name in banned_attrs: + unimplemented( + gb_type="Unsupported tensor subclass attribute access", + context=f"{name}", + explanation="`torch.compile` currently can't trace this", + hints=[ + f"Avoid accessing {name} of tensor subclass in torch.compile region", + *graph_break_hints.SUPPORTABLE, + ], + ) + + # Handle non-overridden attributes inherited from `torch.Tensor`. + attr_is_overridden = _is_attr_overridden(tx, self, name) + if ( + hasattr(torch.Tensor, name) + and not attr_is_overridden + and not inspect.ismethoddescriptor(getattr(torch.Tensor, name)) + ): + args = [self] + kwargs: dict[Any, Any] = {} + if can_dispatch_torch_function(tx, args, kwargs): + get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) + + return self.call_torch_function( + tx, + get_fn, + TupleVariable([self.class_type_var(tx)]), + args, + kwargs, + ) + else: + # `TensorVariable.var_getattr` doesn't handle user-defined + # function/attribute well, so we explicitly handle them here. + # + # TODO move this logic into `TensorVariable`, or try to merge it + # with similar logic in `UserDefinedObjectVariable`. + try: + attr = inspect.getattr_static(self.class_type, name) + except AttributeError: + pass + else: + import types + + cls_source = GlobalSource(self.global_mangled_class_name(tx)) + attr_source = AttrSource(cls_source, name) + if isinstance(attr, types.FunctionType): + install_guard(attr_source.make_guard(GuardBuilder.CLOSURE_MATCH)) + return UserMethodVariable(attr, self) + + elif isinstance(attr, property): + getter_source = AttrSource(attr_source, "fget") + getter = attr.fget + getter_var = VariableTracker.build(tx, getter, source=getter_source) + return getter_var.call_function(tx, [self], {}) + + elif isinstance(attr, classmethod): + return UserMethodVariable( + attr.__func__, self.class_type_var(tx), source=attr_source + ) + + elif attr_is_overridden: + unimplemented( + gb_type="Unsupported tensor subclass overridden attribute access", + context=f"{name}", + explanation="`torch.compile` only support tracing certain types of overridden tensor subclass attributes", + hints=[ + f"Avoid accessing {name} of tensor subclass in torch.compile region", + f"Renaming attribute `{name}` of type {self.class_type}", + *graph_break_hints.SUPPORTABLE, + ], + ) + + return super().var_getattr(tx, name) + + def call_torch_function( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> Any: + # NOTE this assumes `__torch_function__` isn't modified during tracing. + if not hasattr(self, "torch_function_fn"): + self.torch_function_fn = get_torch_function_fn(tx, self) + + return call_torch_function( + tx, + self.torch_function_fn, + fn, + types, + args, + kwargs, + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # This code block implements inlining the __torch_function__ override + # of `call_method`. + tf_args = [self] + list(args) + if can_dispatch_torch_function(tx, tf_args, kwargs): + import torch + + if _is_attr_overridden(tx, self, name): + unimplemented( + gb_type="Tensor subclass overridden method call", + context=f"{name}", + explanation="`torch.compile` currently can't trace this", + hints=[ + f"Avoid calling {name} of tensor subclass in torch.compile region", + f"Renaming method `{name}` of type {self.class_type}", + *graph_break_hints.SUPPORTABLE, + ], + ) + + # [Note: __torch_function__] Currently we only support methods that are defined on tensor + # we will graph break in other cases this will need a bigger overhaul of extracting methods/comparing them for equality + # We've established with the above check that the method is not overridden, so we guard that the method is the same + # as the impl defined on tensor and retrieve it + if self.source: + source = AttrSource(AttrSource(self.source, "__class__"), name) + value = inspect.getattr_static(self.python_type(), name) + else: + source = None + value = getattr(torch.Tensor, name) + func_var = VariableTracker.build(tx, value, source) + return dispatch_torch_function(tx, func_var, tf_args, kwargs) + else: + return super().call_method(tx, name, args, kwargs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/user_defined.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/user_defined.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b39b2f9b53e0ba6c04db41dc616ad3d5daea4a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dynamo/variables/user_defined.py @@ -0,0 +1,2397 @@ +# mypy: ignore-errors + +""" +This module contains variable classes for handling user-defined objects in Dynamo's tracing system. + +The key classes are: +- UserDefinedVariable: Base class for representing custom Python objects +- UserDefinedClassVariable: Handles Python class objects/types +- UserDefinedObjectVariable: Fallback class for instance objects, with support for method calls, + attribute access, and other Python object behaviors. +- Specialized subclasses for common patterns: + - UserDefinedDictVariable: For dict subclasses + - UserDefinedSetVariable: For set subclasses + - UserDefinedTupleVariable: For tuple subclasses + - UserDefinedExceptionObjectVariable: For exception subclasses + - FrozenDataClassVariable: Special handling of frozen dataclasses + - MutableMappingVariable: For collections.abc.MutableMapping subclasses + +Dynamo specializes to VariableTracker subclasses like FrozenDataClassVariable if available; if no +subclass qualifies, it falls back to UserDefinedObjectVariable. + +These classes help Dynamo track and handle arbitrary Python objects during tracing, +maintaining proper semantics while enabling optimizations where possible. +""" + +import _collections +import builtins +import collections +import contextlib +import dataclasses +import enum +import functools +import inspect +import itertools +import random +import sys +import threading +import types +import warnings +import weakref +from typing import TYPE_CHECKING +from typing_extensions import is_typeddict + +import torch._dynamo.config +import torch.nn +from torch._guards import TracingContext +from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type + +from .. import graph_break_hints, polyfills, variables +from ..bytecode_transformation import create_call_function +from ..create_parameter_op import do_not_convert_to_tracable_parameter +from ..exc import ( + handle_observed_exception, + ObservedAttributeError, + ObservedKeyError, + ObservedTypeError, + ObservedUserStopIteration, + raise_observed_exception, + unimplemented, +) +from ..graph_bytecode_inputs import get_external_object_by_index +from ..guards import GuardBuilder, install_guard +from ..source import ( + AttrSource, + CallFunctionNoArgsSource, + DataclassFieldsSource, + DictGetItemSource, + GetItemSource, + RandomValueSource, + TypeDictSource, + TypeMROSource, + TypeSource, + UnspecializedParamBufferSource, +) +from ..utils import ( + check_constant_args, + cmp_name_to_op_mapping, + dict_methods, + frozenset_methods, + get_custom_getattr, + has_torch_function, + is_frozen_dataclass, + is_lru_cache_wrapped_function, + is_namedtuple_cls, + is_wrapper_or_member_descriptor, + istype, + list_methods, + namedtuple_fields, + object_has_getattribute, + proxy_args_kwargs, + raise_args_mismatch, + raise_on_overridden_hash, + set_methods, + tensortype_to_dtype, + tuple_methods, + unpatched_nn_module_getattr, +) +from .base import raise_type_error_exc, ValueMutationNew, VariableTracker +from .dicts import ConstDictVariable, DefaultDictVariable +from .lists import SizeVariable + + +try: + import numpy as np +except ModuleNotFoundError: + np = None + +try: + from torch.utils._cxx_pytree import PyTreeSpec +except ImportError: + PyTreeSpec = type(None) + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + from .constant import ConstantVariable + + +def is_standard_setattr(val): + return val in (object.__setattr__, BaseException.__setattr__) + + +def is_standard_delattr(val): + return val in (object.__delattr__, BaseException.__delattr__) + + +def is_forbidden_context_manager(ctx): + f_ctxs = [] + + try: + from _pytest.python_api import RaisesContext + from _pytest.recwarn import WarningsChecker + + f_ctxs.append(RaisesContext) + f_ctxs.append(WarningsChecker) + except ImportError: + pass + + if m := sys.modules.get("torch.testing._internal.jit_utils"): + f_ctxs.append(m._AssertRaisesRegexWithHighlightContext) + + return ctx in f_ctxs + + +def is_cython_function(obj): + return ( + callable(obj) + and hasattr(type(obj), "__name__") + and type(obj).__name__ == "cython_function_or_method" + ) + + +class UserDefinedVariable(VariableTracker): + value: object + + +class UserDefinedClassVariable(UserDefinedVariable): + value: type[object] + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + # Used when we materialize class.__dict__ to a MappingProxyObject. In + # this case, we don't want to allow mutation in the class because there + # is no way to reflect it in the created MappingProxyVariable. + self.ban_mutation = False + + def as_python_constant(self): + return self.value + + def as_proxy(self): + return self.value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.value})" + + @staticmethod + @functools.cache + def _constant_fold_classes(): + return { + torch.device, + torch.finfo, + torch.iinfo, + torch.Size, + } + + @staticmethod + @functools.cache + def _in_graph_classes(): + _in_graph_class_list = { + torch.Tensor, + torch.cuda.FloatTensor, + torch.cuda.DoubleTensor, + torch.cuda.HalfTensor, + torch.cuda.BFloat16Tensor, + torch.cuda.ByteTensor, + torch.cuda.CharTensor, + torch.cuda.IntTensor, + torch.cuda.ShortTensor, + torch.cuda.LongTensor, + torch.Stream, + torch.Event, + torch.cuda.Stream, + torch.cuda.Event, + torch.xpu.Stream, + torch.xpu.Event, + } + if hasattr(torch, "hpu"): + _in_graph_class_list.update( + { + torch.hpu.Stream, + torch.hpu.Event, + } + ) + + return set(tensortype_to_dtype.keys()) | _in_graph_class_list + + @staticmethod + @functools.cache + def supported_c_new_functions(): + exceptions = [ + getattr(builtins, name).__new__ + for name in dir(builtins) + if isinstance(getattr(builtins, name), type) + and issubclass(getattr(builtins, name), BaseException) + ] + return { + object.__new__, + dict.__new__, + set.__new__, + frozenset.__new__, + tuple.__new__, + list.__new__, + }.union(exceptions) + + @staticmethod + def is_supported_new_method(value): + # TODO(anijain2305) - Extend this to support objects with default tp_new + # functions. + return value in UserDefinedClassVariable.supported_c_new_functions() + + def can_constant_fold_through(self): + return self.value in self._constant_fold_classes() + + def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): + if tx.output.side_effects.has_pending_mutation_of_attr(self, key): + mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) + return not isinstance(mutated_attr, variables.DeletedVariable) + + return key in self.value.__dict__ + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + from . import ConstantVariable, EnumVariable + + source = AttrSource(self.source, name) if self.source is not None else None + + if name == "__name__": + return ConstantVariable.create(self.value.__name__) + elif name == "__qualname__": + return ConstantVariable.create(self.value.__qualname__) + elif name == "__dict__": + options = {"source": source} + return variables.GetAttrVariable(self, name, **options) + elif name == "__mro__": + attr_source = self.source and TypeMROSource(self.source) + return VariableTracker.build(tx, self.value.__mro__, attr_source) + + # Special handling of collections.OrderedDict.fromkeys() + # Wrap it as GetAttrVariable(collections.OrderedDict, "fromkeys") to make it consistent with + # collections.defaultdict, and both will be handled at UserDefinedClassVariable.call_method(). + # Otherwise, it would be wrapped as UserDefinedObjectVariable(collections.OrderedDict.fromkeys), + # and we need duplicate code to handle both cases. + if ( + self.value in {collections.OrderedDict, collections.defaultdict} + and name == "fromkeys" + ): + return super().var_getattr(tx, name) + + try: + obj = inspect.getattr_static(self.value, name) + except AttributeError: + if type(self.value) is type: + raise_observed_exception( + AttributeError, + tx, + args=[ + f"type object '{self.value.__name__}' has no attribute '{name}'" + ], + ) + else: + # Cannot reason about classes with a custom metaclass + # See: test_functions::test_getattr_metaclass + obj = None + + if name == "__new__" and UserDefinedClassVariable.is_supported_new_method(obj): + return super().var_getattr(tx, name) + + if name in cmp_name_to_op_mapping and not isinstance(obj, types.FunctionType): + return variables.GetAttrVariable(self, name, source=source) + + if isinstance(obj, staticmethod): + return VariableTracker.build(tx, obj.__get__(self.value), source) + elif isinstance(obj, classmethod): + if isinstance(obj.__func__, property): + fget_vt = VariableTracker.build(tx, obj.__func__.fget) + return fget_vt.call_function(tx, [self], {}) + return variables.UserMethodVariable(obj.__func__, self, source=source) + elif isinstance(obj, types.ClassMethodDescriptorType): + # e.g.: inspect.getattr_static(dict, "fromkeys") + # inspect.getattr_static(itertools.chain, "from_iterable") + func = obj.__get__(None, self.value) + return VariableTracker.build(tx, func, source) + elif source: + if inspect.ismemberdescriptor(obj): + return VariableTracker.build(tx, obj.__get__(self.value), source) + + if ConstantVariable.is_literal(obj): + return ConstantVariable.create(obj) + elif isinstance(obj, enum.Enum): + return EnumVariable(obj) + elif self.value is collections.OrderedDict: + return variables.GetAttrVariable(self, name) + elif name in getattr(self.value, "__dict__", {}) or ( + self.value.__module__.startswith("torch.") + or self.value.__module__ == "torch" + ): + if source: + return VariableTracker.build(tx, obj, source) + + if ( + source + and not inspect.ismethoddescriptor(obj) + and not is_wrapper_or_member_descriptor(obj) + ): + return VariableTracker.build(tx, obj, source) + + return super().var_getattr(tx, name) + + def _call_cross_entropy_loss(self, tx: "InstructionTranslator", args, kwargs): + """ + functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', + label_smoothing=0.0 + + non functional ctor: weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', + label_smoothing=0.0 + + non functional loss call: input, target, optional_output + """ + from . import ConstantVariable + + def normalize_args( + weight=ConstantVariable.create(None), + size_average=ConstantVariable.create(None), + ignore_index=ConstantVariable.create(-100), + reduce=ConstantVariable.create(None), + reduction=ConstantVariable.create("mean"), + label_smoothing=ConstantVariable.create(0.0), + ): + return ( + weight, + size_average, + ignore_index, + reduce, + reduction, + label_smoothing, + ) + + ( + weight, + size_average, + ignore_index, + reduce_arg, + reduction, + label_smoothing, + ) = normalize_args(*args, **kwargs) + + def fake_cross_entropy_loss(input, target): + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.nn.functional.cross_entropy, + *proxy_args_kwargs( + [ + input, + target, + weight, + size_average, + ignore_index, + reduce_arg, + reduction, + label_smoothing, + ], + {}, + ), + ), + ) + + return variables.LambdaVariable(fake_cross_entropy_loss) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if ( + name == "__subclasses__" + and len(args) == 0 + and not kwargs + and "__subclasses__" not in self.value.__dict__ + ): + source = self.source + if self.source: + source = AttrSource(self.source, "__subclasses__") + source = CallFunctionNoArgsSource(source) + return VariableTracker.build(tx, self.value.__subclasses__(), source) + elif ( + self.value in {collections.OrderedDict, collections.defaultdict} + and name == "fromkeys" + ): + return variables.BuiltinVariable.call_custom_dict_fromkeys( + tx, self.value, *args, **kwargs + ) + elif self.value is collections.OrderedDict and name == "move_to_end": + return args[0].call_method(tx, name, [*args[1:]], kwargs) + elif name == "__eq__" and len(args) == 1 and hasattr(args[0], "value"): + return variables.ConstantVariable(self.value == args[0].value) + elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"): + return variables.ConstantVariable(self.value != args[0].value) + elif issubclass(self.value, dict) and name != "__new__": + # __new__ is handled below + return variables.BuiltinVariable(dict).call_method(tx, name, args, kwargs) + elif issubclass(self.value, (set, frozenset)) and name != "__new__": + # __new__ is handled below + return variables.BuiltinVariable(set).call_method(tx, name, args, kwargs) + elif ( + name == "__new__" + and self.value is collections.OrderedDict + and isinstance(args[0], UserDefinedClassVariable) + and args[0].value is collections.OrderedDict + ): + if kwargs and len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return variables.ConstDictVariable( + {}, collections.OrderedDict, mutation_type=ValueMutationNew() + ) + elif name == "__new__" and UserDefinedClassVariable.is_supported_new_method( + self.value.__new__ + ): + return tx.output.side_effects.track_new_user_defined_object( + self, + args[0], + args[1:], + ) + elif name == "__setattr__" and self.ban_mutation: + unimplemented( + gb_type="Class attribute mutation when the __dict__ was already materialized", + context=str(self.value), + explanation="Dyanmo does not support tracing mutations on a class when its __dict__ is materialized", + hints=graph_break_hints.SUPPORTABLE, + ) + return super().call_method(tx, name, args, kwargs) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from ..side_effects import SideEffects + from .builder import wrap_fx_proxy + + constant_args = check_constant_args(args, kwargs) + + if self.can_constant_fold_through() and constant_args: + # constant fold + return variables.ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + elif self.value is torch.nn.CrossEntropyLoss: + return self._call_cross_entropy_loss(tx, args, kwargs) + elif self.value is contextlib.nullcontext: + # import here to avoid circular dependency + from .ctx_manager import NullContextVariable + + return NullContextVariable(*args, **kwargs) + elif self.value is collections.OrderedDict: + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.construct_dict), + [self, *args], + kwargs, + ) + elif self.value is collections.defaultdict: + if len(args) == 0: + default_factory = variables.ConstantVariable.create(None) + else: + default_factory, *args = args + dict_vt = variables.BuiltinVariable.call_custom_dict( + tx, dict, *args, **kwargs + ) + return DefaultDictVariable( + dict_vt.items, + collections.defaultdict, + default_factory, + mutation_type=ValueMutationNew(), + ) + elif is_typeddict(self.value): + if self.value.__optional_keys__: + unimplemented( + gb_type="TypedDict with optional keys", + context=str(self.value), + explanation="Dyanmo does not support tracing TypedDict with optional keys", + hints=[ + "Avoid using TypedDict with optional keys", + *graph_break_hints.SUPPORTABLE, + ], + ) + return variables.BuiltinVariable(dict).call_dict(tx, *args, **kwargs) + elif self.value is collections.deque: + maxlen = variables.ConstantVariable.create(None) + + def deque_signature(iterable=None, maxlen=None): + pass + + try: + bound_args = inspect.signature(deque_signature).bind(*args, **kwargs) + except TypeError as e: + unimplemented( + gb_type="collections.deque() with bad arguments", + context=f"args={args}, kwargs={kwargs}", + explanation="Detected call to collections.deque() with bad arguments.", + hints=[ + "Fix the call to collections.deque().", + *graph_break_hints.USER_ERROR, + ], + from_exc=e, + ) + + if "iterable" in bound_args.arguments: + if not bound_args.arguments["iterable"].has_force_unpack_var_sequence( + tx + ): + unimplemented( + gb_type="collections.deque() with bad iterable argument", + context=f"args={args}, kwargs={kwargs}", + explanation="Call to collections.deque() has an iterable argument that Dynamo cannot " + "convert to a list.", + hints=[ + "Use a simpler sequence type that Dynamo can convert to a list " + "(e.g. list, tuple, list iterator, etc.)", + *graph_break_hints.USER_ERROR, + ], + ) + items = bound_args.arguments["iterable"].force_unpack_var_sequence(tx) + else: + items = [] + + if "maxlen" in bound_args.arguments: + maxlen = bound_args.arguments["maxlen"] + + return variables.lists.DequeVariable( + items, maxlen=maxlen, mutation_type=ValueMutationNew() + ) + elif self.value is weakref.ref: + if len(args) > 1: + callback = args[1] + else: + callback = variables.ConstantVariable.create(None) + return variables.WeakRefVariable(args[0], callback) + elif self.value is functools.partial: + if not args: + unimplemented( + gb_type="missing args to functools.partial", + context="", + explanation="functools.partial requires at least one argument", + hints=[ + "Fix the functools.partial call.", + *graph_break_hints.USER_ERROR, + ], + ) + # The first arg, a callable (the ctor below will assert on types) + fn = args[0] + rest_args = args[1:] + # guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the + # args and keywords + return variables.functions.FunctoolsPartialVariable( + fn, args=rest_args, keywords=kwargs + ) + elif self.value is warnings.catch_warnings and not args: + return variables.CatchWarningsCtxManagerVariable.create(tx, kwargs) + elif self.value is torch.cuda.device and not kwargs and len(args) == 1: + if not args[0].is_python_constant(): + raise_type_error_exc( + tx, "torch.cuda.device() requires a constant argument" + ) + return variables.CUDADeviceVariable.create(tx, args[0].as_python_constant()) + elif ( + issubclass(type(self.value), type) + and hasattr( + self.value, "__enter__" + ) # TODO(voz): These can invoke user code! + and hasattr( + self.value, "__exit__" + ) # TODO(voz): These can invoke user code! + and self.is_standard_new() + and SideEffects.cls_supports_mutation_side_effects(self.value) + and self.source + and not is_forbidden_context_manager(self.value) + ): + from . import TorchCtxManagerClassVariable + from .functions import ( + BaseUserFunctionVariable, + FunctionDecoratedByContextlibContextManagerVariable, + ) + + # graph break on any contextlib.* that it is not contextlib.contextmanager + # Some of the APIs below are not supported because they rely on features + # that Dynamo doesn't play well today (i.e. contextlib.suppress) + if self.value in ( + contextlib._AsyncGeneratorContextManager, + contextlib.closing, + contextlib.redirect_stdout, + contextlib.redirect_stderr, + contextlib.suppress, + contextlib.ExitStack, + contextlib.AsyncExitStack, + ): + # We are not changing the behavior of Dynamo as these function were + # already ignored on trace_rules.py before #136033 landed + unimplemented( + gb_type="unsupported contextlib.* API", + context=f"{self.value}", + explanation=f"{self.value} not supported. This may be due to its use of " + "context-specific operations that are not supported in " + "Dynamo yet (i.e. Exception handling)", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + if self.value is contextlib._GeneratorContextManager and isinstance( + args[0], (BaseUserFunctionVariable, TorchCtxManagerClassVariable) + ): + if not torch._dynamo.config.enable_trace_contextlib: + unimplemented( + gb_type="attempted to trace contextlib.contextmanager", + context=f"args={args}", + explanation="Tracing contextlib.contextmanager is disabled.", + hints=[ + "Set torch._dynamo.config.enable_trace_contextlib = True", + ], + ) + + # Special treatments for certain context managers created via + # contextlib, because + # 1. we (pytorch) own their impls + # 2. it's tedious to trace through them, so we effectively + # "allow in graph" them without sacrificing soundness. + # + # We would typically reach here via either + # 1. the instance construction in `with ctx_manager(...):`: + # https://github.com/python/cpython/blob/3.12/Lib/contextlib.py#L301 + # 2. calling a function decorated with a context manager: + # https://github.com/python/cpython/blob/3.12/Lib/contextlib.py#L122 + # + # So we basically trace through the surface part of the + # contextlib code, and then special case the shared remaining + # logic (the actual context manager instance construction and + # usage later on). + if isinstance(args[0], TorchCtxManagerClassVariable): + fn_var = args[0] + args_list = args[1].items + kwargs_dict = args[2].keys_as_python_constant() + return fn_var.call_function(tx, args_list, kwargs_dict) + + # Wrap UserFunctionVariable in FunctionDecoratedByContextlibContextManagerVariable + # if the function is annotated with @contextlib.contextmanager + # This shouldn't be necessary once generator functions are fully + # supported in dynamo + args = [ + FunctionDecoratedByContextlibContextManagerVariable( + args[0], source=args[0].source + ) + ] + args[1:] + + cm_obj = tx.output.side_effects.track_new_user_defined_object( + variables.BuiltinVariable(object), + self, + args, + ) + cm_obj.call_method(tx, "__init__", args, kwargs) + return cm_obj + elif is_namedtuple_cls(self.value): + fields = namedtuple_fields(self.value) + # check if this a quasi-namedtuple or a real one + if self.value.__module__ == "torch.return_types": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + "torch.return_types", + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + items = args[0].force_unpack_var_sequence(tx) + else: + field_defaults = self.value._field_defaults + + items = list(args) + items.extend([None] * (len(fields) - len(items))) + + var_tracker_kwargs = {} + for field_name, var_tracker in zip(fields, items): + if var_tracker is None: + if field_name in kwargs: + field_var = kwargs[field_name] + else: + assert field_name in field_defaults + field_var = VariableTracker.build( + tx, field_defaults[field_name] + ) + var_tracker_kwargs[field_name] = field_var + + for name, value in var_tracker_kwargs.items(): + assert name in fields + items[fields.index(name)] = value + + assert all(x is not None for x in items) + + # Modify mutability of namedtuple for sourcelesss instantiations. + from .base import AttributeMutationNew + + return variables.NamedTupleVariable( + items, self.value, mutation_type=AttributeMutationNew() + ) + elif self.value is torch.Size: + # This simulates `THPSize_pynew`, the C impl for `Size.__new__`. + tup = variables.BuiltinVariable(tuple).call_function(tx, args, kwargs) + return SizeVariable(tup.items) + elif is_frozen_dataclass(self.value) and self.is_standard_new(): + fields = dataclasses.fields(self.value) + fields_source = DataclassFieldsSource(self.source) + items = list(args) + items.extend([None] * (len(fields) - len(items))) + + default_kwargs = {} + for ind, field, var_tracker in zip(itertools.count(), fields, items): + if var_tracker is None: + if field.name in kwargs: + var_tracker = kwargs[field.name] + else: + if not field.init: + continue + + if field.default is not dataclasses.MISSING: + var_tracker = VariableTracker.build( + tx, + field.default, + source=AttrSource( + GetItemSource(fields_source, ind), "default" + ), + ) + elif field.default_factory is not dataclasses.MISSING: + factory_fn = VariableTracker.build( + tx, field.default_factory + ) + var_tracker = factory_fn.call_function(tx, [], {}) + else: + # if we are subclass, the constructor could possibly + # be missing args + continue + + default_kwargs[field.name] = var_tracker + kwargs.update(default_kwargs) + + var = tx.output.side_effects.track_new_user_defined_object( + variables.BuiltinVariable(object), self, args + ) + var.call_method(tx, "__init__", args, kwargs) + return var + elif ( + self.value in self._in_graph_classes() + or is_traceable_wrapper_subclass_type(self.value) + ): + # torch.LongTensor cannot accept a list of FakeTensors. + # So we stack the list of FakeTensors instead. + if ( + np + and self.value in tensortype_to_dtype + and len(args) == 1 + and isinstance(args[0], variables.ListVariable) + and len(args[0].items) > 1 + and all(x.is_tensor() for x in args[0].items) + ): + # Stack FakeTensor + stacked = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.stack, + *proxy_args_kwargs(args, kwargs), + ), + ) + args = [stacked] + + if issubclass(self.value, torch.Stream): + from .constant import ConstantVariable + from .lists import TupleVariable + + # Register newly created stream for reconstruction + var_kwargs = ConstDictVariable( + {ConstantVariable(k): v for k, v in kwargs.items()} + ) + var_args = TupleVariable(list(args)) + stream = self.value( + *(var_args.as_python_constant()), + **(var_kwargs.as_python_constant()), + ) + from ..graph_bytecode_inputs import register_graph_created_object + from .streams import StreamVariable + + ind = register_graph_created_object( + stream, + StreamVariable.make_construct_in_graph_stream_fn( + var_args, var_kwargs + ), + ) + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", get_external_object_by_index, (ind,), {} + ), + ) + elif issubclass(self.value, torch.Event): + from .constant import ConstantVariable + from .lists import TupleVariable + + # Register newly created event for reconstruction + var_kwargs = ConstDictVariable( + {ConstantVariable(k): v for k, v in kwargs.items()} + ) + var_args = TupleVariable(list(args)) + event = self.value( + *(var_args.as_python_constant()), + **(var_kwargs.as_python_constant()), + ) + from ..graph_bytecode_inputs import register_graph_created_object + from .streams import EventVariable + + ind = register_graph_created_object( + event, + EventVariable.make_construct_in_graph_event_fn( + var_args, var_kwargs + ), + ) + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", get_external_object_by_index, (ind,), {} + ), + ) + else: + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + *proxy_args_kwargs(args, kwargs), + ), + ) + + return tensor_variable + elif self.value is random.Random: + if len(args) == 1 and args[0].is_python_constant(): + seed = args[0].as_python_constant() + else: + seed = None + random_object = random.Random(seed) + return RandomVariable(random_object) + elif ( + self.value is types.MappingProxyType + and len(args) == 1 + and isinstance(args[0], variables.ConstDictVariable) + ): + # types.MappingProxyType is a read-only proxy of the dict. If the + # original dict changes, the changes are reflected in proxy as well. + return variables.MappingProxyVariable(args[0]) + elif SideEffects.cls_supports_mutation_side_effects(self.value) and self.source: + with do_not_convert_to_tracable_parameter(): + return tx.inline_user_function_return( + VariableTracker.build( + tx, polyfills.instantiate_user_defined_class_object + ), + [self, *args], + kwargs, + ) + return super().call_function(tx, args, kwargs) + + def is_standard_new(self): + """Check for __new__ being overridden""" + new_fn = inspect.getattr_static(self.value, "__new__", None) + if isinstance(new_fn, staticmethod): + new_fn = new_fn.__func__ + return new_fn is object.__new__ + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "ConstantVariable": + if self.source: + source = AttrSource(self.source, name) + install_guard(source.make_guard(GuardBuilder.HASATTR)) + return variables.ConstantVariable(hasattr(self.value, name)) + return super().call_obj_hasattr(tx, name) + + def const_getattr(self, tx: "InstructionTranslator", name): + if name == "__name__": + return self.value.__name__ + return super().const_getattr(tx, name) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return ( + isinstance(other, variables.UserDefinedClassVariable) + and self.value is other.value + ) + + +class UserDefinedExceptionClassVariable(UserDefinedClassVariable): + @property + def fn(self): + return self.value + + +class NO_SUCH_SUBOBJ: + pass + + +def call_random_fn(tx, fn, args, kwargs): + from .builder import VariableBuilder + + args = [x.as_python_constant() for x in args] + kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + random_call_index = len(tx.output.random_calls) + example_value = fn(*args, **kwargs) + source = RandomValueSource(random_call_index) + tx.output.random_calls.append((fn, args, kwargs)) + # TODO: arguably, this should route to wrap_symint/wrap_symfloat + # (currently hypothetical), but I'm not going to poke my hand in + # this nest for now + return VariableBuilder(tx, source).wrap_unspecialized_primitive(example_value) + + +class UserDefinedObjectVariable(UserDefinedVariable): + """ + Mostly objects of defined type. Catch-all for something where we only know the type. + """ + + _nonvar_fields = { + "value", + "value_type", + "attrs_directly_modifed_on_dict", + *UserDefinedVariable._nonvar_fields, + } + + def __init__( + self, + value, + *, + value_type=None, + cls_source=None, + base_cls_vt=None, + init_args=None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.value = value + self.value_type = value_type or type(value) + assert type(value) is self.value_type + # This is used with __new__, when the new object is sourceless but the user class can be sourceful. + self.cls_source = cls_source + if cls_source is None and self.source is not None: + self.cls_source = TypeSource(self.source) + + # These attributes are used to reconstruct the user defined object. The + # pseudo code looks like this. Builtin C __new__ do not support kwargs, + # so init_args is sufficient. + # obj = base_cls.__new__(user_cls, *args) + self.base_cls_vt = base_cls_vt + self.init_args = init_args + + # This records names of the attributes that were modified via instance + # `__dict__` directly, rather than the normal setattr path. + # + # TODO consider emulating `obj.__dict__` as a `ConstDictVariable` to get + # rid of these workarounds here and in `GetAttrVariable`. + self.attrs_directly_modifed_on_dict = set() + + import torch.utils._pytree as pytree + + self.is_pytree_constant_class = pytree.is_constant_class(self.value_type) + if pytree.is_constant_class(self.value_type) and self.source: + install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) + + def __str__(self) -> str: + inner = self.value_type.__name__ + if inner in [ + "builtin_function_or_method", + "getset_descriptor", + "method_descriptor", + "method", + ]: + inner = str(getattr(self.value, "__name__", None)) + return f"{self.__class__.__name__}({inner})" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.value_type.__name__})" + + def is_underlying_vt_modified(self, side_effects): + return False + + def python_type(self): + return self.value_type + + def as_python_constant(self): + if self.is_pytree_constant_class and self.source: + # NOTE pytree constants created in the torch.compile region will + # NOT be guarded (even though they have a source set) + return self.value + # TODO else try reconstructing the object by, e.g., leveraging side + # effects and `as_python_constant`. + return super().as_python_constant() + + def guard_as_python_constant(self): + if self.source: + install_guard(self.source.make_guard(GuardBuilder.ID_MATCH)) + return self.value + return super().guard_as_python_constant() + + def torch_function_check(self): + assert has_torch_function(self), ( + f"calling torch function on object without __torch_function__ {self}" + ) + + def get_torch_fn(self, tx): + self.torch_function_check() + from .torch_function import get_torch_function_fn + + return get_torch_function_fn(tx, self) + + def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): + self.torch_function_check() + + from .torch_function import call_torch_function + + return call_torch_function( + tx, + self.get_torch_fn(tx), + fn, + types, + args, + kwargs, + ) + + @staticmethod + @functools.cache + def _supported_random_functions(): + fns = { + random.random, + random.randint, + random.randrange, + random.uniform, + } + return fns + + def _maybe_get_baseclass_method(self, name): + if name not in getattr(self.value, "__dict__", {}): + try: + return inspect.getattr_static(type(self.value), name) + except AttributeError: + pass + return None + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ConstantVariable, UserMethodVariable + + method = self._maybe_get_baseclass_method(name) + if method is not None: + if method is object.__init__: + return ConstantVariable.create(None) + + if is_standard_setattr(method) or isinstance(self.value, threading.local): + return self.method_setattr_standard(tx, *args, **kwargs) + + if is_standard_delattr(method): + return self.method_setattr_standard( + tx, args[0], variables.DeletedVariable() + ) + + if method is object.__eq__ and len(args) == 1 and not kwargs: + other = args[0] + if not isinstance(other, UserDefinedObjectVariable): + return variables.ConstantVariable.create(NotImplemented) + + # TODO(anijain2305) - Identity checking should already be a part + # of the cmp_eq polyfill function. + return ConstantVariable.create(self.value is other.value) + + if torch._dynamo.config.enable_faithful_generator_behavior and isinstance( + self.value, types.GeneratorType + ): + unimplemented( + gb_type="call_method on generator", + context=f"object={self.value}, method={name}, args={args}, kwargs={kwargs}", + explanation="Detected a method call to a user-defined generator object. " + "This is not fully supported.", + hints=[ + "Set `torch._dynamo.config.enable_faithful_generator_behavior = False`. Note that this " + "may cause silent incorrectness, since we will eagerly unpack generators instead of lazily " + "evaluating them.", + ], + ) + + # check for methods implemented in C++ + if isinstance(method, types.FunctionType): + source = self.source + source_fn = None + if source: + source_fn = self.get_source_by_walking_mro(name) + # TODO(jansel): add a guard to check for monkey patching? + from ..mutation_guard import unpatched_nn_module_init + + if method is torch.nn.Module.__init__: + method = unpatched_nn_module_init + return UserMethodVariable( + method, self, source_fn=source_fn, source=source + ).call_function(tx, args, kwargs) + + if method is list.__len__ and self.source and not (args or kwargs): + install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + return ConstantVariable(len(self.value)) + + return super().call_method(tx, name, args, kwargs) + + def method_setattr_standard( + self, tx: "InstructionTranslator", name, value, directly_update_dict=False + ): + try: + name = name.as_python_constant() + except NotImplementedError: + unimplemented( + gb_type="non-const setattr name on user-defined object", + context=f"object={self}, name={name}, value={value}", + explanation="Detected a call to `setattr` of a user-defined object with a non-constant name.", + hints=["Ensure that the name is a string."], + ) + assert tx.output.side_effects.is_attribute_mutation(self), ( + "Attempted setattr on a user-defined object that does not have " + "an AttributeMutation mutation_type" + ) + + if directly_update_dict: + self.attrs_directly_modifed_on_dict.add(name) + else: + tmp = self.try_get_descritor_and_setter_py_func(name) + if tmp: + descriptor, setter = tmp + # Emulate + # https://github.com/python/cpython/blob/3.11/Objects/object.c#L1371-L1452 + desc_source = None + func_source = None + if self.cls_source: + desc_source = self.get_source_by_walking_mro(name) + # use `type(...)` to ignore instance attrs. + func_source = AttrSource(TypeSource(desc_source), "__set__") + desc_var = VariableTracker.build(tx, descriptor, desc_source) + func_var = VariableTracker.build(tx, setter, func_source) + args = [desc_var, self, value] + return func_var.call_function(tx, args, {}) + # NOTE: else we assume the descriptor (if any) has a + # side-effect-free `__set__` as far as Dynamo tracing is concerned. + + # Emulate the standard setattr on instance dict. + tx.output.side_effects.store_attr(self, name, value) + return variables.ConstantVariable(None) + + def needs_slow_setattr(self): + return not is_standard_setattr( + inspect.getattr_static(self.value, "__setattr__", None) + ) and not isinstance(self.value, threading.local) + + def unpack_var_sequence(self, tx): + if ( + self.source + and self._maybe_get_baseclass_method("__iter__") is list.__iter__ + and self._maybe_get_baseclass_method("__len__") is list.__len__ + and self._maybe_get_baseclass_method("__getitem__") is list.__getitem__ + ): + install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + return [ + variables.LazyVariableTracker.create( + self.value[k], + source=GetItemSource(self.source, k), + ) + for k in range(len(self.value)) + ] + return super().unpack_var_sequence(tx) + + def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + try: + variables.BuiltinVariable(iter).call_function(tx, [self], {}) + return True + except ObservedTypeError: + handle_observed_exception(tx) + return False + + def force_unpack_var_sequence(self, tx): + result = [] + iter_ = variables.BuiltinVariable(iter).call_function(tx, [self], {}) + + while True: + try: + r = iter_.next_variable(tx) + result.append(r) + except ObservedUserStopIteration: + handle_observed_exception(tx) + break + return result + + def next_variable(self, tx): + return self.call_method(tx, "__next__", [], {}) + + def is_supported_random(self): + try: + return self.value in self._supported_random_functions() + except TypeError: + # TypeError: unhashable type + return False + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if ( + self.is_supported_random() + and all(k.is_python_constant() for k in args) + and all(v.is_python_constant() for v in kwargs.values()) + ): + return call_random_fn(tx, self.value, args, kwargs) + elif istype(self.value, types.MethodType): + func = self.value.__func__ + obj = self.value.__self__ + if ( + func is torch.utils._contextlib._DecoratorContextManager.clone + and variables.TorchCtxManagerClassVariable.is_matching_cls( + obj.__class__ + ) + and not (args or kwargs) + ): + return variables.TorchCtxManagerClassVariable( + obj.__class__ + ).call_function(tx, args, kwargs) + + if ( + func is torch.autograd.grad_mode.inference_mode.clone + and obj.__class__ is torch.autograd.grad_mode.inference_mode + ): + # simulate the inference_mode.clone implementation + var = variables.ConstantVariable(obj.mode) + return variables.TorchCtxManagerClassVariable( + obj.__class__ + ).call_function(tx, [var], kwargs) + + if self.source is None: + unimplemented( + gb_type="attempted to call sourceless user-defined object as a method", + context=f"object={self.value}, function={func}, args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + f"Ensure the user-defined object {self.value} is constructed outside the compiled region.", + ], + ) + func_src = AttrSource(self.source, "__func__") + func_var = VariableTracker.build(tx, func, func_src) + obj_src = AttrSource(self.source, "__self__") + obj_var = VariableTracker.build(tx, obj, obj_src) + return func_var.call_function(tx, [obj_var] + args, kwargs) + elif callable(self.value): + if self.source: + source = AttrSource(self.cls_source, "__call__") + install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) + return self.call_method(tx, "__call__", args, kwargs) + + return super().call_function(tx, args, kwargs) + + def _check_for_getattr(self): + return get_custom_getattr(self.value) + + def _is_c_defined_property(self, subobj): + if not isinstance(subobj, property): + return False + + # pybind def_readwrite is implemented via PyCFunction. At the python level, it is visible as a property whose + # fget is an instancemethod wrapper - https://docs.python.org/3/c-api/method.html#c.PyInstanceMethod_Check + + # If we have a PyCFunction, we make an assumption that there is no side effect. + return isinstance( + subobj.fget, types.BuiltinFunctionType + ) or torch._C._dynamo.utils.is_instancemethod(subobj.fget) + + def _getattr_static(self, name): + subobj = inspect.getattr_static(self.value, name, NO_SUCH_SUBOBJ) + + # In some cases, we have to do dynamic lookup because getattr_static is not enough. For example, threading.local + # has side-effect free __getattribute__ and the attribute is not visible without a dynamic lookup. + # NOTE we assume the following descriptors are side-effect-free as far + # as Dynamo tracing is concerned. + if not object_has_getattribute(self.value) and ( + subobj is NO_SUCH_SUBOBJ # e.g., threading.local + or inspect.ismemberdescriptor(subobj) # e.g., __slots__ + or inspect.isgetsetdescriptor(subobj) # e.g., __dict__ + or self._is_c_defined_property(subobj) + ): + # Call __getattribute__, we have already checked that this is not overridden and side-effect free. We don't + # want to call getattr because it can be user-overridden. + subobj = type(self.value).__getattribute__(self.value, name) + elif object_has_getattribute(self.value) and subobj is NO_SUCH_SUBOBJ: + # If the object has an overridden getattribute method, Dynamo has + # already tried tracing it, and encountered an AttributeError. We + # call getattr_static only when the __getattribute__ tracing fails + # (check var_getattr impl). So, it is safe here to raise the + # AttributeError. + raise AttributeError + + return subobj + + def should_skip_descriptor_setter(self, attr_name): + # Check if `attr_name` corresponds to a descriptor. + descriptor = inspect.getattr_static(type(self.value), attr_name, None) + setter = inspect.getattr_static(type(descriptor), "__set__", None) + if setter: + # Skip if `__set__` was traceable (no need to redo the side effect). + if inspect.isfunction(setter): + return True + # For untraceable `__set__` we should still skip if the attribute + # was mutated via instance `__dict__`. + elif attr_name in self.attrs_directly_modifed_on_dict: + return True + return False + + def try_get_descritor_and_setter_py_func(self, attr_name): + descriptor = inspect.getattr_static(type(self.value), attr_name, None) + setter = inspect.getattr_static(type(descriptor), "__set__", None) + if inspect.isfunction(setter): + return (descriptor, setter) + return None + + def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): + if tx.output.side_effects.has_pending_mutation_of_attr(self, key): + mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) + return not isinstance(mutated_attr, variables.DeletedVariable) + + return key in self.value.__dict__ + + def get_source_by_walking_mro(self, name): + assert self.cls_source is not None + + for idx, klass in enumerate(type(self.value).__mro__): + if name in klass.__dict__: + if idx != 0: + mro_source = TypeMROSource(self.cls_source) + klass_source = GetItemSource(mro_source, idx) + else: + klass_source = self.cls_source + dict_source = TypeDictSource(klass_source) + out_source = DictGetItemSource(dict_source, name) + + for absent_idx in range(1, idx): + # Insert a guard that the name is not present in the mro hierarchy + mro_source = TypeMROSource(self.cls_source) + klass_source = GetItemSource(mro_source, absent_idx) + dict_source = TypeDictSource(klass_source) + install_guard( + dict_source.make_guard( + functools.partial( + GuardBuilder.DICT_CONTAINS, key=name, invert=True + ) + ) + ) + # Insert a guard that the name is not present in the object __dict__ + if ( + self.source + and hasattr(self.value, "__dict__") + and name not in self.value.__dict__ + ): + install_guard( + self.source.make_guard( + functools.partial( + GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr=name + ) + ) + ) + return out_source + + unimplemented( + gb_type="could not find name in object's mro", + context=f"name={name}, object type={type(self.value)}, mro={type(self.value).__mro__}", + explanation=f"Could not find name `{name}` in mro {type(self.value).__mro__}", + hints=[ + f"Ensure the name `{name}` is defined somewhere in {self.value}'s type hierarchy.", + *graph_break_hints.USER_ERROR, + ], + ) + + def var_getattr(self, tx: "InstructionTranslator", name): + from . import ConstantVariable + + source = AttrSource(self.source, name) if self.source else None + + if object_has_getattribute(self.value): + getattribute_fn = inspect.getattr_static( + type(self.value), "__getattribute__" + ) + if self.source: + new_source = AttrSource(self.source, "__getattribute__") + try: + return variables.UserMethodVariable( + getattribute_fn, self, source=new_source + ).call_function(tx, [ConstantVariable.create(name)], {}) + except ObservedAttributeError: + # Pass through to __getattr__ if __getattribute__ fails + handle_observed_exception(tx) + + if tx.output.side_effects.has_pending_mutation_of_attr(self, name): + result = tx.output.side_effects.load_attr(self, name, deleted_ok=True) + if isinstance(result, variables.DeletedVariable): + raise_observed_exception( + AttributeError, + tx, + args=[ + f"'{type(self.value).__name__}' object has no attribute '{name}'" + ], + ) + return result + + if name == "__dict__": + options = {"source": source} + return variables.GetAttrVariable(self, name, **options) + + # TODO(anijain2305) - Investigate if we need specialization for more + # dunder attrs. inspect.getattr_static does not return correct value for + # them. + if name == "__class__": + cls_source = source + if cls_source is None: + cls_source = self.cls_source + options = {"source": cls_source} + return UserDefinedClassVariable(type(self.value), **options) + + try: + subobj = self._getattr_static(name) + except AttributeError: + subobj = NO_SUCH_SUBOBJ + getattr_fn = self._check_for_getattr() + if isinstance(getattr_fn, types.FunctionType): + # Dynamo is going to trace the __getattr__ function with + # args=name. Set the source accordingly. + if ( + getattr_fn is unpatched_nn_module_getattr + and isinstance(self, variables.UnspecializedNNModuleVariable) + # prevent against overwriting of params/buffers/submodules + and istype(self.value._parameters, dict) + and istype(self.value._buffers, dict) + and istype(self.value._modules, dict) + ): + # Manually trace out the nn module __getattr__ to avoid large compilation latency. + out = self.manually_trace_nn_module_getattr(tx, name) + else: + new_source = None + if self.source: + new_source = AttrSource(self.source, "__getattr__") + out = variables.UserMethodVariable( + getattr_fn, self, source=new_source + ).call_function(tx, [ConstantVariable.create(name)], {}) + + if self.source and getattr_fn is torch.nn.Module.__getattr__: + if isinstance( + out, + ( + variables.UnspecializedNNModuleVariable, + variables.NNModuleVariable, + ), + ): + # nn_module_stack source is BC surface area. Ensure that + # mod._modules["linear"] is reflected as mod.linear for + # nn_module_stack. + out.set_nn_module_stack_source( + AttrSource(self.get_nn_module_stack_source(), name) + ) + return out + + elif getattr_fn is not None: + unimplemented( + gb_type="User-defined object with non-function __getattr__", + context=f"object={self.value}, name={name}, getattr_fn={getattr_fn}", + explanation=f"Found a non-function __getattr__ {getattr_fn} from a user-defined object {self.value} " + f" when attempting to getattr `{name}`", + hints=[ + "Ensure the object's __getattr__ is a function type.", + ], + ) + + from ..mutation_guard import unpatched_nn_module_init + + if subobj is torch.nn.Module.__init__: + subobj = unpatched_nn_module_init + + subobj_from_class = inspect.getattr_static( + self.value.__class__, name, NO_SUCH_SUBOBJ + ) + is_accessible_from_type_mro = ( + subobj_from_class is subobj + and self.cls_source is not None + and self.source is not None + and hasattr(self.value, "__dict__") + and name not in self.value.__dict__ + ) + + if isinstance(subobj, property): + if self.source: + # Read the class attribute to reach the property + source = self.get_source_by_walking_mro(name) + # Get the getter function + source = AttrSource(source, "fget") + + fget_vt = VariableTracker.build(tx, subobj.fget, source=source) + return fget_vt.call_function(tx, [self], {}) + elif isinstance(subobj, _collections._tuplegetter): + # namedtuple fields are represented by _tuplegetter, and here we + # emulate its `__get__`, which is implemented in C. + _, (idx, _) = subobj.__reduce__() + # Don't go through the `__getitem__` method anymore, see + # https://github.com/python/cpython/blob/470941782f74288823b445120f6383914b659f23/Modules/_collectionsmodule.c#L2690 + assert isinstance(self, UserDefinedTupleVariable) + return self._tuple_vt.items[idx] + elif isinstance(subobj, staticmethod): + # Safe because `staticmethod.__get__` basically won't trigger user + # code and just returns the underlying `__func__`: + # https://github.com/python/cpython/blob/3.11/Objects/funcobject.c#L1088-L1100 + if is_accessible_from_type_mro: + # Accessing from __dict__ does not resolve the descriptor, it + # returns a staticmethod object, so access the __func__ + # attribute to get to the actual function. + source = AttrSource(self.get_source_by_walking_mro(name), "__func__") + func = subobj.__get__(self.value) + return VariableTracker.build(tx, func, source) + elif isinstance(subobj, classmethod): + source_fn = None + if is_accessible_from_type_mro: + # Accessing from __dict__ does not resolve the descriptor, it + # returns a classmethod object, so access the __func__ + # attribute to get to the actual function. + source_fn = AttrSource(self.get_source_by_walking_mro(name), "__func__") + return variables.UserMethodVariable( + subobj.__func__, + self.var_getattr(tx, "__class__"), + source_fn=source_fn, + source=source, + ) + elif isinstance(subobj, types.ClassMethodDescriptorType): + # e.g.: inspect.getattr_static({}, "fromkeys") + func = subobj.__get__(self.value, None) + return VariableTracker.build(tx, func, source) + elif is_lru_cache_wrapped_function(subobj): + # getattr_static returns the lru_wrapped function, and we cannot + # extract the underlying method from the wrapped function. To handle + # it, manually create a wrapped user method vt. + return variables.WrapperUserMethodVariable( + subobj, "__wrapped__", self, source=source + ) + elif inspect.getattr_static( + type(subobj), "__get__", NO_SUCH_SUBOBJ + ) is not NO_SUCH_SUBOBJ and not is_wrapper_or_member_descriptor( + type(subobj).__get__ + ): + # Emulate https://github.com/python/cpython/blob/3.11/Objects/object.c#L1271-L1285 + # + # Attribute has a __get__ method. Create a user defined object vt + # for the subobj, and then trace the __get__ method. + descriptor_source = None + descriptor_get_source = None + if self.cls_source: + # To access the method descriptor from the udf object w/o using + # inspect.getattr_static, we can look into the class mro + descriptor_source = self.get_source_by_walking_mro(name) + descriptor_get_source = AttrSource( + TypeSource(descriptor_source), "__get__" + ) + descriptor_var = VariableTracker.build(tx, subobj, descriptor_source) + else: + # Sourceless Builder does not support user defined objects + descriptor_var = UserDefinedObjectVariable(subobj) + + # The arguments of the __get__ function are (self, instance, owner) + # self - descriptor_var + # instance - instance of the class, represented by self here + # owner - class object + owner_var = UserDefinedClassVariable(type(self.value)) + return variables.UserMethodVariable( + subobj.__get__.__func__, descriptor_var, source=descriptor_get_source + ).call_function(tx, [self, owner_var], {}) + elif isinstance(subobj, types.FunctionType) or ( + isinstance(subobj, types.MethodType) + and isinstance(self.value, torch.nn.Module) + ): + # Since we get subobj via self._getattr_static, which may not trigger dynamic lookup. + # Static lookup can't tell us it's a method or function correctly, + # so we trigger dynamic lookup here to get the correct type. + dynamic_subobj = getattr(self.value, name) + + while dynamic_subobj is subobj and hasattr(subobj, "_torchdynamo_inline"): + subobj = subobj._torchdynamo_inline + dynamic_subobj = subobj + source = AttrSource(source, "_torchdynamo_inline") if source else None + + if isinstance(subobj, types.MethodType): + if dynamic_subobj.__self__ is not self.value: + if not isinstance(dynamic_subobj.__func__, types.FunctionType): + unimplemented( + gb_type="User-defined object method with non-function __func__", + context=f"object={self.value}, name={name}, method={dynamic_subobj}, " + f"method.__self__={dynamic_subobj.__self__}, method.__func__={dynamic_subobj.__func__}", + explanation=f"Method {dynamic_subobj} (name={name}) of user-defined object {self.value} has a " + f"__func__ ({dynamic_subobj.__func__}) that is not a function type.", + hints=[ + "Ensure that the method's __func__ is a function type.", + ], + ) + + # Use the __self__ attribute of the method to find the + # source of the new self object. + self_source = None + if source is not None: + self_source = AttrSource(source, "__self__") + object_vt = VariableTracker.build( + tx, dynamic_subobj.__self__, self_source + ) + + return variables.UserMethodVariable( + dynamic_subobj.__func__, object_vt + ) + func = subobj.__func__ + else: + assert isinstance(subobj, types.FunctionType) + func = subobj + + if inspect.ismethod(dynamic_subobj): + source_fn = None + if is_accessible_from_type_mro: + source_fn = self.get_source_by_walking_mro(name) + return variables.UserMethodVariable( + func, self, source_fn=source_fn, source=source + ) + elif inspect.isfunction(dynamic_subobj): + return VariableTracker.build(tx, func, source) + + if ( + # wrap the source only if inline_inbuilt_nn_modules is set or fsdp modules. This is a temporary solution to + # keep Dynamo behavior compatible with no inlining, as there will be some delay to turn on the flag in + # fbcode. + ( + torch._dynamo.config.inline_inbuilt_nn_modules + or isinstance(self, variables.FSDPManagedNNModuleVariable) + ) + and source + and isinstance(self, variables.UnspecializedNNModuleVariable) + # export has some awkwardness around specialized and unspecialized modules. Skip wrapping source for export + # usecase for now. + and (not tx.output.export or torch._dynamo.config.install_free_tensors) + ): + # Recalculate source for params/buffers + if name in ("_buffers", "_parameters"): + source = UnspecializedParamBufferSource(self.source, name) + source = self._wrap_source(source) + + if subobj is not NO_SUCH_SUBOBJ: + if ( + is_wrapper_or_member_descriptor(subobj) + or torch._C._dynamo.utils.is_instancemethod(subobj) + or is_cython_function(subobj) + ): + options = {"source": source} + return variables.GetAttrVariable(self, name, **options) + if source: + if is_accessible_from_type_mro: + source = self.get_source_by_walking_mro(name) + + return variables.LazyVariableTracker.create(subobj, source) + else: + # Check if the subobj is accessible from the class itself. If the class source is known, we can create a + # sourceful variable tracker. + if self.cls_source is not None: + subobj_from_class = inspect.getattr_static( + self.value.__class__, name, NO_SUCH_SUBOBJ + ) + if subobj_from_class is subobj: + src_from_class = AttrSource(self.cls_source, name) + return variables.LazyVariableTracker.create( + subobj_from_class, src_from_class + ) + + return VariableTracker.build(tx, subobj) + + # Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError. + raise_observed_exception( + AttributeError, + tx, + args=[f"'{type(self.value).__name__}' object has no attribute '{name}'"], + ) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + if self.source: + install_guard( + AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) + ) + + try: + var_vt = self.var_getattr(tx, name) + return variables.ConstantVariable.create( + not isinstance(var_vt, variables.DeletedVariable) + ) + except ObservedAttributeError: + handle_observed_exception(tx) + return variables.ConstantVariable.create(False) + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return True + + def get_python_hash(self): + # default hash + return hash(self.value) + + def is_python_equal(self, other): + # id check + return self.value is other.value + + +class FrozenDataClassVariable(UserDefinedObjectVariable): + @staticmethod + def create(tx, value, source): + from dataclasses import fields + + assert is_frozen_dataclass(value) + + field_map = {} + for field in fields(value): + if hasattr(value, field.name): + field_map[field.name] = VariableTracker.build( + tx, + getattr(value, field.name), + source and AttrSource(source, field.name), + ) + + return FrozenDataClassVariable(value, fields=field_map, source=source) + + def __init__(self, value, fields=None, **kwargs) -> None: + super().__init__(value, **kwargs) + if fields is None: + fields = {} + self.fields = fields + + def as_python_constant(self): + # NOTE: this is an intentionally limited version of + # `as_python_constant` for `nonstrict_trace` implementation. + from dataclasses import fields + + import torch.utils._pytree as pytree + + if not istype( + self.value, (pytree.TreeSpec, pytree.LeafSpec, pytree.ConstantNode) + ): + # TODO loosen this restriction and fix `as_proxy`. + raise NotImplementedError( + "currently can't reconstruct arbitrary frozen dataclass instances" + ) + + # LeafSpec is deprecated, use treespec_leaf() instead + if istype(self.value, pytree.LeafSpec): + return pytree.treespec_leaf() + + args = [] + kwargs = {} + for field in fields(self.value): + if field.init: + data = self.fields[field.name].as_python_constant() + if getattr(field, "kw_only", False): + kwargs[field.name] = data + else: + args.append(data) + + # This is safe because we know the TreeSpec classes constructors don't + # have external side effects. + ctor = self.python_type() + return ctor(*args, **kwargs) + + def as_proxy(self): + from dataclasses import fields + + args = [] + kwargs = {} + for field in fields(self.value): + proxy = self.fields[field.name].as_proxy() + if hasattr(field, "kw_only") and field.kw_only: + kwargs[field.name] = proxy + else: + args.append(proxy) + + # TODO this isn't really safe, because + # 1. it could invoke a user defined `__post_init__`. + # 2. it could invoke a user defined `__init__` if the class _subclasses_ + # a frozen dataclass. + # Either of the above could end up mutating external state. + ctor = self.python_type() + return ctor(*args, **kwargs) + + def reconstruct(self, codegen: "PyCodegen") -> None: + from dataclasses import fields + + # Handle specific pytree classes + import torch.utils._pytree as pytree + + if isinstance(self.value, pytree.TreeSpec) and self.value.is_leaf(): + # Create a new LeafSpec instance by calling the constructor + codegen.add_push_null( + lambda: codegen.load_import_from("torch.utils._pytree", "LeafSpec") + ) + codegen.extend_output(create_call_function(0, False)) + return + + # For general frozen dataclasses, reconstruct by calling the constructor + # with the field values as arguments + dataclass_cls = self.python_type() + + if hasattr(dataclass_cls, "__post_init__"): + unimplemented( + gb_type="Frozen dataclass with __post_init__", + context=f"dataclass={dataclass_cls.__name__}", + explanation="Cannot reconstruct frozen dataclass with __post_init__ method, " + "as it may have side effects that would be incorrectly replayed.", + hints=[ + "Remove the __post_init__ method from the frozen dataclass.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + # Collect positional and keyword-only arguments + pos_args = [] + kw_args = [] + for field in fields(dataclass_cls): + if not field.init: + continue + field_vt = self.fields.get(field.name) + if field_vt is None: + unimplemented( + gb_type="Frozen dataclass with missing field", + context=f"dataclass={dataclass_cls.__name__}, field={field.name}", + explanation=f"Cannot reconstruct frozen dataclass: field '{field.name}' " + "was not tracked during tracing.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + if getattr(field, "kw_only", False): + kw_args.append((field.name, field_vt)) + else: + pos_args.append(field_vt) + + # Load the dataclass constructor + codegen.add_push_null( + lambda: codegen.append_output( + codegen.create_load_const_unchecked(dataclass_cls) + ) + ) + # Reconstruct all arguments + for arg_vt in pos_args: + codegen(arg_vt) + for _, arg_vt in kw_args: + codegen(arg_vt) + # Call the constructor + total_args = len(pos_args) + len(kw_args) + if kw_args: + kw_names = tuple(name for name, _ in kw_args) + codegen.extend_output( + codegen.create_call_function_kw(total_args, kw_names, push_null=False) + ) + else: + codegen.extend_output(create_call_function(total_args, False)) + + # NB: This is called during __init__ for a frozen dataclass + # use this to accumulate the most up-to-date field values + def method_setattr_standard(self, tx: "InstructionTranslator", name, value): + self.fields[name.as_python_constant()] = value + return super().method_setattr_standard(tx, name, value) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.value_type.__name__})" + + def is_python_hashable(self): + # TODO - Check corner cases like eq=False, hash=False etc + return True + + def get_python_hash(self): + return hash(tuple(arg.get_python_hash() for arg in self.fields.values())) + + def is_python_equal(self, other): + is_class_same = self.python_type() is other.python_type() + is_field_name_same = self.fields.keys() == other.fields.keys() + is_field_value_same = all( + value_a.is_python_equal(value_b) + for value_a, value_b in zip(self.fields.values(), other.fields.values()) + ) + return is_class_same and is_field_name_same and is_field_value_same + + +class SourcelessGraphModuleVariable(UserDefinedObjectVariable): + def __init__( + self, + value, + **kwargs, + ) -> None: + super().__init__(value, **kwargs) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + fn_variable = VariableTracker.build(tx, self.value.forward.__func__) + args = [self] + args + return tx.inline_user_function_return( + fn_variable, + args, + kwargs, + ) + + +class UserDefinedExceptionObjectVariable(UserDefinedObjectVariable): + def __init__(self, value, **kwargs): + super().__init__(value, **kwargs) + self.exc_vt = variables.ExceptionVariable(self.value_type, ()) + + @property + def fn(self): + return self.value_type + + def call_method(self, tx, name, args, kwargs): + if ( + name == "__init__" + and (method := self._maybe_get_baseclass_method(name)) + and inspect.ismethoddescriptor(method) + and len(kwargs) == 0 + ): + self.exc_vt.args = args + self.value.args = args + return variables.ConstantVariable(None) + elif ( + name == "__setattr__" + and len(args) == 2 + and args[0].is_constant_match( + "__cause__", "__context__", "__suppress_context__", "__traceback__" + ) + ): + self.exc_vt.call_setattr(tx, args[0], args[1]) + elif name == "with_traceback": + return self.exc_vt.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + @property + def __context__(self): + return self.exc_vt.__context__ + + @property + def args(self): + return self.exc_vt.args + + def set_context(self, context: "variables.ExceptionVariable"): + return self.exc_vt.set_context(context) + + @property + def exc_type(self): + return self.exc_vt.exc_type + + +class KeyedJaggedTensorVariable(UserDefinedObjectVariable): + @staticmethod + def is_matching_object(obj): + mod = sys.modules.get("torchrec.sparse.jagged_tensor") + return mod is not None and type(obj) is mod.KeyedJaggedTensor + + def __init__(self, value, **kwargs) -> None: + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + assert type(value) is KeyedJaggedTensor + super().__init__(value, **kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name): + if ( + torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt + and self.source is not None + and name in ("_length_per_key", "_offset_per_key") + ): + with TracingContext.patch(force_unspec_int_unbacked_size_like=True): + return super().var_getattr(tx, name) + return super().var_getattr(tx, name) + + +class IntWrapperVariable(UserDefinedObjectVariable): + # Dummy class to check if the object is an IntWrapper, and turn it into a + # symint + @staticmethod + def is_matching_object(obj): + mod = sys.modules.get("torch.export.dynamic_shapes") + return mod is not None and type(obj) is mod._IntWrapper + + +class RemovableHandleClass: + # Dummy class to pass to python_type of RemovableHandleVariable + # Useful for isinstance check on hooks + pass + + +class RemovableHandleVariable(VariableTracker): + REMOVED = -1 + + def __init__( + self, + mutation_type=None, + # index of the registration in the side_effects owned register_hook/handle list, used during removal. + idx=None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.mutation_type = mutation_type + self.idx = idx + + def call_method(self, tx: "InstructionTranslator", method_name, args, kwargs): + if method_name == "remove": + if self.idx != self.REMOVED: + tx.output.side_effects.remove_hook(self.idx) + self.idx = self.REMOVED + return variables.ConstantVariable.create(None) + super().call_method(tx, method_name, args, kwargs) + + def reconstruct(self, codegen: "PyCodegen"): + if self.idx == self.REMOVED: + # Hook has already been removed, return a dummy handle + codegen.add_push_null( + lambda: codegen.load_import_from( + "torch._dynamo.utils", "invalid_removeable_handle" + ) + ) + codegen.extend_output(create_call_function(0, False)) + return + # unreachable due to codegen.add_cache() when the hook is installed + super().reconstruct(codegen) + + def python_type(self): + return RemovableHandleClass + + +class UserDefinedDictVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of dict/OrderedDict. + + Internally, it uses a ConstDictVariable to represent the dict part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + def __init__(self, value, dict_vt=None, **kwargs): + super().__init__(value, **kwargs) + self._dict_vt = dict_vt + if self._dict_vt is None: + assert self.source is None, ( + "dict_vt must be constructed by builder.py when source is present" + ) + self._dict_vt = variables.ConstDictVariable( + {}, type(value), mutation_type=ValueMutationNew() + ) + self._dict_methods = dict_methods + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + method = self._maybe_get_baseclass_method(name) + if method in self._dict_methods: + # Dict subclasses can override __missing__ to provide fallback + # behavior instead of raising a KeyError. This is used, for example, + # by collections.Counter. + try: + return self._dict_vt.call_method(tx, name, args, kwargs) + except ObservedKeyError: + if ( + name == "__getitem__" + and issubclass(self.python_type(), dict) + and self._maybe_get_baseclass_method("__missing__") + ): + return self.call_method(tx, "__missing__", args, kwargs) + else: + raise + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx): + if type(self.value).__iter__ in ( + dict.__iter__, + collections.OrderedDict.__iter__, + ): + return self._dict_vt.unpack_var_sequence(tx) + raise NotImplementedError + + def is_underlying_vt_modified(self, side_effects): + return side_effects.is_modified(self._dict_vt) + + @property + def user_cls(self): + return self._dict_vt.user_cls + + @property + def items(self): + return self._dict_vt.items + + def install_dict_keys_match_guard(self): + return self._dict_vt.install_dict_keys_match_guard() + + def install_dict_contains_guard(self): + return self._dict_vt.install_dict_contains_guard() + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return False + + +class UserDefinedSetVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of set. + + Internally, it uses a SetVariable to represent the set part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + def __init__(self, value, set_vt=None, **kwargs): + super().__init__(value, **kwargs) + self._set_vt = set_vt + + python_type = set if isinstance(value, set) else frozenset + self._set_methods = set_methods if python_type is set else frozenset_methods + + if self._set_vt is None: + assert self.source is None, ( + "set_vt must be constructed by builder.py when source is present" + ) + if python_type is set: + # set is initialized later + self._set_vt = variables.SetVariable( + {}, mutation_type=ValueMutationNew() + ) + else: + init_args = kwargs.get("init_args", {}) + tx = torch._dynamo.symbolic_convert.InstructionTranslator.current_tx() + self._set_vt = variables.BuiltinVariable(python_type).call_function( + tx, init_args, {} + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + method = self._maybe_get_baseclass_method(name) + if method in self._set_methods: + return self._set_vt.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + def as_python_constant(self): + return self._set_vt.as_python_constant() + + def unpack_var_sequence(self, tx): + if inspect.getattr_static(self.value, "__iter__") in ( + set.__iter__, + frozenset.__iter__, + ): + return self._set_vt.unpack_var_sequence(tx) + raise NotImplementedError + + @property + def set_items(self): + return self._set_vt.set_items + + @property + def items(self): + return self._set_vt.items + + def is_underlying_vt_modified(self, side_effects): + return side_effects.is_modified(self._set_vt) + + def install_dict_keys_match_guard(self): + return self._set_vt.install_dict_keys_match_guard() + + def install_dict_contains_guard(self): + return self._set_vt.install_dict_contains_guard() + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return self._set_vt.is_python_hashable() + + def get_python_hash(self): + return self._set_vt.get_python_hash() + + def is_python_equal(self, other): + return isinstance( + other, UserDefinedSetVariable + ) and self._set_vt.is_python_equal(other._set_vt) + + +class UserDefinedListVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of lists. + + Internally, it uses a ListVariable to represent the list part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + def __init__(self, value, list_vt=None, **kwargs): + super().__init__(value, **kwargs) + self._list_vt = list_vt + if self._list_vt is None: + assert self.source is None, ( + "list_vt must be constructed by builder.py when source is present" + ) + self._list_vt = variables.ListVariable([], mutation_type=ValueMutationNew()) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + assert self._list_vt is not None + method = self._maybe_get_baseclass_method(name) + if method in list_methods: + return self._list_vt.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx): + assert self._list_vt is not None + if type(self.value).__iter__ is list.__iter__: + return self._list_vt.unpack_var_sequence(tx) + raise NotImplementedError + + def is_underlying_vt_modified(self, side_effects): + return side_effects.is_modified(self._list_vt) + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return False + + +class UserDefinedTupleVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of tuple. + + Internally, it uses a TupleVariable to represent the tuple part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + def __init__(self, value, tuple_vt=None, init_args=None, **kwargs): + super().__init__(value, init_args=init_args, **kwargs) + self._tuple_vt = tuple_vt + if self._tuple_vt is None: + assert self.source is None, ( + "tuple_vt must be constructed by builder.py when source is present" + ) + # Emulate `tuple.__new__` + # https://github.com/python/cpython/blob/3.11/Objects/tupleobject.c#L697-L710 + # + # TODO this duplicates the logic in `BuiltinVariable(tuple)` + from torch._dynamo.symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + elems = init_args[0].force_unpack_var_sequence(tx) + self._tuple_vt = variables.TupleVariable( + elems, mutation_type=ValueMutationNew() + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + assert self._tuple_vt is not None + method = self._maybe_get_baseclass_method(name) + if method in tuple_methods: + return self._tuple_vt.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx): + assert self._tuple_vt is not None + if type(self.value).__iter__ is tuple.__iter__: + return self._tuple_vt.unpack_var_sequence(tx) + raise NotImplementedError + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return self._tuple_vt.is_python_hashable() + + def get_python_hash(self): + return self._tuple_vt.get_python_hash() + + def is_python_equal(self, other): + return isinstance( + other, UserDefinedTupleVariable + ) and self._tuple_vt.is_python_equal(other._tuple_vt) + + +class MutableMappingVariable(UserDefinedObjectVariable): + def __init__(self, value, **kwargs): + super().__init__(value, **kwargs) + self.generic_dict_vt = variables.ConstDictVariable({}) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + # A common pattern in the init code of MutableMapping objects is to + # update the __dict__ attribute. To prevent graph break, we directly + # return a ConstDictVariable for the __dict__attr. + # + # However, users can try to add a new attribute to the class using the + # __dict__ attribute. To catch this, we save the ConstDictVariable for + # the __dict__ and then lookup into this vt for each attr lookup. + if name == "get" and type(self.value).get in ( + collections.abc.Mapping.get, + dict.get, + ): + return variables.UserMethodVariable(polyfills.mapping_get, self) + elif name == "__dict__" and self.source: + self.generic_dict_vt = variables.LazyVariableTracker.create( + self.value.__dict__, AttrSource(self.source, "__dict__") + ) + return self.generic_dict_vt + elif out := self.generic_dict_vt.maybe_getitem_const( + variables.ConstantVariable(name) + ): + return out + else: + return super().var_getattr(tx, name) + + +class RandomVariable(UserDefinedObjectVariable): + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10a55772ab58b21573a6eba0356ddd3080164ac7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/ac_logging_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/ac_logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b629d43ef3b5d9734cf2fc6bf1502026d30c0c30 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/ac_logging_utils.py @@ -0,0 +1,190 @@ +import json +import logging +from typing import Any + +from torch._logging import trace_structured +from torch.fx import Graph, Node + + +log: logging.Logger = logging.getLogger(__name__) + + +def create_joint_graph_node_information( + joint_graph: Graph, + recomputable_node_info: dict[str, int], +) -> dict[str, Any]: + joint_graph_node_information: dict[str, Any] = {} + + for i, joint_graph_node in enumerate(joint_graph.nodes): + is_recomputable_candidate: bool = ( + joint_graph_node.name in recomputable_node_info + ) + tensor_meta = joint_graph_node.meta.get("tensor_meta") + shape = getattr(tensor_meta, "shape", []) if tensor_meta else [] + + node_info: dict[str, Any] = { + "index": i, + "name": joint_graph_node.name, + "is_recomputable_candidate": is_recomputable_candidate, + "target": str(joint_graph_node.target), + "shape": str(shape), + "input_arguments": [inp.name for inp in joint_graph_node.all_input_nodes], + "stack_trace": joint_graph_node.meta.get("stack_trace", ""), + } + + if is_recomputable_candidate: + idx: int = recomputable_node_info[joint_graph_node.name] + node_info["recomputable_candidate_info"] = { + "recomputable_node_idx": idx, + } + + joint_graph_node_information[joint_graph_node.name] = node_info + + return joint_graph_node_information + + +def create_joint_graph_edges(joint_graph: Graph) -> list[tuple[str, str]]: + joint_graph_edges: list[tuple[str, str]] = [ + (inp.name, node.name) + for node in joint_graph.nodes + for inp in node.all_input_nodes + ] + return joint_graph_edges + + +def create_activation_checkpointing_logging_structure_payload( + joint_graph: Graph, + joint_graph_node_information: dict[str, Any], + joint_graph_edges: list[tuple[str, str]], + all_recomputable_banned_nodes: list[Node], + expected_runtime: float, + saved_node_idxs: list[int], + recomputable_node_idxs: list[int], + memories_banned_nodes: list[int], + normalized_memories_banned_nodes: list[float], + runtimes_banned_nodes: list[float], + min_cut_saved_values: list[Node], +) -> dict[str, Any]: + """ + Creates a structured payload for logging activation checkpointing information. + + Args: + joint_graph: The computational graph representing operations. + joint_graph_node_information: Dictionary containing information about nodes in the joint graph. + joint_graph_edges: List of edges in the joint graph represented as tuples of node names. + all_recomputable_banned_nodes: List of nodes that are banned from recomputation. + expected_runtime: Expected runtime of the computation. + saved_node_idxs: Indices of nodes that are saved (not recomputed). + recomputable_node_idxs: Indices of nodes that can be recomputed. + memories_banned_nodes: Memory usage values (in absolute units) for banned nodes. + normalized_memories_banned_nodes: Normalized memory usage values for banned nodes, + used as input to the knapsack algorithm. + runtimes_banned_nodes: Runtime values for banned nodes, used as input to the + knapsack algorithm. + min_cut_saved_values: List of nodes saved by the min-cut algorithm. + + Returns: + A dictionary containing structured logging information for activation checkpointing. + """ + activation_checkpointing_logging_structure_payload: dict[str, Any] = { + "Joint Graph Size": len(joint_graph.nodes), + "Joint Graph Edges": { + "Total": len(joint_graph_edges), + "Edges": joint_graph_edges, + }, + "Joint Graph Node Information": joint_graph_node_information, + "Recomputable Banned Nodes Order": [ + node.name for node in all_recomputable_banned_nodes + ], + "Expected Runtime": expected_runtime, + "Knapsack Saved Nodes": saved_node_idxs, + "Knapsack Recomputed Nodes": recomputable_node_idxs, + "Absolute Memories": memories_banned_nodes, + "Knapsack Input Memories": normalized_memories_banned_nodes, + "Knapsack Input Runtimes": runtimes_banned_nodes, + "Min Cut Solution Saved Values": [node.name for node in min_cut_saved_values], + } + return activation_checkpointing_logging_structure_payload + + +def create_structured_trace_for_min_cut_info( + joint_graph: Graph, + all_recomputable_banned_nodes: list[Node], + saved_node_idxs: list[int], + recomputable_node_idxs: list[int], + expected_runtime: float, + memories_banned_nodes: list[int], + normalized_memories_banned_nodes: list[float], + runtimes_banned_nodes: list[float], + min_cut_saved_values: list[Node], +) -> None: + """ + Creates a structured trace for minimum cut information in the graph. + + Args: + joint_graph: The computational graph representation. + all_recomputable_banned_nodes: List of nodes that can be recomputed. + saved_node_idxs: Indices of nodes that are saved in memory. + recomputable_node_idxs: Indices of nodes that are recomputed. + expected_runtime: Expected runtime for the computation. + memories_banned_nodes: Memory requirements for each banned node in bytes. + normalized_memories_banned_nodes: Normalized memory requirements for each banned node + (typically scaled between 0 and 1 for relative comparison). + runtimes_banned_nodes: Runtime costs associated with each banned node. + min_cut_saved_values: Nodes that are saved as part of the minimum cut solution. + """ + # Create a dictionary to store recomputable node information + recomputable_node_info: dict[str, int] = { + node.name: idx for idx, node in enumerate(all_recomputable_banned_nodes) + } + + # Create joint graph node information + joint_graph_node_information = create_joint_graph_node_information( + joint_graph, recomputable_node_info + ) + + # Update node information with recomputable candidate details + for node_name, node_info in joint_graph_node_information.items(): + if node_info["is_recomputable_candidate"]: + idx = recomputable_node_info[node_name] + node_info["recomputable_candidate_info"]["memory"] = memories_banned_nodes[ + idx + ] + node_info["recomputable_candidate_info"]["runtime"] = runtimes_banned_nodes[ + idx + ] + node_info["recomputable_candidate_info"]["is_saved"] = ( + idx in saved_node_idxs + ) + node_info["recomputable_candidate_info"]["is_recomputed"] = ( + idx in recomputable_node_idxs + ) + + # Create joint graph edges + joint_graph_edges = create_joint_graph_edges(joint_graph) + + # Create activation checkpointing logging structure payload + activation_checkpointing_logging_structure_payload = ( + create_activation_checkpointing_logging_structure_payload( + joint_graph=joint_graph, + joint_graph_node_information=joint_graph_node_information, + joint_graph_edges=joint_graph_edges, + all_recomputable_banned_nodes=all_recomputable_banned_nodes, + expected_runtime=expected_runtime, + saved_node_idxs=saved_node_idxs, + recomputable_node_idxs=recomputable_node_idxs, + memories_banned_nodes=memories_banned_nodes, + normalized_memories_banned_nodes=normalized_memories_banned_nodes, + runtimes_banned_nodes=runtimes_banned_nodes, + min_cut_saved_values=min_cut_saved_values, + ) + ) + + # Create structured trace + trace_structured( + "artifact", + metadata_fn=lambda: {"name": "min_cut_information", "encoding": "json"}, + payload_fn=lambda: json.dumps( + activation_checkpointing_logging_structure_payload + ), + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/graph_info_provider.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/graph_info_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..2a5da58fdd63303bebddd2439f7b6607b45377d5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/graph_info_provider.py @@ -0,0 +1,319 @@ +from typing import Any, Optional + +import networkx as nx + +from torch.fx import Graph, Node + + +class GraphInfoProvider: + """ + This class provides information about the graph, such as the nodes, edges, and their runtime and memory requirements. + It also provides methods to create graphs from the information provided. + """ + + __RECOMPUTABLE_NODE_ONLY_GRAPH = "recomputable_node_only_graph" + __RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT = ( + "recomputable_node_only_graph_with_larger_graph_context" + ) + __FULL_NX_JOINT_GRAPH = "full_nx_joint_graph" + __SIMPLIFIED_FX_JOINT_GRAPH = "fx_joint_graph" + + def __init__( + self, + graph_nodes_in_order: list[str], + graph_edges: list[tuple[str, str]], + all_recomputable_banned_nodes: list[str], + all_node_runtimes: Optional[dict[str, float]] = None, + all_node_memories: Optional[dict[str, float]] = None, + recorded_knapsack_input_memories: Optional[list[float]] = None, + recorded_knapsack_input_runtimes: Optional[list[float]] = None, + joint_graph: Optional[Graph] = None, + ): + self.graph_nodes_in_order = graph_nodes_in_order + self.graph_edges = graph_edges + self.all_node_runtimes: dict[str, float] = dict() + if all_node_runtimes is None: + if recorded_knapsack_input_runtimes is None: + raise ValueError( + "Either all_node_runtimes or recorded_knapsack_input_runtimes must be provided." + ) + self.all_node_runtimes = { + node: recorded_knapsack_input_runtimes[i] + for i, node in enumerate(all_recomputable_banned_nodes) + } + else: + self.all_node_runtimes.update(all_node_runtimes) + self.all_node_memories: dict[str, float] = dict() + if all_node_memories is None: + if recorded_knapsack_input_memories is None: + raise ValueError( + "Either all_node_memories or recorded_knapsack_input_memories must be provided." + ) + self.all_node_memories = { + node: recorded_knapsack_input_memories[i] + for i, node in enumerate(all_recomputable_banned_nodes) + } + else: + self.all_node_memories.update(all_node_memories) + self.all_recomputable_banned_nodes = all_recomputable_banned_nodes + self.all_recomputable_banned_nodes_set = set(all_recomputable_banned_nodes) + self.recorded_knapsack_input_memories = recorded_knapsack_input_memories + self.recorded_knapsack_input_runtimes = recorded_knapsack_input_runtimes + self._lazily_initialized_graphs: dict[str, Any] = { + self.__RECOMPUTABLE_NODE_ONLY_GRAPH: None, + self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT: None, + self.__FULL_NX_JOINT_GRAPH: None, + self.__SIMPLIFIED_FX_JOINT_GRAPH: None, + } + + @classmethod + def inialize_from_graph( + cls, + joint_graph: Graph, + all_recomputable_banned_nodes: list[Node], + recorded_knapsack_input_memories: list[float], + recorded_knapsack_input_runtimes: list[float], + ) -> "GraphInfoProvider": + """ + Enables initialization from a joint graph. + """ + graph_nodes_in_order = [node.name for node in joint_graph.nodes] + graph_edges = [ + (node.name, user.name) for node in joint_graph.nodes for user in node.users + ] + all_recomputable_banned_node_names = [ + node.name for node in all_recomputable_banned_nodes + ] + return cls( + graph_nodes_in_order=graph_nodes_in_order, + graph_edges=graph_edges, + all_recomputable_banned_nodes=all_recomputable_banned_node_names, + recorded_knapsack_input_memories=recorded_knapsack_input_memories, + recorded_knapsack_input_runtimes=recorded_knapsack_input_runtimes, + joint_graph=joint_graph, + ) + + @property + def recomputable_node_only_graph(self) -> nx.DiGraph: + if self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] is None: + self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] = ( + self._create_recomputable_node_only_graph() + ) + return self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] + + @property + def recomputable_node_only_graph_with_larger_graph_context(self) -> nx.DiGraph: + if ( + self._lazily_initialized_graphs[ + self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT + ] + is None + ): + self._lazily_initialized_graphs[ + self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT + ] = self._create_recomputable_node_only_graph_with_larger_graph_context() + return self._lazily_initialized_graphs[ + self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT + ] + + @property + def full_joint_nx_graph(self) -> nx.DiGraph: + if self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] is None: + self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] = ( + self._create_full_joint_graph() + ) + return self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] + + @property + def simplified_fx_joint_graph(self) -> Graph: + if self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] is None: + self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] = ( + self._recreate_psuedo_joint_graph() + ) + return self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] + + def get_non_ac_peak_memory(self) -> float: + return sum( + self.all_node_memories[node_name] + for node_name in self.all_recomputable_banned_nodes_set + ) + + def get_theoretical_max_runtime(self) -> float: + return sum( + self.all_node_runtimes[node_name] + for node_name in self.all_recomputable_banned_nodes_set + ) + + def get_knapsack_memory_input(self) -> list[float]: + return ( + self.recorded_knapsack_input_memories + if self.recorded_knapsack_input_memories + else [ + self.all_node_memories[node_name] + for node_name in self.all_recomputable_banned_nodes + ] + ) + + def get_knapsack_runtime_input(self) -> list[float]: + return ( + self.recorded_knapsack_input_runtimes + if self.recorded_knapsack_input_runtimes + else [ + self.all_node_runtimes[node_name] + for node_name in self.all_recomputable_banned_nodes + ] + ) + + def _create_recomputable_node_only_graph(self) -> nx.DiGraph: + graph = nx.DiGraph() + for recomputable_node in self.all_recomputable_banned_nodes: + graph.add_node(recomputable_node) + + for a, b in self.graph_edges: + if ( + a in self.all_recomputable_banned_nodes_set + and b in self.all_recomputable_banned_nodes_set + ): + graph.add_edge(a, b) + return graph + + def _create_recomputable_node_only_graph_with_larger_graph_context( + self, + ) -> nx.DiGraph: + # Create a dictionary to store the reachable nodes for each node + all_recomputable_banned_nodes_set = set(self.all_recomputable_banned_nodes) + + reachable_nodes = {} + for node in all_recomputable_banned_nodes_set: + # Use BFS to find all reachable nodes + predecessors = dict(nx.bfs_predecessors(self.full_joint_nx_graph, node)) + reachable_recomputable_nodes = set(predecessors.keys()).intersection( + all_recomputable_banned_nodes_set + ) + reachable_nodes[node] = reachable_recomputable_nodes + # Create the candidate graph + candidate_graph = nx.DiGraph() + candidate_graph.add_nodes_from(all_recomputable_banned_nodes_set) + for node1 in all_recomputable_banned_nodes_set: + for node2 in reachable_nodes[node1]: + # Check if there is an overlapping path + overlapping_path = False + for intermediate_node in reachable_nodes[node1]: + if ( + intermediate_node != node2 + and node2 in reachable_nodes[intermediate_node] + ): + overlapping_path = True + break + if not overlapping_path: + candidate_graph.add_edge(node1, node2) + return candidate_graph + + def _create_full_joint_graph(self) -> nx.DiGraph: + graph = nx.DiGraph() + for node in self.graph_nodes_in_order: + if node == "output": + continue + graph.add_node(node) + + for a, b in self.graph_edges: + if a == "output" or b == "output": + continue + graph.add_edge(a, b) + return graph + + def _recreate_psuedo_joint_graph(self) -> Graph: + # Create a dictionary to store the dependencies of each node + node_dependencies: dict[str, list[str]] = { + node: [] for node in self.graph_nodes_in_order + } + for a, b in self.graph_edges: + if a not in node_dependencies or b not in node_dependencies: + raise ValueError(f"Edge ({a}, {b}) references a non-existent node.") + node_dependencies[b].append(a) + + joint_graph = Graph() + # Create nodes in the graph + nodes: dict[str, Node] = {} + for node_name in self.graph_nodes_in_order: + input_nodes = [nodes[dep] for dep in node_dependencies[node_name]] + if input_nodes: + node = joint_graph.call_function(lambda *x: x, tuple(input_nodes)) + node.name = node_name + else: + node = joint_graph.placeholder(node_name) + nodes[node_name] = node + return joint_graph + + def _visualize_recomputable_candidate_graph_with_larger_context( + self, + layout_k: float = 0.5, + layout_iterations: int = 30, + ) -> None: + """ + Visualize the recomputable candidate graph with larger context. + """ + from matplotlib import cm, colors as mcolors, pyplot as plt + + pos = nx.spring_layout( + self.recomputable_node_only_graph_with_larger_graph_context, + k=layout_k, + iterations=layout_iterations, + ) + # pos = nx.spectral_layout(graph_with_indirect_edges) + plt.figure(figsize=(20, 15)) + + # Create a dictionary for node labels using the index + labels = { + node: self.recomputable_node_only_graph_with_larger_graph_context.nodes[ + node + ].get("index", node) + for node in self.recomputable_node_only_graph_with_larger_graph_context.nodes + } + + # Extract memory values and normalize them + norm = mcolors.Normalize( + vmin=min(self.get_knapsack_memory_input()), + vmax=max(self.get_knapsack_memory_input()), + ) + cmap = cm.viridis # type: ignore[attr-defined] + + # Assign colors based on memory + node_colors = [ + cmap( + norm( + float( + self.recomputable_node_only_graph_with_larger_graph_context.nodes[ + node + ]["memory"] + ) + ) + ) + for node in self.recomputable_node_only_graph_with_larger_graph_context.nodes + ] + + # Draw the graph with parsed nodes only + nx.draw_networkx_nodes( + self.recomputable_node_only_graph_with_larger_graph_context, + pos, + node_color=node_colors, + node_size=300, + label="Parsed Nodes", + ) + nx.draw_networkx_edges( + self.recomputable_node_only_graph_with_larger_graph_context, + pos, + arrows=True, + arrowsize=10, + ) + nx.draw_networkx_labels( + self.recomputable_node_only_graph_with_larger_graph_context, + pos, + labels=labels, + font_size=8, + font_weight="bold", + ) + + plt.title("Memory Colour Coded Dependency Graph for Recomputable Nodes") + plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), label="Memory") + plt.show() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/knapsack.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/knapsack.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f0a124c64c1ec7ec6651aa79ff62ebec557949 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/knapsack.py @@ -0,0 +1,267 @@ +import torch + + +def greedy_knapsack( + memory: list[float], runtimes: list[float], max_memory: float +) -> tuple[float, list[int], list[int]]: + n = len(runtimes) + items = list(range(n)) + + # Sort items based on the ratio of runtime to memory in descending order + items = sorted(items, key=lambda i: runtimes[i] / memory[i], reverse=True) + + total_memory = 0.0 + total_runtime = 0.0 + items_to_save = [] + items_to_allow_recomputing = [] + + for i in items: + if total_memory + memory[i] <= max_memory: + total_memory += memory[i] + total_runtime += runtimes[i] + items_to_save.append(i) + else: + items_to_allow_recomputing.append(i) + return total_runtime, items_to_save, items_to_allow_recomputing + + +def ilp_knapsack( + memory: list[float], runtimes: list[float], max_memory: float +) -> tuple[float, list[int], list[int]]: + import numpy as np + + try: + from scipy.optimize import Bounds, LinearConstraint, milp + except ImportError: + raise RuntimeError( + "To use the ILP for memory budget checkpointing you need to install scipy" + ) from None + + np_memory = np.array(memory) + np_runtimes = np.array(runtimes) + c = -np_runtimes # type: ignore[operator] + + memory_constraint = LinearConstraint(A=np_memory, ub=np.array(max_memory)) + constraints = [memory_constraint] + + integrality = np.ones_like(c) + res = milp( + c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1) + ) + if not res.success: + raise RuntimeError("Somehow scipy solving failed") + + items_to_save = [] + items_to_allow_recomputing = [] + for idx, i in enumerate(res.x): + if i == 1: + items_to_save.append(idx) + else: + items_to_allow_recomputing.append(idx) + return -res.fun, items_to_save, items_to_allow_recomputing + + +def dp_knapsack( + memory: list[float], runtime: list[float], max_memory: float +) -> tuple[float, list[int], list[int]]: + # Scaling factor to convert floating point weights to integers + S = 10000 + + # Quantize the memory weights + quantized_memory = torch.tensor( + [round(m * S) for m in memory], dtype=torch.long, device="cpu" + ) + runtimes = torch.tensor(runtime, dtype=torch.float32, device="cpu") + + # Quantized pseudopolynomial DP for 0-1 Knapsack + quantized_max_memory = round(max_memory * S) + + n = len(memory) + + # Initialize the DP table + # TODO(chilli): I think if needed, this memory can be optimized with sliding + # window trick + Hirschberg trick: + # https://codeforces.com/blog/entry/47247?#comment-316200 + dp = torch.zeros( + (n + 1, quantized_max_memory + 1), dtype=torch.float32, device="cpu" + ) + + for i in range(1, n + 1): + current_memory = quantized_memory[i - 1] + current_runtime = runtimes[i - 1] + + # Copy the previous row + dp[i, :] = dp[i - 1, :] + + # Update dp[i, j] for all j >= current_memory + if current_memory == 0: + dp[i, :] = dp[i - 1, :] + current_runtime + else: + dp[i, current_memory:] = torch.maximum( + dp[i - 1, current_memory:], + dp[i - 1, :-current_memory] + current_runtime, + ) + + # Backtrack to find the items included in the knapsack + saved_items = [] + recomputable_items = [] + j: int = quantized_max_memory + for i in range(n, 0, -1): + if dp[i][j] != dp[i - 1][j]: + saved_items.append(i - 1) # Include this item (indexing from 0) + j -= int(quantized_memory[i - 1].item()) + else: + recomputable_items.append(i - 1) + + saved_items.reverse() # To get items in the order they were added + + # The maximum runtime that can be achieved within the max_memory constraint + max_runtime = dp[n][quantized_max_memory].item() + + return max_runtime, saved_items, recomputable_items + + +def dp_knapsack_sliding_hirschberg( + memory: list[float], runtime: list[float], max_memory: float +) -> tuple[float, list[int], list[int]]: + # Scaling factor to convert floating point weights to integers + S = 10000 + + # q_ prefix stands for quantized + q_memory = [int(round(m * S)) for m in memory] + runtimes = [float(v) for v in runtime] + + q_max_memory = int(round(max_memory * S)) + + q_memory_length = len(q_memory) + if q_memory_length == 0: + return 0.0, [], [] + + item_indices = list(range(q_memory_length)) + dp_profile_size = q_max_memory + 1 + + # Current DP profile (row) + dp_profile = torch.zeros(dp_profile_size, dtype=torch.float32, device="cpu") + # Store a candidate for next dp_profile - current dp row + item + candidate_profile = torch.empty(dp_profile_size, dtype=torch.float32, device="cpu") + left_profile = torch.empty(dp_profile_size, dtype=torch.float32, device="cpu") + right_profile = torch.empty(dp_profile_size, dtype=torch.float32, device="cpu") + + saved_items: list[int] = [] + recomputable_items: list[int] = [] + + # Explicit stack to optimize memory and avoid recursion + # Stack stores segments as (start index, end index, capacity for segment) + stack: list[tuple[int, int, int]] = [(0, q_memory_length, q_max_memory)] + + # LIFO + while stack: + start, end, capacity = stack.pop() + length = end - start + if length == 0: + continue + + # Leaf + if length == 1: + index = item_indices[start] + memory_item = q_memory[index] + runtime_item = runtimes[index] + if memory_item <= capacity and runtime_item > 0.0: + saved_items.append(index) + else: + recomputable_items.append(index) + continue + + # Split the segment into two halves + middle = start + (length // 2) + left_start, left_end = middle, end + right_start, right_end = start, middle + + # Assign items to both halves + left_items = item_indices[left_start:left_end] + right_items = item_indices[right_start:right_end] + + # Working only on items allowed by segment's capacity + capacity = capacity + 1 + dp_view = dp_profile[:capacity] + candidate_view = candidate_profile[:capacity] + left_dp_local = left_profile[:capacity] + right_dp_local = right_profile[:capacity] + + # Left part + dp_view.zero_() + for index in left_items: + memory_item = q_memory[index] + runtime_item = runtimes[index] + + if memory_item == 0: + # Weight is 0, so add it to all capacities; a "free lunch", essentially + dp_view.add_(runtime_item) + continue + + # If item is too heavy, we skip it + if memory_item >= capacity: + continue + + # Add the current item so we can then pick the highest value + dp_view_candidate = candidate_view[: capacity - memory_item] + torch.add(dp_view[:-memory_item], runtime_item, out=dp_view_candidate) + # Take the highest - either previous (without current) or with current + torch.maximum( + dp_view[memory_item:], dp_view_candidate, out=dp_view[memory_item:] + ) + + # Store the left profile + left_dp_local.copy_(dp_view) + + # Right part + dp_view.zero_() + for index in right_items: + memory_item = q_memory[index] + runtime_item = runtimes[index] + + if memory_item == 0: + dp_view.add_(runtime_item) + continue + + if memory_item >= capacity: + continue + + dp_view_candidate = candidate_view[: capacity - memory_item] + torch.add(dp_view[:-memory_item], runtime_item, out=dp_view_candidate) + torch.maximum( + dp_view[memory_item:], dp_view_candidate, out=dp_view[memory_item:] + ) + + # Store the reversed right profile + right_dp_local.copy_(dp_view.flip(-1)) + + # In-place compute item-wise sum of left and right to pick the split point where the sum is highest + left_dp_local.add_(right_dp_local) + + # Pick the index of highest value of a pair, which we then use as a split point + best_split = int(torch.argmax(left_dp_local).item()) + + left_capacity = best_split + right_capacity = capacity - best_split + + # Clamp (might be removed if we're 100% sure that there is no edge case that will mess up the indices math) + if left_capacity < 0: + left_capacity = 0 + if right_capacity < 0: + right_capacity = 0 + if left_capacity > q_max_memory: + left_capacity = q_max_memory + if right_capacity > q_max_memory: + right_capacity = q_max_memory + + # Push right then left, so left is processed next + stack.append((right_start, right_end, right_capacity)) + stack.append((left_start, left_end, left_capacity)) + + saved_items = sorted(saved_items) + recomputable_items = sorted(recomputable_items) + + max_runtime = sum(runtime[i] for i in saved_items) + recomputable_items.reverse() + return max_runtime, saved_items, recomputable_items diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..2a1a3db275d2dc548e0edbebb632913d8fed01ec --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py @@ -0,0 +1,273 @@ +import operator +from collections import deque +from collections.abc import Callable + +import networkx as nx + +from torch._functorch._activation_checkpointing.graph_info_provider import ( + GraphInfoProvider, +) + + +class KnapsackEvaluator: + """ + This class evaluates the theoretical runtime and peak memory usage of a given checkpointing strategy. + It takes in a graph and a list of nodes that are saved and recomputed, and then simulates the + backward pass to calculate the peak memory usage. + """ + + def __init__( + self, + graph_info_provider: GraphInfoProvider, + ) -> None: + self._graph_info_provider = graph_info_provider + + def _get_backward_memory_from_topologically_sorted_graph( + self, + node_graph: nx.DiGraph, + node_memories: dict[str, float], + saved_nodes_set: set[str], + peak_memory_after_forward_pass: float, + ) -> list[tuple[float, str]]: + """ + Simulates the backward pass and keeps track of the peak memory usage. + + High Level Steps: + 1. Set Initial Peak/Current Memory + Allows you to set the peak memory after the forward pass, but typically this is + the sum of the estimated memory of the saved nodes. + 2. Perform a reverse topological sort of the node_graph. + If full graph is defined then will sort the full graph and only process the subset + of nodes in the node_graph. + 3. Iterate through the sorted graph nodes. + If the node is saved then just drop it's memory from current memory. + If the node is not saved then add it's memory to current memory and then traverse it's + predecessors to simulate recomuptation chain. Will check if new peak memory after all + predecessors are processed. + + Args: + node_graph (nx.DiGraph): A directed graph representing the recomputable forward nodes. + saved_nodes_set (Set[str]): A set of node names that are saved. + peak_memory_after_forward_pass (float): The peak memory usage after the forward pass. + """ + current_memory = [ + (peak_memory_after_forward_pass, "Initial Peak/Current Memory") + ] + already_computed = set() + sorted_nodes = list(reversed(list(nx.topological_sort(node_graph)))) + dependencies_computed = set() + + for node in sorted_nodes: + if node in saved_nodes_set or node in already_computed: + current_memory.append( + ( + current_memory[-1][0] - node_memories[node], + f"Dropping Node(already saved): {node}", + ) + ) + continue + + already_computed.add(node) + current_memory.append( + ( + current_memory[-1][0] + node_memories[node], + f"Recomputing Node: {node}", + ) + ) + # Create a queue of dependencies required for recomputation + predecessor_queue = deque( + [ + dependency + for dependency, v in node_graph.in_edges(node) + if dependency not in already_computed + ] + ) + while predecessor_queue: + dep = predecessor_queue.popleft() + already_computed.add(dep) + dependencies_computed.add(dep) + current_memory.append( + ( + current_memory[-1][0] + node_memories[dep], + f"Recomputing Predecessor of {node}: {dep}", + ) + ) + # Add predecessors of the predecessor to the queue if they haven't been recomputed yet + for dependency_of_dependency, _ in node_graph.in_edges(dep): + if ( + dependency_of_dependency in already_computed + or dependency_of_dependency in saved_nodes_set + or dependency_of_dependency in predecessor_queue + ): + continue + predecessor_queue.append(dependency_of_dependency) + dependencies_computed.clear() + current_memory.append( + (current_memory[-1][0] - node_memories[node], f"Dropping Node: {node}") + ) + return current_memory + + def _validate_all_indexes_accounted_for_in_provided_output( + self, saved_nodes_idxs: list[int], recomputable_node_idxs: list[int] + ) -> None: + """ + Validate that all indexes are accounted for in the provided output. + This function checks that the union of saved nodes and recomputable nodes + covers all candidate nodes without any overlaps. + """ + recomputable_node_idxs_set = set(recomputable_node_idxs) + saved_nodes_idxs_set = set(saved_nodes_idxs) + all_candidate_nodes_idxs = set( + range(len(self._graph_info_provider.all_recomputable_banned_nodes)) + ) + # Check that there are no overlaps between saved nodes and recomputable nodes + assert ( + len(recomputable_node_idxs_set.intersection(saved_nodes_idxs_set)) == 0 + ), "Saved nodes and recomputable nodes cannot have any overlaps" + # Check that all candidate nodes are accounted for + assert ( + recomputable_node_idxs_set.union(saved_nodes_idxs_set) + == all_candidate_nodes_idxs + ), "All candidate nodes must be accounted for in the provided output" + + def evaluate_knapsack_output( + self, + saved_nodes_idxs: list[int], + recomputable_node_idxs: list[int], + account_for_backward_pass: bool = False, + ) -> dict[str, float]: + """ + Evaluate the theoretical runtime and peak memory usage of a given checkpointing strategy. + Args: + - saved_nodes_idxs (List[int]): The indices of nodes that are saved. + - recomputable_node_idxs (List[int]): The indices of nodes that need to be recomputed. + """ + self._validate_all_indexes_accounted_for_in_provided_output( + saved_nodes_idxs, recomputable_node_idxs + ) + recomputation_runtime = sum( + self._graph_info_provider.all_node_runtimes[ + self._graph_info_provider.all_recomputable_banned_nodes[node] + ] + for node in recomputable_node_idxs + ) + if account_for_backward_pass: + memory_list = self._get_backward_memory_from_topologically_sorted_graph( + node_graph=self._graph_info_provider.recomputable_node_only_graph_with_larger_graph_context, + saved_nodes_set={ + self._graph_info_provider.all_recomputable_banned_nodes[i] + for i in saved_nodes_idxs + }, + node_memories=self._graph_info_provider.all_node_memories, + peak_memory_after_forward_pass=sum( + self._graph_info_provider.all_node_memories[ + self._graph_info_provider.all_recomputable_banned_nodes[i] + ] + for i in saved_nodes_idxs + ), + ) + peak_memory = max(memory_list, key=operator.itemgetter(0))[0] + else: + peak_memory = sum( + self._graph_info_provider.all_node_memories[ + self._graph_info_provider.all_recomputable_banned_nodes[node] + ] + for node in saved_nodes_idxs + ) + return { + "peak_memory": peak_memory, + "recomputation_runtime": recomputation_runtime, + "non_ac_peak_memory": self._graph_info_provider.get_non_ac_peak_memory(), + "theoretical_max_runtime": self._graph_info_provider.get_theoretical_max_runtime(), + "percentage_of_theoretical_peak_memory": peak_memory + / self._graph_info_provider.get_non_ac_peak_memory(), + "percentage_of_theoretical_peak_runtime": recomputation_runtime + / self._graph_info_provider.get_theoretical_max_runtime(), + } + + def evaluate_distribution_of_results_for_knapsack_algo( + self, + knapsack_algo: Callable[ + [list[float], list[float], float], tuple[float, list[int], list[int]] + ], + memory_budget_values: list[float], + ) -> list[dict[str, float]]: + """ + Evaluates the distribution of results for a given knapsack algorithm. + Args: + knapsack_algo (Callable): The knapsack algorithm to use for evaluation. + memory_budget_values (List[float]): A list of memory budgets to evaluate. + """ + results = list() + for memory_budget in memory_budget_values: + _, saved_nodes, recomputed_nodes = knapsack_algo( + self._graph_info_provider.get_knapsack_memory_input(), + self._graph_info_provider.get_knapsack_runtime_input(), + memory_budget, + ) + result = self.evaluate_knapsack_output( + saved_nodes_idxs=saved_nodes, + recomputable_node_idxs=recomputed_nodes, + ) + result["memory_budget"] = memory_budget + results.append(result) + return results + + def get_knee_point_memory_budget( + self, + knapsack_algo: Callable[ + [list[float], list[float], float], tuple[float, list[int], list[int]] + ], + max_mem_budget: float = 0.1, + min_mem_budget: float = 0.001, + iterations: int = 100, + ) -> float: + """ + Finds the memory budget at the knee point in the Pareto frontier. + + The knee point is defined as the point where the trade-off between + runtime and memory usage is optimal. + + Args: + knapsack_algo (callable): Knapsack algorithm to use for evaluation. + max_mem_budget (float, optional): Maximum memory budget. Defaults to 0.1. + min_mem_budget (float, optional): Minimum memory budget. Defaults to 0.001. + iterations (int, optional): Number of memory budgets to evaluate. Defaults to 100. + + Returns: + float: Memory budget at the knee point. + """ + results = self.evaluate_distribution_of_results_for_knapsack_algo( + knapsack_algo=knapsack_algo, + memory_budget_values=[ + min_mem_budget + + i * (max_mem_budget - min_mem_budget) / (iterations - 1) + for i in range(iterations) + ], + ) + runtime_values = [ + result["percentage_of_theoretical_peak_runtime"] for result in results + ] + memory_values = [ + result["percentage_of_theoretical_peak_memory"] for result in results + ] + runtime_range = max(runtime_values) - min(runtime_values) + memory_range = max(memory_values) - min(memory_values) + if runtime_range == 0 or memory_range == 0: + return max_mem_budget + + # Normalize values + runtime_min = min(runtime_values) + memory_min = min(memory_values) + runtime_norm = [ + (value - runtime_min) / runtime_range for value in runtime_values + ] + memory_norm = [(value - memory_min) / memory_range for value in memory_values] + # Calculate Euclidean distance + distances = [ + (runtime_norm[i] ** 2 + memory_norm[i] ** 2) ** 0.5 + for i in range(len(runtime_norm)) + ] + # Find the knee point(shortest distance from the origin) + knee_index = distances.index(min(distances)) + return results[knee_index]["memory_budget"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..f975bf0b5d111b0188d6ebc56e334eccb2a164fe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py @@ -0,0 +1,134 @@ +""" +AC rematerialize pass: Duplicates checkpointed nodes for backward, then DCE removes unused forward versions. +""" + +import warnings + +import torch +import torch.fx as fx +from torch._functorch import config +from torch._functorch.compile_utils import raise_getitems +from torch._functorch.partitioners import ( + cleanup_recompute_tags, + force_save_bw_mutation_src, + force_save_collectives, + has_recomputable_ops, + has_recomputable_rng_ops, + is_not_collective, + must_recompute, +) + + +def is_impure_node_for_dce(node): + # Check for special collectives that should be treated as pure + if not is_not_collective(node): + # It's a collective (wait_tensor, all_gather_into_tensor, etc.) + # Treat as pure - can be eliminated if unused + return False + + # For everything else, fall back to the DEFAULT logic + # This is what eliminate_dead_code() calls when is_impure_node=None + impure_random = True + if torch._guards.TracingContext.try_get(): + impure_random = torch._inductor.config.fallback_random + return node.is_impure(impure_random) + + +def _is_backward_node(node: fx.Node) -> bool: + """Check if node is in backward region via annotation""" + return node.meta.get("custom", {}).get("remat_pass_tag", None) == "is_backward" + + +def remat_using_tags_for_fwd_loss_bwd_graph(gm: fx.GraphModule) -> fx.GraphModule: + """ + Duplicate checkpointed nodes for backward use. DCE removes unused forward versions. We assume that + you already annotated your backward region with fx.traceback.annotate({"remat_pass_tag": "is_backward"}) + which helps us identify the backward region. + """ + if not has_recomputable_ops(gm): + return gm + + # Find backward boundary and build ordering + bwd_start: int | None = None + order = {} + for idx, node in enumerate(gm.graph.nodes): + order[node] = idx + if _is_backward_node(node) and bwd_start is None: + bwd_start = idx + + if bwd_start is None: + warnings.warn( + "remat_using_tags_for_fwd_loss_bwd_graph: Graph has recomputable ops but no backward region. " + "This may indicate a forward-only graph (e.g., from nested compilation) or missing backward annotations. " + "Returning graph unchanged." + ) + return gm + + if has_recomputable_rng_ops(gm): + raise RuntimeError( + "Activation checkpoint rematerializing in `forward-loss-backward` graph does not support RNG ops " + "in checkpointed regions. Please move RNG operations outside " + "of checkpoint regions, or use joint graph mode (where partitioner handles RNG)." + ) + + # Use partitioner pass to normalize AC node tags. + gm = cleanup_recompute_tags(gm, is_default_partition=True) + + if not config.unsafe_allow_optimization_of_collectives: + force_save_collectives(gm) + + force_save_bw_mutation_src(gm) + + new_graph = fx.Graph() + env: dict[fx.Node, fx.Node] = {} + recomputed_nodes: dict[fx.Node, fx.Node] = {} + + # Insert forward nodes + for node in list(gm.graph.nodes)[:bwd_start]: + env[node] = new_graph.node_copy(node, lambda x: env[x]) + + def remat_input(x): + # fx.Node can have args that are primitive types (e.g. int, float, bool) + if not isinstance(x, fx.Node): + return x + return recomputed_nodes.get(x, env[x]) + + def gather_checkpointed_deps(node: fx.Node, visited: set) -> None: + if node in visited or node in recomputed_nodes: + return + visited.add(node) + for inp in node.all_input_nodes: + if must_recompute(inp): + gather_checkpointed_deps(inp, visited) + + # Insert backward nodes + for node in list(gm.graph.nodes)[bwd_start:]: + # Gather all checkpointed deps needed by this node + deps = set() + for inp in node.all_input_nodes: + if must_recompute(inp): + gather_checkpointed_deps(inp, deps) + + # Insert deps in forward order (guaranteed disjoint from already-inserted) + # This is not as inefficient as it looks, because we only add fresh dependencies + # when they are not yet processed as recomputed nodes. + for dep in sorted(deps, key=lambda n: order[n]): + assert dep not in recomputed_nodes, "We shouldn't have recomputed it before" + dup = new_graph.node_copy(dep, remat_input) + dup.name = dep.name + "_recomputed" + recomputed_nodes[dep] = dup + + env[node] = new_graph.node_copy(node, remat_input) + + new_gm = torch.fx.GraphModule(gm, new_graph) + + # DCE with custom is_impure_node (like default_partition) + # Treats certain collectives as pure while delegating to default impurity logic + new_gm.graph.eliminate_dead_code(is_impure_node=is_impure_node_for_dce) + + # raise_getitems pass for better memory (like default_partition) + new_gm = raise_getitems(new_gm) + + new_gm.recompile() + + return new_gm diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/aot_autograd_result.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/aot_autograd_result.py new file mode 100644 index 0000000000000000000000000000000000000000..3bbacfaf3080264bdb538ab96d22b71b6f64b12e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/aot_autograd_result.py @@ -0,0 +1,676 @@ +# mypy: allow-untyped-defs +""" +This module provides result classes for AOT Autograd compilation. + +Similar to how torch._inductor.output_code provides OutputCode classes for inductor +compilation results, this module provides AOTAutogradResult classes that represent +the compiled artifacts produced by AOT Autograd. + +These results are: +- Serializable: can be saved/loaded from disk without recompilation +- Addressable: can be stored in caches with keys for later retrieval +- Reusable: can be used for both caching and ahead-of-time compilation (precompile) + +The main result types are: +- GenericAOTAutogradResult: Abstract base for all AOT Autograd results +- AOTAutogradResult: Regular result that references FxGraphCache entries +- BundledAOTAutogradResult: Result that bundles the entire compiled code directly +""" + +from __future__ import annotations + +import json +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable +from copy import copy +from dataclasses import dataclass +from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar + +import torch +from torch._dynamo.precompile_context import BackendCacheArtifact +from torch._inductor.codecache import FxGraphCache +from torch._inductor.output_code import ( + CompiledFxGraph, + CompiledFxGraphConstants, + OutputCode, +) +from torch._inductor.utils import should_use_remote_fx_graph_cache + +from .runtime_wrappers import ( + AOTDispatchAutograd, + AOTDispatchSubclassWrapper, + CachedAutogradLazyBackwardCompileInfo, + CompilerWrapper, + FunctionalizedRngRuntimeWrapper, + post_compile, + RuntimeWrapper, + SerializableCompiledFunction, + SubclassMeta, +) +from .schemas import AOTAutogradCacheInfo # noqa: F401 +from .utils import simple_wraps + + +if TYPE_CHECKING: + from torch._inductor.compile_fx import _CompileFxKwargs + + from .schemas import AOTConfig, ViewAndMutationMeta + +log = logging.getLogger(__name__) + + +TOut = TypeVar("TOut", bound=OutputCode) + + +class InductorOutput(ABC, Generic[TOut]): + """ + Class representing a single inductor output + """ + + @abstractmethod + def pre_save(self) -> None: ... + + @abstractmethod + def load(self, example_inputs) -> TOut: ... + + @abstractmethod + def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: ... + + +TOutputCode = TypeVar("TOutputCode", bound=OutputCode) + + +@dataclass +class BundledOutputCodeLoadable(InductorOutput[TOutputCode], Generic[TOutputCode]): + """ + A generic wrapper for OutputCode objects that are bundled directly in the cache + (rather than looked up via FxGraphCache). + + This works for any OutputCode subclass (CompiledFxGraph, RegionalOutputCode, etc.) + """ + + result: TOutputCode + + def pre_save(self) -> None: + disk_result = copy(self.result) + disk_result.prepare_for_serialization() + self.result = disk_result + return + + def load(self, example_inputs) -> TOutputCode: + self.example_inputs = example_inputs + return self.result + + def post_compile( + self, result: TOutputCode, fx_config: _CompileFxKwargs + ) -> TOutputCode: + constants = CompiledFxGraphConstants() + + # Special handling for CompiledFxGraph - needs FxGraphCache.cache_hit_post_compile + if isinstance(result, CompiledFxGraph): + graph, cache_info = FxGraphCache.cache_hit_post_compile( + result, {}, constants + ) + if graph is None: + raise RuntimeError("Failed to reload cache entry from disk") + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_bundled_cache_hit", # always a hit + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + result = graph # type: ignore[assignment] + + # Run normal post compile + result.post_compile(self.example_inputs, constants, fx_config) + return result + + +# Backwards compatibility alias +CompiledFxGraphLoadable: type[BundledOutputCodeLoadable[CompiledFxGraph]] = ( + BundledOutputCodeLoadable[CompiledFxGraph] +) + + +@dataclass +class FxGraphCacheLoadable(InductorOutput[CompiledFxGraph]): + fx_graph_cache_info: tuple[str, list[str]] + fx_graph_guard_expr: Optional[str] + + def pre_save(self): + return + + def _is_backward(self) -> bool: + return False + + def load(self, example_inputs) -> CompiledFxGraph: + from .autograd_cache import FXGraphCacheMiss + + # [Note: AOTAutogradCache and FXGraphCache Guard interactions] + # As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments. + # FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph. + # The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly + # the same as the ones it passes to inductor, for both the forward and backward passes. + # (This does not mean that the tensor values passed in are the same: only that their symints are). + # That is, AOTAutograd and Inductor never create new guards based on symints with different sources + # than those passed to it by inductor. + # We pass the post compile function, which sets various fx_config boxed values, + # so we can call it only after we're sure both forward and backward have + # Clear CompiledTritonKernels before loading from FXGraphCache + torch._inductor.async_compile.CompiledTritonKernels.cache_clear() + remote_cache = None + constants = CompiledFxGraphConstants() + if should_use_remote_fx_graph_cache(): + remote_cache = FxGraphCache.get_remote_cache() + (cache_key, debug_lines) = self.fx_graph_cache_info + + def check_exact_guard_match(guard_expr, _hints): + """ + AOTAutogradCache tracks its own guards, so we just need to treat these guard expressions as a second + cache key of sorts: we just check for equality, i.e. the FXGraphCache entry with + the exact same guards as we originally saved into the cache. + """ + return guard_expr == self.fx_graph_guard_expr + + result, cache_info = FxGraphCache.load_with_key( + cache_key, + debug_lines, + example_inputs, + local=True, + remote_cache=remote_cache, + is_backward=self._is_backward(), + constants=constants, + evaluate_guards=check_exact_guard_match, + ) + if result is None: + log.info("FXGraphCache cache miss for key %s", self.fx_graph_cache_info) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_cache_miss", # always a hit + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + + raise FXGraphCacheMiss + + # No need to log chromium event because AOTAutograd will log that immediately for us + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_cache_hit", # always a hit + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + self.example_inputs = example_inputs + self.constants = constants + return result + + def post_compile( + self, result: CompiledFxGraph, fx_config: _CompileFxKwargs + ) -> CompiledFxGraph: + """ + Called after FXGraphCacheLoadable.load, mutates fx_config + """ + result.post_compile(self.example_inputs, self.constants, fx_config) + return result + + +@dataclass +class CompiledForward(FxGraphCacheLoadable): + """ + Cacheable entry for a forward function + """ + + def _is_backward(self) -> bool: + return False + + +@dataclass +class GenericCompiledBackward(InductorOutput[TOut]): + # Used by AOTDispatchAutograd.post_compile + backward_state_indices: list[int] + num_symints_saved_for_bw_: int + + +@dataclass +class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoadable): + """ + Cacheable entry for a forward function + """ + + def _is_backward(self) -> bool: + return True + + def post_compile( + self, result: CompiledFxGraph, fx_config: _CompileFxKwargs + ) -> CompiledFxGraph: + compiled_bw = super().post_compile(result, fx_config) + # See note [Wrapping bw_compiler in disable] + # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py + # But since on cache hit we do not call the bw_compiler, we need to reapply the disable + return torch._dynamo.disable( # type: ignore[return-value] + compiled_bw, reason="do not trace generated backwards pass" + ) + + +# Generic bundled forward/backward classes that work with any OutputCode type +@dataclass +class BundledCompiledForward( + BundledOutputCodeLoadable[TOutputCode], Generic[TOutputCode] +): + """ + Generic forward function for bundled compilation. + Works with any OutputCode type (CompiledFxGraph, RegionalOutputCode, etc.) + """ + + +@dataclass +class BundledCompiledBackward( + GenericCompiledBackward[TOutputCode], + BundledOutputCodeLoadable[TOutputCode], + Generic[TOutputCode], +): + """ + Generic backward function for bundled compilation. + Works with any OutputCode type (CompiledFxGraph, RegionalOutputCode, etc.) + """ + + def post_compile( + self, result: TOutputCode, fx_config: _CompileFxKwargs + ) -> TOutputCode: + compiled_bw = super().post_compile(result, fx_config) + # See note [Wrapping bw_compiler in disable] + # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py + # But since on cache hit we do not call the bw_compiler, we need to reapply the disable + return torch._dynamo.disable( # type: ignore[return-value] + compiled_bw, reason="do not trace generated backwards pass" + ) + + +@dataclass +class SerializedGraphModule: + fn: Callable[[dict[Any, Any], str], torch.nn.Module] + args: tuple[Any, ...] + + def __init__(self, gm: torch.fx.GraphModule): + self.fn, self.args = gm.__reduce__() + + def deserialize(self) -> torch.fx.GraphModule: + gm = self.fn(*self.args) + assert isinstance(gm, torch.fx.GraphModule) + return gm + + +def serialize_graph_module(gm: torch.fx.GraphModule) -> SerializedGraphModule: + # NOTE: mutates the graph module + gm.meta = {} + for node in gm.graph.nodes: + node.meta = {} + return SerializedGraphModule(gm) + + +TForward = TypeVar("TForward", bound=InductorOutput) +TBackward = TypeVar("TBackward", bound=GenericCompiledBackward) + + +@dataclass +class GenericAOTAutogradResult(Generic[TForward, TBackward]): + """A single result from AOT Autograd compilation, genericized by Forward and Backward types. + + A TForward is always an InductorOutput of some sort, which represents the + forward graph of the compile. + A TBackward is an InductorOutput + metadata about the backward, useful for specific + backward-only wrappers. This type is encapsulated by GenericCompiledBackward. + + Each AOTAutogradResult is essentially parameterized by 1. the method of loading + from the cache (either Bundled or UnBundled), and 2. The type of the output. For now, + the only type of output we support is Python Wrapper output, i.e. OutputCode.CompiledFxGraph, + but the same technique works for C++ wrapper code; we'd just add an extra InductorOutput type. + """ + + # Forward and Backward info + compiled_fw: TForward + compiled_bw: Optional[TBackward] + + # Code of the joint graph using print_readable() + # Used for logging purposes + aot_joint_graph_str: Optional[str] + aot_forward_graph_str: Optional[str] + aot_backward_graph_str: Optional[str] + + # Runtime_metadata saved right before compilation + runtime_metadata: ViewAndMutationMeta + + # Wrappers that run after each aot_dispatch_* function + dispatch_wrappers: list[CompilerWrapper] + + # Used by AOTSubclassWrapper + maybe_subclass_meta: Optional[SubclassMeta] + num_fw_outs_saved_for_bw: Optional[int] + + # Used by RuntimeWrapper + indices_of_inps_to_detach: list[int] + + # Time taken to trace/compile the forward + # forward_time_taken includes AOTAutograd tracing time + inductor compilation time + # backward_time_taken is essentially just the time inductor took to compile + forward_time_taken_ns: int + backward_time_taken_ns: int + + # Used by standalone_compile + sanitized_aot_config: AOTConfig + + guards_expr: Optional[str] + + # Used by Compiled Autograd + serialized_bw_module: Optional[SerializedGraphModule] + + def pre_save(self): + """ + Perform any preparations to make the result ready for serialization. + """ + self.compiled_fw.pre_save() + if self.compiled_bw is not None: + self.compiled_bw.pre_save() + + # Turn result into the original callable + def wrap_post_compile( + self, + args: list[torch.Tensor], + aot_config: AOTConfig, + fx_config: _CompileFxKwargs, + ) -> Callable: + """ + This function takes a result and carefully reconstructs the original callable + that AOTAutograd returned the first time it was run. It does this by running the various + post compile steps that AOTAutograd runs on its compiled artifact after running the fw/bw compilers. + + In the inference path, this consists of the Subclass, FunctionalzedRngRuntime, and RuntimeWrappers. + In the autograd path, this consists of AOTAutogradDispatch.post_compile. + + The steps here should match exactly the steps that are run in aot_dispatch_base and aot_dispatch_autograd. + + Notably absent from the cached path are: + - DebugAssertWrapper + - FakifiedOutWrapper + + Which we'll handle separately later on, if necessary. + """ + from torch._dynamo.utils import CompileEventLogger, dynamo_timed + + # Log the output of AOTAutogradCache + if aot_config.enable_log: + # TODO: maybe also log to aot_graphs_log + # Unfortunately aot_graphs_log uses + # slightly different formatting though + if self.aot_joint_graph_str is not None: + torch._logging.trace_structured( + "aot_joint_graph", payload_fn=lambda: self.aot_joint_graph_str + ) + + if self.aot_forward_graph_str is not None: + from torchgen.utils import dataclass_repr + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(self.runtime_metadata), + ) + if self.maybe_subclass_meta is not None: + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(self.maybe_subclass_meta), + ) + + # It's called an inference graph if not running with autograd + name = ( + "aot_forward_graph" + if self.aot_backward_graph_str is not None + else "aot_inference_graph" + ) + torch._logging.trace_structured( + name, payload_fn=lambda: self.aot_forward_graph_str + ) + + if self.aot_backward_graph_str is not None: + torch._logging.trace_structured( + "aot_backward_graph", payload_fn=lambda: self.aot_backward_graph_str + ) + with dynamo_timed("AOTAutogradCache.inductor_load"): + compiled_fw_func = self.compiled_fw.load(args) + compiled_bw_func = None + if self.compiled_bw is not None: + compiled_bw_func = self.compiled_bw.load(args) + needs_autograd = True + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="autograd" + ) + # Now that we've loaded forward and backward, call post compile on both + # This avoids setting things like BoxedBools in fx_config until + # after both forward and backward cache hit + fw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } + bw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": True, + } + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, fw_fx_config + ) + compiled_bw_func = self.compiled_bw.post_compile( + compiled_bw_func, bw_fx_config + ) + else: + inference_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } + + needs_autograd = False + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="inference" + ) + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, inference_fx_config + ) + + # Wrap the forward function in post compile wrappers + compiled_fw_func = AOTDispatchSubclassWrapper( + trace_joint=needs_autograd, + fw_only=None, + maybe_subclass_meta=self.maybe_subclass_meta, + num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw, + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + + req_subclass_dispatch = self.maybe_subclass_meta is not None + CompileEventLogger.try_add_pt2_compile( + "backend_compile", requires_subclass_dispatch=req_subclass_dispatch + ) + + # In autograd case, functionalizedRngWrapper should not modify outs + return_new_outs = not needs_autograd + compiled_fw_func = FunctionalizedRngRuntimeWrapper( + return_new_outs=return_new_outs + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + compiled_fw_func._boxed_call = True + disable_amp = torch._C._is_any_autocast_enabled() + + if needs_autograd: + assert self.compiled_bw is not None + + cached_lazy_backward = None + if self.serialized_bw_module is not None: + cached_lazy_backward = CachedAutogradLazyBackwardCompileInfo( + self.serialized_bw_module.deserialize + ) + # This function is run on both cache miss and cache hit, either here + # or in aot_dispatch_autograd. On a cache hit, + # 1. the bw is already compiled + # 2. we don't need to save to the cache again + # so those corresponding arguments are set to None. + compiled_function = AOTDispatchAutograd.post_compile( + compiled_fw_func, + compiled_bw_func, + self.maybe_subclass_meta, + self.compiled_bw.num_symints_saved_for_bw_, + self.compiled_bw.backward_state_indices, + disable_amp, + self.indices_of_inps_to_detach, + cached_lazy_backward, + aot_config, + fw_metadata=self.runtime_metadata, + try_save_cache_entry=None, + ) + + else: + compiled_function = RuntimeWrapper( + indices_of_inps_to_detach=self.indices_of_inps_to_detach, + trace_joint=False, + disable_amp=disable_amp, + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + + # Add serialization function back onto object + compiled_function, _ = post_compile( + self.dispatch_wrappers, + compiled_function, + aot_config, + runtime_metadata=self.runtime_metadata, + ) + + # Now that we're pretty sure it's a successful load, add guards + # to the existing shape environment from the cache + if self.guards_expr: + from .autograd_cache import AOTAutogradCache + + symints = AOTAutogradCache._filter_backed_symints(args) + check = bool(AOTAutogradCache.evaluate_guards(self.guards_expr, symints)) + assert check is True + + return compiled_function + + +class AOTAutogradResult(GenericAOTAutogradResult[CompiledForward, CompiledBackward]): + """ + Regular AOTAutogradResult: saves the forward/backward FxGraphCache keys + and looks them up in FxGraphCache on load + """ + + +class BundledAOTAutogradResult( + GenericAOTAutogradResult[ + BundledCompiledForward[TOutputCode], BundledCompiledBackward[TOutputCode] + ], + Generic[TOutputCode], +): + """ + Generic AOTAutogradResult where we bundle the entire OutputCode directly + (rather than looking it up via FxGraphCache). + + This works with any OutputCode type: + - CompiledFxGraph: Traditional inductor compilation + - RegionalOutputCode: Regional inductor compilation with GraphPickler serialization + - Any future OutputCode subclasses + + Type parameter: + TOutputCode: The OutputCode subclass (e.g., CompiledFxGraph, RegionalOutputCode) + + Usage with CompiledFxGraph: + entry = BundledAOTAutogradResult[CompiledFxGraph]( + compiled_fw=BundledCompiledForward(result=CompiledFxGraph(...)), + compiled_bw=BundledCompiledBackward( + result=CompiledFxGraph(...), + backward_state_indices=[...], + num_symints_saved_for_bw_=..., + ), + ... + ) + + Usage with RegionalOutputCode: + entry = BundledAOTAutogradResult[RegionalOutputCode]( + compiled_fw=BundledCompiledForward(result=RegionalOutputCode(gm)), + compiled_bw=BundledCompiledBackward( + result=RegionalOutputCode(gm), + backward_state_indices=[...], + num_symints_saved_for_bw_=..., + ), + ... + ) + """ + + +def deserialize_bundled_cache_entry(entry: BundledAOTAutogradResult) -> Callable: + from copy import deepcopy + + from torch._inductor.cudagraph_utils import BoxedDeviceIndex + from torch._inductor.utils import BoxedBool + + # In the precompile use case, guards are already serialized + # by dynamo, so we don't need to add them to the environment + entry.guards_expr = None + # TODO: this isn't exactly right, because cudagraphs needs to be a shared config + # which is set by compile_fx. But in precompile, we never actually call compile_fx + # so we don't have a place to track cudagraphs here. + cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) + boxed_forward_device_index = BoxedDeviceIndex(None) + # We need to make a clean copy of the cache entry + # in case it needs to be serialized again + serializable_copy = deepcopy(entry) + + from torch._subclasses import FakeTensorMode + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + context = torch._guards.TracingContext.try_get() + if context is None: + # Create a clean environment when running fx graph post compile + # if one is not available + context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv())) + with torch._guards.tracing(context): + compiled_fn = entry.wrap_post_compile( + [], + entry.sanitized_aot_config, + { + "cudagraphs": cudagraphs, + "boxed_forward_device_index": boxed_forward_device_index, + }, + ) + # Ensure the deserialized cache entry is still serializable + + compiled_fn = SerializableCompiledFunction(compiled_fn, lambda: serializable_copy) + + # TODO: this ignores flat_params, which can exist + # if inline_builtin_nn_modules=False + @simple_wraps(compiled_fn) + def forward(*runtime_args: tuple[Any]): + return compiled_fn(list(runtime_args)) + + assert hasattr(compiled_fn, "serialize") + forward.serialize = compiled_fn.serialize # type: ignore[attr-defined] + + return forward + + +@dataclass +class BundledAOTAutogradCacheArtifact(BackendCacheArtifact[Callable]): + def after_deserialization(self) -> Callable: + return deserialize_bundled_cache_entry(self.content) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7b4c8973c5df21187350e50ed5b40c18860cc4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py @@ -0,0 +1,1037 @@ +# mypy: allow-untyped-defs +""" +Utils for caching the outputs of AOTAutograd +""" + +from __future__ import annotations + +import base64 +import contextlib +import functools +import json +import logging +import os +import pickle +import random +import shutil +import time +import traceback +from copy import copy +from typing import Any, Optional, TYPE_CHECKING, Union +from typing_extensions import override + +import torch +from torch._dynamo.precompile_context import PrecompileContext +from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions +from torch._dynamo.utils import chromium_event_log_active, CompileEventLogger, counters +from torch._functorch import config +from torch._inductor.codecache import ( + _ident, + add_ephemeral_timeout_increase_for_distributed, + BypassFxGraphCache, + create_cache, + extract_tensor_metadata_for_cache_key, + FxGraphCache, + FxGraphCachePickler, + FxGraphHashDetails, + GuardedCache, + sha256_hash, + write_atomic, +) +from torch._inductor.output_code import OutputCode +from torch._inductor.runtime.runtime_utils import cache_dir +from torch._inductor.utils import BoxedBool, should_use_remote_fx_graph_cache +from torch._logging import LazyString +from torch._utils_internal import log_cache_bypass +from torch.compiler._cache import ( + CacheArtifact, + CacheArtifactFactory, + CacheArtifactManager, +) +from torch.fx.experimental.symbolic_shapes import hint_int +from torch.utils._triton import has_triton_package + +from .aot_autograd_result import ( + AOTAutogradResult, + BundledAOTAutogradCacheArtifact, + BundledAOTAutogradResult, + BundledCompiledBackward, + BundledCompiledForward, + CompiledBackward, + CompiledForward, + GenericAOTAutogradResult, + SerializedGraphModule, +) +from .runtime_wrappers import ( + CompilerWrapper, + SerializableCompiledFunction, + SubclassMeta, +) +from .schemas import AOTAutogradCacheInfo, AOTConfig, ViewAndMutationMeta # noqa: F401 + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch._inductor.compile_fx import _CompileFxKwargs + from torch._inductor.cudagraph_utils import BoxedDeviceIndex + from torch._inductor.remote_cache import JsonDataTy, RemoteCache + from torch.fx.node import Node + + +log = logging.getLogger(__name__) + + +class BypassAOTAutogradCache(Exception): + pass + + +# Used to signify when FXGraphCache missed when AOTAutogradCache uses it +class FXGraphCacheMiss(BypassAOTAutogradCache): + pass + + +def should_use_remote_autograd_cache(): + if torch.compiler.config.force_disable_caches: + return False + if config.enable_remote_autograd_cache is not None: + return config.enable_remote_autograd_cache + if not config.is_fbcode(): + return False + + if torch._utils_internal.is_fb_unit_test(): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + jk_name = "pytorch/remote_cache:aot_autograd_cache_version" + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(jk_name) + + +def should_use_local_autograd_cache(): + if torch.compiler.config.force_disable_caches: + return False + return config.enable_autograd_cache + + +def should_bundle_autograd_cache(): + return config.bundled_autograd_cache or torch._dynamo.config.caching_precompile + + +def check_node_safe(node: Node): + """ + Checks that the node only uses supported operators. We are starting with very + conservative cacheability constraints, and incrementally adding more support as we expand. + + [Note: AOTAutograd Cacheability checks] + - Our cache key is computed from the FX graph produced by Dynamo and the input example values + - A node is "safe" if the same cache key results in a compiled artifact that has the same behavior + (i.e, the set of inputs that go into our cache key is sufficient to distinguish its behavior) + + To accomplish this safety check, we consider the following functions to be safe: + - Public functions under modules torch, torch.functional, and torch.nn.functional: these are + allowed in the graph by dynamo, so we can assume they are safe to cache. + - method calls on base tensor types + - Any call_module that dynamo deemed safe to allow AOTAutograd to trace + - Non callable nodes, such as placeholder, output, get_attr + + The test suite test_aot_autograd_cache.py::AOTAutogradCachePicklerTests tries its best to fully cover/specify this behavior. + """ + SAFE_TORCH_MODULES = ("torch.functional", "torch.nn.functional") + SAFE_TORCH_FUNCTIONS = ( + "torch.Size", + "torch.Tensor", + "torch.sym_int", + "torch._sym_sqrt", + "torch.sym_float", + "torch.sym_sum", + ) + SAFE_NON_TORCH_FUNCTIONS = ( + "einops.einops.rearrange", + "einops.einops.repeat", + ) + + def is_public_torch_api(target): + # Don't blindly allow private functions in the torch namespace + is_private = target.__name__.startswith("_") + + return ( + getattr(target, "__module__", None) in SAFE_TORCH_MODULES and not is_private + ) + + def is_safe_torch_function(target): + """Allowlisted torch functions""" + function_name = f"{target.__module__}.{target.__name__}" + # Allow torch.autograd.function.FunctionCtx if custom autograd functions are allowed + if function_name == "torch.autograd.function.FunctionCtx": + return ( + torch._functorch.config.autograd_cache_allow_custom_autograd_functions + ) + + # Functions in torch_non_c_binding_in_graph_functions + # are guaranteed to be cache safe. + # See NOTE: [Cacheability of in-graph torch functions] + return ( + function_name in torch_non_c_binding_in_graph_functions + or function_name in SAFE_TORCH_FUNCTIONS + or function_name in torch._inductor.config.unsafe_marked_cacheable_functions + ) + + def is_cacheable_function(target): + if isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)): + return True + if is_public_torch_api(target): + return True + # Technically, FXGraphCache._check_for_hop already checks this, + # but better to error earlier anyway + if isinstance(target, torch._ops.HigherOrderOperator): + return target.cacheable() + is_builtin_fun_or_type = type(target).__name__ == "builtin_function_or_method" + if is_builtin_fun_or_type: + return True + if is_safe_torch_function(target): + return True + function_name = f"{target.__module__}.{target.__name__}" + if function_name in SAFE_NON_TORCH_FUNCTIONS: + return True + return False + + def is_tensor(target): + # Tensors always have example values in meta field + return "example_value" in target.meta + + # I'd love to use a match statement here, but it wasn't introduced until py3.10 + if node.op == "call_function": + if node.meta and node.meta.get("is_wrapped", False): + # This is fx.wrap function + # By default we BypassAOTAutogradCache for unknown functions, + # But if user explicitly specified cache hash - allow to cache it. + if node.meta.get("user_cache_hash", None): + return + + if not is_cacheable_function(node.target): + module = getattr(node.target, "__module__", None) + name = getattr(node.target, "__name__", None) + raise BypassAOTAutogradCache( + f"Unsupported call_function target {node.target}. \n Function module: {module}, \nFunction name: {name}" + ) + elif node.op == "call_method": + method_name = node.target + method_target = node.args[0] + # Only support method calls on base tensors + if not is_tensor(method_target): + module = getattr(method_target, "__module__", None) + name = getattr(method_target, "__name__", None) + raise BypassAOTAutogradCache( + f"Unsupported call_method target {method_target}. \nMethod module: {module}, \nMethod name: {name}" + ) + if ( + type(method_name) is not str + and type(method_name).__name__ != "method_descriptor" + ): + raise BypassAOTAutogradCache( + f"Unsupported call_method method {node.target}: {method_name}" + ) + # Cache safe + elif node.op in ("placeholder", "get_attr", "call_module", "output"): + # Assumption today for call_module being a safe op: + # (1) today the only call_module ops that can show up in a graph come from "built-in-nn-modules" + # that dynamo assumes are safe to trace. If dynamo assumes they are safely to blindly trace, then + # they should be safe to cache as well. + # (2) in the steady-state (some time in H2?) we shouldn't see these anymore, once inline builtin nn modules by default + # (3) We do not allow user made nn modules in the graph today, only function calls. + pass + else: + raise BypassAOTAutogradCache(f"Unsupported node op {node.op}") + + +def check_cacheable(gm: torch.fx.GraphModule): + """ + Checks that the graph module only uses supported operators + """ + nodes = gm.graph.nodes + if torch._inductor.config.freezing: + raise BypassAOTAutogradCache("Cannot cache a graph with freezing enabled") + + if not ( + torch._inductor.config.fx_graph_cache or should_use_remote_fx_graph_cache() + ): + raise BypassAOTAutogradCache("FX graph cache is not enabled") + + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context and tracing_context.fakify_first_call: + raise BypassAOTAutogradCache( + "Won't cache a graph with fakify_first_call enabled" + ) + for node in nodes: + check_node_safe(node) + + # Saved tensors hooks are globally set subgraphs, + # that are not used explicitly in the main graph. + # They are inlined in aot_autograd graphs. + # Subgraphs are only used for caching logic. + if hasattr(gm, "saved_tensors_hooks_pack_0"): + check_cacheable(gm.saved_tensors_hooks_pack_0) # type: ignore[arg-type] + # We have guarantee of unpack sugraph existence if pack subgraph exists + check_cacheable(gm.saved_tensors_hooks_unpack_0) # type: ignore[arg-type] + + +class AOTAutogradCacheDetails(FxGraphHashDetails): + """ + Object to capture all the details for a dynamo graph module relevant to computing + a safe and stable cache key for AOTAutograd. + """ + + def get_triton_source_codes_from_gm( + self, + gm: torch.fx.GraphModule, + ): + assert has_triton_package(), "Triton is not available" + + triton_kernels = [] + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if isinstance(node.target, torch._ops.OpOverloadPacket): + attrs = node.target._dir + for attr in attrs: + if custom_op := getattr(node.target, attr, None): + kernels = torch._library.triton.get_triton_kernels_for_op( + custom_op._name + ) + triton_kernels.extend(kernels) + elif isinstance(node.target, torch._ops.OpOverload): + kernels = torch._library.triton.get_triton_kernels_for_op( + node.target._name + ) + triton_kernels.extend(kernels) + + triton_kernel_source_codes = [] + from torch._inductor.codegen.wrapper import ( + user_defined_triton_kernel_transitive_closure_source_code, + ) + + for kernel in triton_kernels: + from triton.runtime.autotuner import Autotuner + + if isinstance(kernel, Autotuner): + # Grab the Inner JITFunction + kernel = kernel.fn + source_codes = user_defined_triton_kernel_transitive_closure_source_code( + kernel + ) + triton_kernel_source_codes.append(source_codes) + + return triton_kernel_source_codes + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs, + aot_config: AOTConfig, + fx_config: _CompileFxKwargs, + ): + # FxGraphHashDetails contains all the keys related to inductor. Also includes some system info + self.aot_config = aot_config + self.grad_enabled = torch.is_grad_enabled() + self.disable_amp = torch._C._is_any_autocast_enabled() + self.deterministic_algorithms = torch.are_deterministic_algorithms_enabled() + self.autograd_config = config.save_config() + self.saved_tensors_hooks_fx_wrap_cache_hashes: tuple[list[str], list[str]] = ( + [], + [], + ) + if has_triton_package(): + self.triton_kernel_source_codes = self.get_triton_source_codes_from_gm(gm) + + if hasattr(gm, "saved_tensors_hooks_pack_0"): + + def _add_wrapped_user_cache_hashes(_gm, _l): + for node in _gm.graph.nodes: + if node.meta and node.meta.get("is_wrapped", False): + _l.append(node.meta["user_cache_hash"]) + + _add_wrapped_user_cache_hashes( + gm.saved_tensors_hooks_pack_0, + self.saved_tensors_hooks_fx_wrap_cache_hashes[0], + ) + _add_wrapped_user_cache_hashes( + gm.saved_tensors_hooks_unpack_0, + self.saved_tensors_hooks_fx_wrap_cache_hashes[1], + ) + + try: + # FXGraphCache has constraints on what can be pickled in its inductor + # config. Check that the gm is cacheable by inductor first, + # and if it raises an exception, also bypass on our end. + FxGraphCache._check_can_cache(gm) + super().__init__(gm, example_inputs, fx_config, []) + except BypassFxGraphCache as e: + # Sometimes inductor configs are unpickleable and can fail + raise BypassAOTAutogradCache(str(e)) from e + + +class AOTAutogradCachePickler(FxGraphCachePickler): + def __init__(self, gm: torch.fx.GraphModule): + super().__init__(gm) + # pyrefly: ignore [bad-override] + self.dispatch_table: dict + self.dispatch_table.update( + { + AOTConfig: functools.partial(self._reduce_aot_config), + torch.Tensor: functools.partial(self._reduce_tensor), + } + ) + + def _reduce_aot_config(self, aot_config: AOTConfig): + """ + Reduce the config to a stable key for caching. + """ + return ( + _ident, + ( + aot_config.num_params_buffers, + aot_config.keep_inference_input_mutations, + aot_config.is_export, + aot_config.no_tangents, + aot_config.dynamic_shapes, + aot_config.aot_autograd_arg_pos_to_source, + aot_config.enable_log, + aot_config.pre_dispatch, + ), + ) + + def _reduce_tensor(self, tensor): + """ + Reduce the tensor to a stable key for caching. + """ + metadata = extract_tensor_metadata_for_cache_key(tensor) + return (_ident, (metadata,)) + + +@contextlib.contextmanager +def normalize_placeholder_names(gm: torch.fx.GraphModule): + """ + Context manager that normalizes the placeholder names in the graph module. + This is used while generating a cache key for AOTAutogradCache, so that two graphs + that are isomorphic when normalizing names can hit the same cache entry. + This is safe because nothing underneath AOTAutograd uses the node names on the + original dynamo graph: AOTAutograd re-traces with its own nodes, and guards are + in terms of original sources rather than placeholder names. + """ + # Standalone inductor: we're bypassing AOTAutogradCache anyway, so return the graph + # as-is + if not config.autograd_cache_normalize_inputs or not hasattr(gm, "graph"): + yield + return + + # Track all the old state of placeholders + old_placeholder_names = [] + old_used_names = copy(gm.graph._graph_namespace._used_names) + i = 0 + for n in gm.graph.find_nodes(op="placeholder", sort=True): + if n.type != torch.SymInt: + # _rename renames the node in the body of the function, + # but it doesn't change the raw name from node.target + # So we also set the raw_name of node.target to a new placeholder name + new_placeholder_name = f"p_{i}" + old_placeholder_names.append((n.name, n.target)) + n.target = new_placeholder_name + n._rename(new_placeholder_name) + i += 1 + gm.recompile() + try: + yield + finally: + # Used_names contains all our old placeholder names, + # so we clear it temporarily when we put them back + gm.graph._graph_namespace._used_names = set() + # Restore the placeholder names + i = 0 + for n in gm.graph.find_nodes(op="placeholder", sort=True): + if n.type != torch.SymInt: + (name, target) = old_placeholder_names[i] + n.target = target + n._rename(name) + i += 1 + assert i == len(old_placeholder_names) + # Now restore the old namespace's used names + gm.graph._graph_namespace._used_names = old_used_names + gm.recompile() + + +def autograd_cache_key( + gm: torch.fx.GraphModule, + example_inputs, + config: AOTConfig, + fx_config: _CompileFxKwargs, + # TODO: add args and parameters +) -> tuple[str, list[str]]: + """ + Generate a unique hash of the FX graph for caching. + """ + + try: + check_cacheable(gm) + if has_triton_package(): + # Due to https://github.com/triton-lang/triton/issues/3729, + # if triton is < 3.2.0, AOTAutogradCache may cause us to + # attempt to load a cache entry without initializing + # the CUDA context on the autograd thread. + + # Without caching, we naturally do this initialization when + # tracing through the graph with the autograd engine. + import triton + + if triton.__version__ < "3.2.0": + raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0") + details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config) + pickler = AOTAutogradCachePickler(gm) + # The prefix distinguishes among the other kinds of objects we cache + key = "a" + pickler.get_hash(details) + debug_lines = pickler.debug_lines(details) + log.debug( + "Autograd graph cache hash details for key %s:\n%s", + key, + LazyString(lambda: "\n".join(debug_lines)), + ) + return key, debug_lines + except Exception: + # If enable_aot_compile is set, we're in AOT precompile mode where we always + # want to use fallback nonce keys. Unlike caching, it's fine if we can't generate + # a proper key because we are guaranteed in an AOT precompile world users are in + # complete control of distributing and loading artifacts. + if torch._dynamo.config.enable_aot_compile: + log.info( + "Failed to generate AOTAutograd cache key; falling back to nonce due to enable_aot_compile", + exc_info=True, + ) + return str(random.random()), [] + else: + raise + + +@contextlib.contextmanager +def sanitize_gm_for_cache(gm: torch.fx.GraphModule): + """ + Clears a few fields in a dynamo supplied Graph Module that are not stable between graph inputs, but don't + affect inductor or aotdispatch correctness. + + These fields **can** be used by code calling into aotdispatch (namely, dynamo), so we can't null them out completely. + + To ensure that these fields are not accessed by inductor or aotdispatch, we clear them during AOTAutogradCache.load, + and then put them back before returning. This way, we generate a cache key based off of a canonical graph + without these fields, and also guarantee they aren't used to affect the cache's output. + """ + # Mapping from each field to a default value + IGNORED_FIELDS: dict[str, Any] = { + "meta": {}, # metadata used by export + "compile_subgraph_reason": None, # Used by dynamo only for logging, no change in inductor/autograd behavior + "_param_name_to_source": None, # Encapsulated by aot_config.aot_autograd_arg_pos_to_source + "_backend_id": None, + } + saved_fields = {} + for field, default_value in IGNORED_FIELDS.items(): + saved_fields[field] = getattr(gm, field, None) + # Clear the field + setattr(gm, field, default_value) + try: + with normalize_placeholder_names(gm): + yield + finally: + for field, value in saved_fields.items(): + setattr(gm, field, value) + + +@CacheArtifactFactory.register +class AOTAutogradCacheArtifact(CacheArtifact): + @override + def populate_cache(self): + AOTAutogradCache._write_to_local_cache(self.key, self.content) + + @override + @staticmethod + def type(): + return "aot_autograd" + + +class AOTAutogradCache(GuardedCache[GenericAOTAutogradResult]): + """ + Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas + AOTAutogradResult handles the wrapping/unwrapping logic. + + Cache Inputs (AOTAutogradCacheDetails) + - AOTAutogradCache takes in the following inputs, which are analogous to inputs given + to AOTAutograd by dynamo: + - A fx graph module generated by dynamo + - A list of args, which consists of: + - Symint inputs to the graph, generated by dynamo + - The **real tensor** inputs, which inductor uses for cudagraphs + - Notably, the real tensor inputs don't have symints in their metadata. + AOTAutograd then retraces those real tensor arguments into FakeTensors later during execution. + - A set of global configurations that affect AOTAutograd or Inductor behavior. + + It then generates a cache key given these values. Notably, this means AOTAutogradCache currently + specializes on the sizes and strides of the real tensor inputs when dynamic shapes are turned on. + In a later PR, we'll likely generate the cache key based on the FakeTensors AOTAutograd generates + based on the real tensor inputs, which can contain symints. + + # Cache Outputs (AOTAutogradResult) + - AOTAutogradCache caches the following values: + - The compiled forward and backward functions from inductor, via keys to the FXGraphCache + - Metadata to reconstruct the AOTModule from the compiled inductor artifacts + - See AOTAutogradResult for more info + + [Note: Caching guards generated by AOTAutograd and Inductor] + AOTAutograd and inductor both can introduce new guards to the shape environment. FXGraphCache saves guards with each + compiled graph inductor generates. On a cache hit, AOTAutograd reloads the compiled forward and backward functions + from FXGraphCache, giving it new symint arguments from the input args. + FXGraphCache uses those symints and its saved guards to repopulate the ShapeEnv with guards. + **No new guards are generated into the shape env after inductor finishes compiling**, so the guards + saved by inductor are sufficient for correctness for both AOTAutograd and Inductor's caches. + """ + + @staticmethod + def clear(): + """Clear the cache""" + try: + shutil.rmtree(AOTAutogradCache._get_tmp_dir()) + except FileNotFoundError: + pass + + @staticmethod + def try_load( + mod: Union[torch.fx.GraphModule, torch._dynamo.utils.GmWrapper], + args, + aot_config: AOTConfig, + cudagraphs: BoxedBool, + boxed_forward_device_index: Optional[BoxedDeviceIndex], + local: bool, + remote: bool, + ) -> Optional[Callable]: + """ + Load a result from the cache, and reconstruct a runtime wrapper around the object + """ + gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod + with sanitize_gm_for_cache(gm): + compiled_fn = None + cache_info: dict[str, Any] = {} + cache_key = None + debug_lines: list[str] = [] + cache_event_time = time.time_ns() + cache_state = None + fx_config: _CompileFxKwargs = { + "cudagraphs": cudagraphs, + "boxed_forward_device_index": boxed_forward_device_index, + } + try: + cache_key, debug_lines = autograd_cache_key( + gm, args, aot_config, fx_config + ) + result: Optional[tuple[GenericAOTAutogradResult, bytes]] = ( + AOTAutogradCache._lookup( + cache_key, local, remote, args, cache_info, aot_config + ) + ) + if result is not None: + (entry, pickled_content) = result + compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config) + # Make the compiled_fn serializable, where the serialize function just + # makes a copy of the original entry before post compile via the pickled content + compiled_fn = SerializableCompiledFunction( + compiled_fn, lambda: pickle.loads(pickled_content) + ) + log.info("AOTAutograd cache hit for key %s", cache_key) + + counters["aot_autograd"]["autograd_cache_hit"] += 1 + cache_state = "hit" + cache_event_time = time.time_ns() + forward_time_saved = entry.forward_time_taken_ns // 1e6 + backward_time_saved = entry.backward_time_taken_ns // 1e6 + cache_info.update( + { + "forward_time_saved_ms": forward_time_saved, + "backward_time_saved_ms": backward_time_saved, + "time_saved_ms": forward_time_saved + backward_time_saved, + } + ) + time_saved_ns = ( + entry.forward_time_taken_ns + entry.backward_time_taken_ns + ) + # TODO: should we use the same field for remote cache time saved for both + # FXGraphCache and AOTAutogradCache? + # get_metrics_context().increment(...) + if ( + ephemeral_increase + := add_ephemeral_timeout_increase_for_distributed(time_saved_ns) + ) != 0: + cache_info["ephemeral_timeout_increase"] = ephemeral_increase + + if compiled_fn is None: + log.info("AOTAutograd cache miss for key %s", cache_key) + counters["aot_autograd"]["autograd_cache_miss"] += 1 + cache_state = "miss" + cache_event_time = time.time_ns() + # Count missing the FXGraphCache as a miss not a bypass + except FXGraphCacheMiss as e: + counters["aot_autograd"]["autograd_cache_miss"] += 1 + cache_state = "miss" + if ( + config.strict_autograd_cache + or torch._dynamo.config.strict_precompile + ): + raise e + # Most often this is BypassAOTAutogradCache, but + # if there's ever different reason we can't cache, + # we still never want to hard throw an exception, since + # we can always fallback to a cache bypass. + # As an example, if the user calls autograd via + # standalone inductor, we will sometimes get a GraphModule + # that doesn't actually have a `.graph` on it. Instead + # of checking every single case, we safely catch the exception + # in those cases. + except Exception as e: + cache_key = None + counters["aot_autograd"]["autograd_cache_bypass"] += 1 + log.info("Bypassing autograd cache due to: %s", e) # noqa: G200 + cache_state = "bypass" + cache_event_time = time.time_ns() + cache_info["cache_bypass_reason"] = str(e) + cache_info["cache_bypass_exception_type"] = type(e).__name__ + cache_info["cache_bypass_traceback"] = traceback.format_exc().split( + "\n" + ) + # TODO: this gets logged implicitly by cache_bypass_reason, + # and here we explicitly log it into tlparse. + # We may want to log this as an extra column in Scuba, though. + cache_info["cache_bypass_hard_exception"] = not isinstance( + e, BypassAOTAutogradCache + ) + if remote: + log_cache_bypass("bypass_aot_autograd", str(e)) + if ( + config.strict_autograd_cache + or torch._dynamo.config.strict_precompile + ): + raise e + if compiled_fn is None: + # Set the cache key so we can save a cache result later + symints = AOTAutogradCache._filter_backed_symints(args) + if cache_key is not None: + aot_config.cache_info = AOTAutogradCacheInfo( + cache_key, + time.time_ns(), + forward_symints=symints, + ) + + cache_info.update( + { + "key": cache_key, + "cache_state": cache_state, + "components": debug_lines, + } + ) + if chromium_event_log_active(): + CompileEventLogger.instant( + f"autograd_cache_{cache_state}", + metadata=cache_info, + time_ns=cache_event_time, + ) + CompileEventLogger.try_add_pt2_compile( + "backend_compile", + cache_state=cache_state, + cache_event_time=cache_event_time, + key=cache_info.get("key"), + components=cache_info.get("components"), + cache_bypass_reason=cache_info.get("cache_bypass_reason"), + remote_cache_enabled=remote, + local_cache_enabled=local, + ) + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"aotautograd_cache_{cache_state}", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + + return compiled_fn + + @classmethod + def generate_guards_expression( + cls: type[AOTAutogradCache], cache_info: AOTAutogradCacheInfo + ) -> Optional[str]: + shape_env = cls._get_shape_env() + assert shape_env is not None + symints = cache_info.forward_symints + guards = shape_env.get_pruned_guards(symints) + return shape_env.produce_guards_expression(placeholders=symints, guards=guards) + + @classmethod + def _get_tmp_dir(cls: type[AOTAutogradCache]) -> str: + """ + Get the toplevel temporary directory for storing compiled graphs. + """ + return os.path.join(cache_dir(), "aotautograd") + + @classmethod + def _get_tmp_dir_for_key(cls: type[AOTAutogradCache], key) -> str: + """ + Get the toplevel temporary directory for storing compiled graphs. + """ + return os.path.join(cls._get_tmp_dir(), key) + + @staticmethod + def evaluate_guards(guard_expr: str, hints: Union[list[int], list[torch.SymInt]]): + if torch._inductor.config.unsafe_skip_cache_dynamic_shape_guards: + return True + shape_env = AOTAutogradCache._get_shape_env() + assert shape_env is not None + result = shape_env.evaluate_guards_expression(guard_expr, hints) + return result + + @staticmethod + def _lookup( + key: str, + local: bool, + remote: bool, + args: list[Any], + cache_info: dict[str, Any], + aot_config: Optional[AOTConfig], + ) -> Optional[tuple[GenericAOTAutogradResult, bytes]]: + """Given a key generated by AOTAutogradCachePickler, look up its location in the cache.""" + remote_cache: Optional[RemoteCache[JsonDataTy]] = None + if remote: + remote_cache = AOTAutogradCache.get_remote_cache() + + symints = AOTAutogradCache._filter_backed_symints(args) + hints = [hint_int(s) for s in symints] + entry = None + pickled_content = None + try: + ( + entry, + pickled_content, + guard_info, + ) = AOTAutogradCache.find_guarded_entry( + key, local, remote_cache, AOTAutogradCache.evaluate_guards, hints + ) + + if entry is None and guard_info["cache_status_detailed"] == "guard_miss": + counters["aot_autograd"]["autograd_cache_guard_miss"] += 1 + cache_info.update(guard_info) + if pickled_content is not None: + CacheArtifactManager.record_artifact( + AOTAutogradCacheArtifact.type(), key, pickled_content + ) + if ( + should_bundle_autograd_cache() + and aot_config is not None + and aot_config.precompile_backend_id is not None + ): + # NB: We don't want to use the cached aot_config.precompile_backend_id + # 1. because we set it to None on save 2. even if we didn't, this new run + # that cache hit has a *new* backend id associated with it. + PrecompileContext.record_artifact( + BundledAOTAutogradCacheArtifact( + aot_config.precompile_backend_id, entry + ), + ) + except Exception as e: + log.info("AOTAutograd cache unable to load compiled graph: %s", e) # noqa: G200 + if config.strict_autograd_cache: + raise e + if entry is not None: + assert pickled_content is not None + return (entry, pickled_content) + else: + return None + + @staticmethod + def _write_to_local_cache(key: str, content: bytes): + """Write an entry to the local cache.""" + subdir = AOTAutogradCache._get_tmp_dir_for_key(key) + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + + # Use a hash of the serialized entry to get a unique file + # name. The specific name doesn't matter since a lookup involves + # iterating over all entries in the parent subdir. + path = os.path.join(subdir, sha256_hash(content)) + log.info("Writing AOTAutograd cache entry to %s", path) + write_atomic(path, content) + + @staticmethod + def save(key: str, entry: GenericAOTAutogradResult, remote: bool): + """Save a single entry into the cache.""" + try: + entry.pre_save() + content = pickle.dumps(entry) + CacheArtifactManager.record_artifact( + AOTAutogradCacheArtifact.type(), key, content + ) + if ( + should_bundle_autograd_cache() + and entry.sanitized_aot_config.precompile_backend_id is not None + ): + precompile_key = entry.sanitized_aot_config.precompile_backend_id + artifact = BundledAOTAutogradCacheArtifact(precompile_key, entry) + # Now that we're saving it, the precompile_backend_id field is no longer + # useful, remove it from the entry. + entry.sanitized_aot_config.precompile_backend_id = None + PrecompileContext.record_artifact(artifact) + AOTAutogradCache._write_to_local_cache(key, content) + counters["aot_autograd"]["autograd_cache_saved"] += 1 + except BypassAOTAutogradCache as e: + counters["aot_autograd"]["autograd_cache_bypass"] += 1 + log.info("Bypassing autograd cache due to: %s", e) # noqa: G200 + if remote: + log_cache_bypass("bypass_aot_autograd", str(e)) + return None + except Exception as e: + log.info("AOTAutograd cache unable to serialize compiled graph: %s", e) # noqa: G200 + if remote: + log_cache_bypass( + "bypass_aot_autograd", "Unable to serialize: " + str(e) + ) + if config.strict_autograd_cache: + raise e + return None + + if remote: + remote_cache: Optional[RemoteCache[JsonDataTy]] = ( + AOTAutogradCache.get_remote_cache() + ) + if remote_cache is not None: + time_taken_ms = int( + (entry.forward_time_taken_ns + entry.backward_time_taken_ns) // 1e6 + ) + cache_data: JsonDataTy = { + "data": base64.b64encode(content).decode("ascii"), + "time_taken_ms": time_taken_ms, + } + remote_cache.put(key, cache_data) + + @staticmethod + @functools.cache + def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: + """ + Attempts to load the remote cache, returns None on error. + """ + cache_id = "autograd-experimental" + return create_cache( + cache_id, + config.is_fbcode(), + "FbRemoteAOTAutogradCache", + "RemoteAOTAutogradCache", + ) + + @staticmethod + def make_entry( + compiled_fw_func: OutputCode, + compiled_bw_func: Optional[OutputCode], + aot_joint_graph_str: Optional[str], + aot_forward_graph_str: Optional[str], + aot_backward_graph_str: Optional[str], + runtime_metadata: ViewAndMutationMeta, + dispatch_wrappers: list[CompilerWrapper], + maybe_subclass_meta: Optional[SubclassMeta], + num_fw_outs_saved_for_bw: Optional[int], + indices_of_inps_to_detach: list[int], + forward_time_taken_ns: int, + backward_time_taken_ns: int, + sanitized_aot_config: AOTConfig, + guards_expr: Optional[str], + backward_state_indices: Optional[list[int]], + num_symints_saved_for_bw: Optional[int], + serialized_bw_module: Optional[SerializedGraphModule], + ) -> GenericAOTAutogradResult: + if should_bundle_autograd_cache(): + # Helper function to unwrap all the wrappers we added during aotdispatch + # They get reapplied on cache load + def unwrap_output_code(obj): + while hasattr(obj, "__wrapped__"): + obj = obj.__wrapped__ + assert isinstance(obj, OutputCode) + return obj + + compiled_fw_graph = unwrap_output_code(compiled_fw_func) + bundled_compiled_forward = BundledCompiledForward(compiled_fw_graph) + bundled_compiled_backward = None + if compiled_bw_func is not None: + assert backward_state_indices is not None + assert num_symints_saved_for_bw is not None + compiled_bw_graph = unwrap_output_code(compiled_bw_func) + bundled_compiled_backward = BundledCompiledBackward( + compiled_bw_graph, backward_state_indices, num_symints_saved_for_bw + ) + + return BundledAOTAutogradResult( + compiled_fw=bundled_compiled_forward, + compiled_bw=bundled_compiled_backward, + aot_joint_graph_str=aot_joint_graph_str, + aot_forward_graph_str=aot_forward_graph_str, + aot_backward_graph_str=aot_backward_graph_str, + runtime_metadata=runtime_metadata, + dispatch_wrappers=dispatch_wrappers, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + indices_of_inps_to_detach=indices_of_inps_to_detach, + forward_time_taken_ns=forward_time_taken_ns, + backward_time_taken_ns=backward_time_taken_ns, + sanitized_aot_config=sanitized_aot_config, + guards_expr=guards_expr, + serialized_bw_module=serialized_bw_module, + ) + + else: + fw_key = getattr(compiled_fw_func, "_fx_graph_cache_key", None) + fw_debug_lines = getattr( + compiled_fw_func, "_fx_graph_cache_debug_lines", [] + ) + + assert fw_key is not None + compiled_forward = CompiledForward( + fx_graph_cache_info=(fw_key, fw_debug_lines), + fx_graph_guard_expr=getattr(compiled_fw_func, "guards_expr", None), + ) + compiled_backward = None + if compiled_bw_func is not None: + bw_key = getattr(compiled_bw_func, "_fx_graph_cache_key", None) + bw_debug_lines = getattr( + compiled_bw_func, "_fx_graph_cache_debug_lines", [] + ) + assert bw_key is not None + assert backward_state_indices is not None + assert num_symints_saved_for_bw is not None + compiled_backward = CompiledBackward( + fx_graph_cache_info=(bw_key, bw_debug_lines), + fx_graph_guard_expr=getattr(compiled_bw_func, "guards_expr", None), + backward_state_indices=backward_state_indices, + num_symints_saved_for_bw_=num_symints_saved_for_bw, + ) + + return AOTAutogradResult( + compiled_fw=compiled_forward, + compiled_bw=compiled_backward, + aot_joint_graph_str=aot_joint_graph_str, + aot_forward_graph_str=aot_forward_graph_str, + aot_backward_graph_str=aot_backward_graph_str, + runtime_metadata=runtime_metadata, + dispatch_wrappers=dispatch_wrappers, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + indices_of_inps_to_detach=indices_of_inps_to_detach, + forward_time_taken_ns=forward_time_taken_ns, + backward_time_taken_ns=backward_time_taken_ns, + sanitized_aot_config=sanitized_aot_config, + guards_expr=guards_expr, + serialized_bw_module=serialized_bw_module, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/frontend_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/frontend_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..041d321fec56da208dff93ccac9cd85eabd3b4c0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/frontend_utils.py @@ -0,0 +1,336 @@ +# mypy: ignore-errors + +import warnings +from collections.abc import KeysView +from contextlib import contextmanager +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +from torch._guards import detect_fake_mode +from torch._library.opaque_object import is_opaque_type +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config +from .descriptors import BufferAOTInput, DifferentiableAOTInput, ParamAOTInput +from .schemas import AOTConfig, FakifiedFlatArgs + + +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + + +def process_inputs( + flat_args: list[Any], + aot_config: AOTConfig, + fake_mode: FakeTensorMode, + shape_env: Optional[ShapeEnv], + ignore_shape_env: bool = False, +) -> FakifiedFlatArgs: + with fake_mode: + + def convert(idx, x): + if shape_env is not None and not ignore_shape_env: + from torch._dynamo.source import ConstantSource + + if isinstance(x, int): + # We always specialize on scalar values in export. + if aot_config.is_export: + return x + source = ConstantSource(f"sym_{idx}") + return shape_env.create_symintnode( + shape_env.create_symbol(x, source, positive=x >= 0), + hint=x, + source=source, + ) + if isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)): + return torch._library.fake_class_registry.maybe_to_fake_obj( + fake_mode, x + ) + if not isinstance(x, torch.Tensor): + return x + if isinstance(x, FakeTensor): + assert x.fake_mode is fake_mode + return x + if is_traceable_wrapper_subclass(x): + attrs, _ = x.__tensor_flatten__() + if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): + assert all( + getattr(x, attr).fake_mode is fake_mode for attr in attrs + ) + return x + + # see note [Tensor Fakification and Symbol Caching] + symbolic_context = None + source = None + trace = True + if tracing_context := torch._guards.TracingContext.try_get(): + if x in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[x] + source = symbolic_context.tensor_source + # We already fakeified this tensor in Dynamo, don't + # dump the trace for it again + trace = False + if ( + idx < aot_config.num_params_buffers + and config.static_weight_shapes + and not symbolic_context + ): + # TODO: Ensure that this codepath is never exercised from + # Dynamo + return fake_mode.from_tensor(x, static_shapes=True) + + result = fake_mode.from_tensor( + x, + static_shapes=ignore_shape_env, + symbolic_context=symbolic_context, + source=source, + trace=trace, + ) + return result + + return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)]) + + +def construct_fake_mode( + flat_args: list[Any], aot_config: AOTConfig +) -> tuple[FakeTensorMode, Optional[ShapeEnv]]: + fake_mode = detect_fake_mode(flat_args) + if fake_mode is None: + shape_env = ShapeEnv() if aot_config.dynamic_shapes else None + fake_mode = FakeTensorMode(shape_env=shape_env) + else: + shape_env = fake_mode.shape_env + return (fake_mode, shape_env) + + +def _try_get_metadata_from_dynamo( + mod: torch.nn.Module, + param_keys: KeysView[str], + full_args_num: int, + full_args_descs: list[DifferentiableAOTInput], +) -> tuple[Optional[list[torch._guards.Source]], list[int]]: + """ + Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule. + We first verify that `mod` does come from Dynamo, then we handle cases where + metadata might be missing. + + Returns: + aot_autograd_arg_pos_to_source: used to dedup params and their guards + static_input_indices: used to identify static inputs for cudagraphs + """ + # Note [Assumption on Dynamo Metadata] + # This function assumes a graph module from dynamo provides `dynamo_compiled_id`, + # _param_name_to_source, and every placeholder node has `_dynamo_source` attributes. + # When gm is modified (e.g., DDPOptimizer via split_module), metadata needs to + # be propagated in order to be recognized as a dynamo graph + + if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta): + # graph was not captured by dynamo + return None, [] + + if not hasattr(mod, "_param_name_to_source"): + # is from export + static_input_indices = [ + i + for i, node in enumerate(full_args_descs) + if isinstance(node, (ParamAOTInput, BufferAOTInput)) + ] + return None, static_input_indices + + # We now know this came from dynamo, and (1) we care about guards, + # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards + # can now be done safely. (2) Dynamo logic protects the 1:1 sizing below. + # Additionally, we mark static indices for cudagraphs. + param_name_to_source = mod._param_name_to_source + seen_sources = set() + + aot_autograd_arg_pos_to_source = [] + static_input_indices = [] + # Collect the new inputs lifted by aotdispatch + for i, name in enumerate(param_keys): + assert name in param_name_to_source, f"{name} not found." + source = param_name_to_source[name] + assert source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + + static_input_indices.append(i) + + # Collect the dynamo graph inputs + # TODO(mlazos): Revisit if this is still needed. With Dynamo install ID + # matched tensors back into the Fx graph, this might not be necessary. + for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")): + assert hasattr(node, "_dynamo_source") + source = node._dynamo_source + # `source`` specifies the source from user code. ddp optimizer may have + # intermediate values becoming submodule placeholders which does not + # have a source + assert source is None or source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + source_name = source.name if source else str(source) + + # input[i] in dynamo is now: + # input[i + len(extra_params)] in AOT, + # where extra_params are the params/buffers that dynamo baked into the + # OutputGraph + actual_pos = pos + len(param_keys) + + if "tensor_dict" in node.meta and node.meta["tensor_dict"].get( + "_dynamo_static_input_type", None + ): + static_inputs_log.debug( + "Adding static input pos %s for source %s", actual_pos, source_name + ) + static_input_indices.append(actual_pos) + else: + static_inputs_log.debug( + "Non-static input pos %s for source %s", actual_pos, source_name + ) + + assert full_args_num == len(aot_autograd_arg_pos_to_source) + return aot_autograd_arg_pos_to_source, static_input_indices + + +@contextmanager +def _detect_attribute_assignment(mod: torch.nn.Module): + # Do not allow assignment of tensor attributes during export unless + # the attribute is registered as a buffer. + + NN_MODULE_STD_ATTRS = [ + "_backward_hooks", + "_backward_pre_hooks", + "_buffers", + "_forward_hooks", + "_forward_hooks_always_called", + "_forward_hooks_with_kwargs", + "_forward_pre_hooks", + "_forward_pre_hooks_with_kwargs", + "_is_full_backward_hook", + "_load_state_dict_post_hooks", + "_load_state_dict_pre_hooks", + "_modules", + "_non_persistent_buffers_set", + "_parameters", + "_state_dict_hooks", + "_state_dict_pre_hooks", + "training", + ] + NN_MODULE_LAZY_STD_ATTRS = [ + "_initialize_hook", + "_load_hook", + ] + STD_ATTRS = { + *NN_MODULE_STD_ATTRS, + *NN_MODULE_LAZY_STD_ATTRS, + } + + def _get_attributes(mod): + # return any attributes of a module that are not standard attributes + return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS} + + def _get_all_module_attributes(mod): + # return attributes from all modules and submodules + result = {} + for name, submodule in mod.named_modules(): + result[name] = _get_attributes(submodule) + return result + + def _restore_all_module_attributes(mod, snapshot): + # restore attributes to all modules and submodules + for name, submodule in mod.named_modules(): + if name in snapshot: + submodule.__dict__.update(snapshot[name]) + + # save state of attributes before enter + snapshot = pytree.tree_map( + lambda x: x, + _get_all_module_attributes(mod), + is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info, + ) + try: + yield + finally: + # after exit, compare state of attributes with snapshot + # to detect which tensor attributes were assigned + + def _collect_assigned_tensor_attributes(snapshot, new_attrs): + assigned_tensor_attributes = [] + + def _compare_values(path, old_val, new_val): + """Recursively compare values, handling containers.""" + # Same object, no change + if old_val is new_val: + return + + if old_val is None or new_val is None: + if isinstance(new_val, torch.Tensor): + assigned_tensor_attributes.append(path) + return + + # Check if it's a tensor that was reassigned + if isinstance(new_val, torch.Tensor): + assigned_tensor_attributes.append(path) + return + + # Handle dict containers + if isinstance(old_val, dict) and isinstance(new_val, dict): + all_keys = set(old_val.keys()) | set(new_val.keys()) + for key in all_keys: + old_item = old_val.get(key) + new_item = new_val.get(key) + _compare_values(f"{path}[{key!r}]", old_item, new_item) + return + + # Handle list/tuple containers + if isinstance(old_val, (list, tuple)) and isinstance( + new_val, (list, tuple) + ): + # Different lengths = mutation happened + max_len = max(len(old_val), len(new_val)) + for i in range(max_len): + old_item = old_val[i] if i < len(old_val) else None + new_item = new_val[i] if i < len(new_val) else None + _compare_values(f"{path}[{i}]", old_item, new_item) + return + + # For other types, just check if they're different objects + # (we don't care about non-tensor mutations) + + for module_name in snapshot.keys() | new_attrs.keys(): + old_module_attrs = snapshot.get(module_name, {}) + new_module_attrs = new_attrs.get(module_name, {}) + + for attr_name in old_module_attrs.keys() | new_module_attrs.keys(): + module_prefix = f"self.{module_name}." if module_name else "self." + full_path = f"{module_prefix}{attr_name}" + + old_val = old_module_attrs.get(attr_name) + new_val = new_module_attrs.get(attr_name) + _compare_values(full_path, old_val, new_val) + + return assigned_tensor_attributes + + new_attrs = _get_all_module_attributes(mod) + assigned_tensor_attributes = _collect_assigned_tensor_attributes( + snapshot, new_attrs + ) + # restore state of all attributes (including, e.g., of primitive types) + _restore_all_module_attributes(mod, snapshot) + + if assigned_tensor_attributes: + if len(assigned_tensor_attributes) > 1: + noun, verb = "attributes", "were" + else: + noun, verb = "attribute", "was" + warnings.warn( + f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. " + "Such attributes must be registered as buffers using the `register_buffer` API " + "(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer).", + stacklevel=2, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/fx_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/fx_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..491cf3e1fe8cfad65cea4394b0eb2bcbe9832910 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/fx_utils.py @@ -0,0 +1,317 @@ +""" +This module contains utility functions for working with joint FX graphs with descriptors +that are produced by AOTAutograd. They will NOT work on generic FX graphs. See also +:func:`torch._functorch.aot_autograd.aot_export_joint_with_descriptors`. We also +recommend reading :mod:torch._functorch._aot_autograd.descriptors`. +""" + +from typing import NoReturn, Optional, Union + +import torch.fx as fx + +from .descriptors import ( + AOTInput, + AOTOutput, + BufferAOTInput, + DifferentiableAOTInput, + DifferentiableAOTOutput, + GradAOTOutput, + ParamAOTInput, + PlainAOTInput, + PlainAOTOutput, + SubclassGetAttrAOTInput, + SubclassGetAttrAOTOutput, + TangentAOTInput, +) + + +def _raise_autograd_subclass_not_implemented( + n: fx.Node, desc: Union[AOTInput, AOTOutput] +) -> NoReturn: + raise RuntimeError( + "Subclasses are currently not supported by this function, but a desugared subclass input " + f"was found at {n} ({desc}). The problem is " + "that there may not necessarily be a 1-1 correspondence between primals/tangents/outputs/grads " + "when subclasses are involved: for example, the primal might be a plain tensor " + "but the tangent a tensor subclass that desugared into multiple plain tensors. " + "It is not clear what exactly you would like this function to do in this case " + "(Collect all nodes for the subclass together? Match up the inner nodes if " + "subclasses match exactly?) If you have a concrete use case, please file an " + "issue so we can understand it and design an API that works for your case." + ) + + +def get_all_input_and_grad_nodes( + g: fx.Graph, +) -> dict[DifferentiableAOTInput, tuple[fx.Node, Optional[fx.Node]]]: + """ + Given a joint graph with descriptors (meta['desc'] on placeholders and + output), returns the node for every input and its corresponding grad + output node if it exists. These tuples are in a dict that is indexed by + the AOTInput descriptor that describes the input. + + NB: *all* forward tensor inputs are returned, including non-differentiable + inputs (which simply have a None grad), so it is safe to use this function + to perform operations on all inputs. (Non-tensor inputs like symbolic + integers, tokens or RNG state are NOT traversed by this function.) + + Args: + g: The FX joint graph with descriptors + + Returns: + A dictionary mapping each DifferentiableAOTInput descriptor to a tuple + containing: + - The input node itself + - The grad (output) node if it exists, None otherwise + + Raises: + RuntimeError: If the joint graph has subclass tensor inputs/outputs; this + is not supported by API as there is not necessarily a 1-1 correspondence + between inputs and grads when subclasses are involved. + """ + input_index: dict[DifferentiableAOTInput, tuple[fx.Node, Optional[fx.Node]]] = {} + for n in g.nodes: + if n.op == "placeholder": + desc = n.meta["desc"] + # Skip inputs that cannot possibly be differentiable + if not isinstance(desc, DifferentiableAOTInput): + continue + if isinstance(desc, SubclassGetAttrAOTInput): + _raise_autograd_subclass_not_implemented(n, desc) + # pyrefly: ignore [unsupported-operation] + input_index[desc] = (n, None) + elif n.op == "output": + assert "desc" in n.meta, (n, n.meta) + desc = n.meta["desc"] + for sub_n, sub_desc in zip(n.args[0], desc): + if isinstance(sub_desc, SubclassGetAttrAOTOutput): + _raise_autograd_subclass_not_implemented(sub_n, sub_desc) + if isinstance(sub_desc, GradAOTOutput): + inp, grad = input_index[sub_desc.grad_of] + assert grad is None, (sub_n, sub_desc, input_index) + input_index[sub_desc.grad_of] = (inp, sub_n) + return input_index + + +def get_all_output_and_tangent_nodes( + g: fx.Graph, +) -> dict[DifferentiableAOTOutput, tuple[fx.Node, Optional[fx.Node]]]: + """Get all output nodes and their corresponding tangent nodes from a joint graph. + + Similar to get_all_input_and_grad_nodes, but returns output nodes paired with + their tangent nodes (if they exist). This function traverses the graph to find + all differentiable outputs and matches them with their corresponding tangent + inputs used in forward-mode autodiff. + + NB: *all* forward tensor output sare turned, including non-differentiable outputs, + so you can use this function to perform operations on all outputs. + + Args: + g: The FX joint graph with descriptors + + Returns: + A dictionary mapping each DifferentiableAOTOutput descriptor to a tuple + containing: + - The output node itself + - The tangent (input) node if it exists, None otherwise + + Raises: + RuntimeError: If the joint graph has subclass tensor inputs/outputs; this + is not supported by API as there is not necessarily a 1-1 correspondence + between outputs and tangents when subclasses are involved. + """ + output_index: dict[DifferentiableAOTOutput, tuple[fx.Node, Optional[fx.Node]]] = {} + for n in g.nodes: + if n.op == "output": + desc = n.meta["desc"] + for sub_n, sub_d in zip(n.args[0], desc): + # Skip outputs that cannot possibly be differentiable + if not isinstance(sub_d, DifferentiableAOTOutput): + continue + if isinstance(sub_d, SubclassGetAttrAOTOutput): + _raise_autograd_subclass_not_implemented(sub_n, sub_d) + # pyrefly: ignore [unsupported-operation] + output_index[sub_d] = (sub_n, None) + for n in g.nodes: + if n.op == "placeholder": + desc = n.meta["desc"] + if isinstance(desc, SubclassGetAttrAOTInput): + _raise_autograd_subclass_not_implemented(n, desc) + if isinstance(desc, TangentAOTInput): + out, tangent = output_index[desc.output] + assert tangent is None, (n, desc, output_index) + output_index[desc.output] = (out, n) + return output_index + + +def get_param_and_grad_nodes( + graph: fx.Graph, +) -> dict[ParamAOTInput, tuple[fx.Node, Optional[fx.Node]]]: + """Get parameter nodes and their corresponding gradient nodes from a joint graph. + + Args: + graph: The FX joint graph with descriptors + + Returns: + A dictionary mapping each ParamAOTInput descriptor to a tuple containing: + - The parameter input node + - The gradient (output) node if it exists, None otherwise + """ + return { + desc: (n, g) + for desc, (n, g) in get_all_input_and_grad_nodes(graph).items() + if isinstance(desc, ParamAOTInput) + } + + +def get_plain_input_and_grad_nodes( + graph: fx.Graph, +) -> dict[PlainAOTInput, tuple[fx.Node, Optional[fx.Node]]]: + """Get plain input nodes and their corresponding gradient nodes from a joint graph. + + Args: + graph: The FX joint graph with descriptors + + Returns: + A dictionary mapping each PlainAOTInput descriptor to a tuple containing: + - The plain input node + - The gradient (output) node if it exists, None otherwise + """ + return { + desc: (n, g) + for desc, (n, g) in get_all_input_and_grad_nodes(graph).items() + if isinstance(desc, PlainAOTInput) + } + + +def get_plain_output_and_tangent_nodes( + graph: fx.Graph, +) -> dict[PlainAOTOutput, tuple[fx.Node, Optional[fx.Node]]]: + """Get plain output nodes and their corresponding tangent nodes from a joint graph. + + Args: + graph: The FX joint graph with descriptors + + Returns: + A dictionary mapping each PlainAOTOutput descriptor to a tuple containing: + - The plain output node + - The tangent (input) node if it exists, None otherwise + """ + return { + desc: (n, g) + for desc, (n, g) in get_all_output_and_tangent_nodes(graph).items() + if isinstance(desc, PlainAOTOutput) + } + + +def _raise_fqn_subclass_not_implemented( + n: fx.Node, desc: Union[AOTInput, AOTOutput] +) -> NoReturn: + raise RuntimeError( + "Subclasses are currently not supported by this function, but a desugared subclass input " + f"was found at {n} ({desc}). The problem is " + "that there may not necessarily be a 1-1 correspondence between a FQN and a plain tensor " + "when subclasses are involved: for example, a parameter that is a subclass " + "would desugar into multiple plain tensors, which we can't uniquely assign the " + "FQN to. It's not clear what you want the API to do in this case: do you want to " + "instead return a struct of nodes showing how to assemble the subclass? But you " + "don't (directly) have the metadata for the subclass? If you have a concrete use " + "case, please file an issue so we can understand it and design an API that works for your case." + ) + + +def get_named_param_nodes(graph: fx.Graph) -> dict[str, fx.Node]: + """Get parameter nodes mapped by their fully qualified names. + + This function traverses the graph to find all parameter input nodes and + returns them in a dictionary where keys are the parameter names (FQNs) + and values are the corresponding FX nodes. + + Args: + graph: The FX joint graph with descriptors + + Returns: + A dictionary mapping parameter names (str) to their corresponding FX nodes. + + Raises: + RuntimeError: If subclass tensors are encountered (not yet supported), as + with subclasses a FQN does not necessarily map to a single plain tensor. + """ + r = {} + for n in graph.nodes: + if n.op == "placeholder": + desc = n.meta["desc"] + if isinstance(desc, SubclassGetAttrAOTInput): + _raise_fqn_subclass_not_implemented(n, desc) + elif isinstance(desc, ParamAOTInput): + r[desc.target] = n + return r + + +def get_named_buffer_nodes(graph: fx.Graph) -> dict[str, fx.Node]: + """Get buffer nodes mapped by their fully qualified names. + + This function traverses the graph to find all buffer input nodes and + returns them in a dictionary where keys are the buffer names (FQNs) + and values are the corresponding FX nodes. + + Args: + graph: The FX joint graph with descriptors + + Returns: + A dictionary mapping buffer names (str) to their corresponding FX nodes. + + Raises: + RuntimeError: If subclass tensors are encountered (not yet supported), as + with subclasses a FQN does not necessarily map to a single plain tensor. + """ + r = {} + for n in graph.nodes: + if n.op == "placeholder": + desc = n.meta["desc"] + if isinstance(desc, SubclassGetAttrAOTInput): + _raise_fqn_subclass_not_implemented(n, desc) + elif isinstance(desc, BufferAOTInput): + r[desc.target] = n + return r + + +def get_param_nodes(graph: fx.Graph) -> list[fx.Node]: + """Get all parameter nodes from a graph as a list. + + You can rely on this providing the correct order of parameters you need + to feed into the joint graph (at the very beginning of the argument list, + before buffers). + + Args: + graph: The FX joint graph with descriptors + + Returns: + A list of FX nodes representing all parameters in the graph. + + Raises: + RuntimeError: If subclass tensors are encountered (not yet supported), as + it is not clear if you wanted each individual constituent piece of the + subclasses, or have them grouped up in some way. + """ + return list(get_named_param_nodes(graph).values()) + + +def get_buffer_nodes(graph: fx.Graph) -> list[fx.Node]: + """Get all buffer nodes from a graph as a list. + + You can rely on this providing the correct order of buffers you need + to feed into the joint graph (after parameters). + + Args: + graph: The FX joint graph with descriptors + + Returns: + A list of FX nodes representing all buffers in the graph. + + Raises: + RuntimeError: If subclass tensors are encountered (not yet supported), as + it is not clear if you wanted each individual constituent piece of the + subclasses, or have them grouped up in some way. + """ + return list(get_named_buffer_nodes(graph).values()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/indexed_dict.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/indexed_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..39a06996c6e08f1f3ac519e549f5012ffa8728eb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/indexed_dict.py @@ -0,0 +1,54 @@ +from collections.abc import Iterator, MutableMapping +from typing import Generic, Optional, TypeVar + + +K = TypeVar("K") +V = TypeVar("V") + + +# Used for fast next key access (using the fact that the dict is ordered) +# Note: doesn't support deletion but we don't need it! +class IndexedDict(MutableMapping[K, V], Generic[K, V]): + """A dict that maintains insertion order with O(1) index access.""" + + __slots__ = ("_dict", "_keys", "_key_to_index") + + def __init__(self) -> None: + self._dict: dict[K, V] = {} + self._keys: list[K] = [] # typing: ignore[bad-override] + self._key_to_index: dict[K, int] = {} + + def __setitem__(self, key: K, value: V) -> None: + if key not in self._dict: + self._key_to_index[key] = len(self._keys) + self._keys.append(key) + self._dict[key] = value + + def __getitem__(self, key: K) -> V: + return self._dict[key] + + def __delitem__(self, key: K) -> None: + raise NotImplementedError("Deletion not supported for IndexedDict") + + def __len__(self) -> int: + return len(self._dict) + + def __iter__(self) -> Iterator[K]: + return iter(self._keys) + + def __contains__(self, key: object) -> bool: + return key in self._dict + + def next_key(self, key: K) -> Optional[K]: + """Get the next key in insertion order. O(1).""" + idx = self._key_to_index.get(key) + if idx is not None and idx + 1 < len(self._keys): + return self._keys[idx + 1] + return None + + def prev_key(self, key: K) -> Optional[K]: + """Get the previous key in insertion order. O(1).""" + idx = self._key_to_index.get(key) + if idx is not None and idx > 0: + return self._keys[idx - 1] + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/input_output_analysis.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/input_output_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..06581e1524fdef15475d9e9fc907b40ec858ad4b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -0,0 +1,466 @@ +# mypy: allow-untyped-defs +""" +This module is one of the analysis modules - it takes as input a function or graph +and some preexisting properties, and returns some data that is useful for deciding +how to further proceed with compilation or construct runtime wrappers. + +In particular, the following analyses are provided: +1. Refine the view and mutation metadata collected previously - removing duplicate + inputs or mapping views to their bases. +2. We also analyze the function signature for export graphs. +""" + +import contextlib +import itertools +from typing import Any, Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch import Tensor +from torch._C._dynamo.guards import compute_overlapping_tensors +from torch._functorch._aot_autograd.schemas import PlainTensorMeta +from torch._guards import StorageOverlap +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.fx.experimental.symbolic_shapes import is_concrete_int + +from .collect_metadata_analysis import coerce_tangent_and_suggest_memory_format +from .descriptors import AOTInput, InputMutationAOTOutput, TangentAOTInput +from .schemas import ( + BackwardSignature, + GraphSignature, + InputAliasInfo, + MemoryFormatMeta, + OutputAliasInfo, + OutputType, + ViewAndMutationMeta, +) +from .utils import strict_zip + + +zip = strict_zip + + +def remove_dupe_metadata( + m: ViewAndMutationMeta, + keep_arg_mask: list[bool], + add_dupe_map: list[int], +) -> ViewAndMutationMeta: + assert len(m.input_info) == len(keep_arg_mask) + # Easy invariant: the first argument should never be a dupe (it will be kept) + assert len(keep_arg_mask) > 0 and keep_arg_mask[0] + + # Filter dupe'd mutated inputs out of traced_tangents + num_data_mutations = len([x for x in m.input_info if x.mutates_data]) + other_traced_tangents = m.traced_tangents[num_data_mutations:] + inp_traced_tangents = m.traced_tangents[:num_data_mutations] + other_traced_tangents_descs = m.traced_tangents_descs[num_data_mutations:] + inp_traced_tangents_descs = m.traced_tangents_descs[:num_data_mutations] + filtered_inp_traced_tangents = [ + # See Note [Tangents memory format] + x + for i, x in enumerate(inp_traced_tangents) + if keep_arg_mask[m.mutated_inp_runtime_indices[i]] + ] + filtered_inp_traced_tangents_descs = [ + x_desc + for i, x_desc in enumerate(inp_traced_tangents_descs) + if keep_arg_mask[m.mutated_inp_runtime_indices[i]] + ] + traced_tangents = filtered_inp_traced_tangents + other_traced_tangents + traced_tangents_descs = ( + filtered_inp_traced_tangents_descs + other_traced_tangents_descs + ) + + assert m.subclass_tangent_meta is not None + subclass_tangent_meta = [ + PlainTensorMeta( + 0, memory_format=MemoryFormatMeta(memory_format=torch.contiguous_format) + ) + ] * len(filtered_inp_traced_tangents) + m.subclass_tangent_meta[num_data_mutations:] + + return ViewAndMutationMeta( + input_info=[x for i, x in enumerate(m.input_info) if keep_arg_mask[i]], + # For outputs that are views of inputs, we store the index of the input that the output + # was generated from. Need to update that index to account for removed dupes. + output_info=[ + OutputAliasInfo( + output_type=o.output_type, + raw_type=o.raw_type, + dynamic_dims=o.dynamic_dims, + base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx], + requires_grad=o.requires_grad, + view_meta_sequence=o.view_meta_sequence, + ) + for o in m.output_info + ], + num_intermediate_bases=m.num_intermediate_bases, + keep_input_mutations=m.keep_input_mutations, + traced_tangents=traced_tangents, + traced_tangents_descs=traced_tangents_descs, + # We are guaranteed not to get here, since dupes are not supported today with subclass inputs. + subclass_inp_meta=[], + subclass_fw_graph_out_meta=[], + subclass_tangent_meta=subclass_tangent_meta, + is_train=m.is_train, + ) + + +# Given our ViewAndMutation metadata, this fn constructs a new set of metadata, +# after adding synthetic base arguments to the function. +# Most of the work in this fn is slogging through all of the metadata corresponding to inputs, +# and updating it with our synthetic base calling convention. +# +# When config.debug_assert is set, we automatically regenerate the metadata +# and compare it to this output for sanity. +# +# In addition to the updated metadata, also return the list of input indices +# that will need to be updated in the synthetic base epilogue +def create_synthetic_base_metadata( + m: ViewAndMutationMeta, + # Maps each outer argument idx to its inner idx (or, if this outer arg is generated from a + # synthetic base, you get a tuple of (i, TensorMeta), telling you the base tensor idx, and view metadata) + synthetic_base_info: list[Union[int, tuple[int, torch.Tensor]]], + outer_args: list[Any], + inner_args: list[Any], + inner_args_desc: list[AOTInput], +) -> tuple[ViewAndMutationMeta, list[int]]: + # maps inner arg indices to outer arg indices + synthetic_base_to_indices: dict[int, list[int]] = {} + for inner_idx in range(len(inner_args)): + outer_aliased_indices_of_current_base_arg = [ + outer_idx + for outer_idx, inner_idx_or_tuple in enumerate(synthetic_base_info) + if (isinstance(inner_idx_or_tuple, int) and inner_idx_or_tuple == inner_idx) + or ( + isinstance(inner_idx_or_tuple, tuple) + and inner_idx_or_tuple[0] == inner_idx + ) + ] + synthetic_base_to_indices[inner_idx] = outer_aliased_indices_of_current_base_arg + + # given the requires_grad info on mutated inputs, + # generate the requires_grad info on those same mutated inputs, but after constructing synthetic bases. + input_infos = [] + for outer_indices in synthetic_base_to_indices.values(): + # leaf-ness should be all-or-nothing for aliased tensor. + # (aka if "a" and "b" are views, then a.is_leaf == b.is_leaf) + any_leaf = any(m.input_info[x].is_leaf for x in outer_indices) + all_leaf = all(m.input_info[x].is_leaf for x in outer_indices) + assert any_leaf == all_leaf + + mutates_data = ( + True + if len(outer_indices) > 1 + else m.input_info[outer_indices[0]].mutates_data + ) + mutates_metadata = ( + False + if len(outer_indices) > 1 + else m.input_info[outer_indices[0]].mutates_metadata + ) + requires_grad = any(m.input_info[x].requires_grad for x in outer_indices) + mutations_under_no_grad_or_inference_mode = all( + m.input_info[x].mutations_under_no_grad_or_inference_mode + for x in outer_indices + ) + + mutation_inductor_storage_resize = all( + m.input_info[x].mutation_inductor_storage_resize for x in outer_indices + ) + + inpt_info = InputAliasInfo( + # If len(outer_indices) > 1, then this input is a synthetic base. + # The invariant is that to the rest of aot autograd, synthetic bases only show up if + # one of their aliases gets a data mutation. And if any of their aliases get metadata + # mutations, they will be hidden from the rest of aot autograd. + mutates_data=mutates_data, + mutates_metadata=mutates_metadata, + mutations_hidden_from_autograd=all( + m.input_info[x].mutations_hidden_from_autograd for x in outer_indices + ), + mutates_storage_metadata=( + False + if len(outer_indices) > 1 + else m.input_info[outer_indices[0]].mutates_storage_metadata + ), + mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode, + mutation_inductor_storage_resize=mutation_inductor_storage_resize, + is_leaf=any_leaf, + requires_grad=requires_grad, + keep_input_mutations=m.keep_input_mutations, + ) + input_infos.append(inpt_info) + + # Find any inputs that fulfill the following criteria: + # (1) They are part of a synthetic base (because they alias another input, + # and at least one input experiences a data mutation) + # (2) They experience a metadata mutation + outer_aliased_arg_idx_with_metadata_mutations = [ + outer_idx + for outer_idx, inpt_info in enumerate(m.input_info) + if inpt_info.mutates_metadata + and not isinstance(synthetic_base_info[outer_idx], int) + ] + + # grab the original requires grad info on the outputs, except the ones from the mutated inputs + input_metadata_output_info = [ + OutputAliasInfo( + output_type=OutputType.alias_of_input, + raw_type=FunctionalTensor, + dynamic_dims={ + i + for i, s in enumerate(outer_args[outer_idx].shape) + if not is_concrete_int(s) + }, + base_idx=synthetic_base_info[outer_idx][0], # type: ignore[index] + requires_grad=outer_args[outer_idx].requires_grad, + ) + for outer_idx in outer_aliased_arg_idx_with_metadata_mutations + ] + existing_output_infos = [] + for o in m.output_info: + new_base_idx = ( + None + if o.base_idx is None + else ( + synthetic_base_info[o.base_idx] + if isinstance(synthetic_base_info[o.base_idx], int) + else synthetic_base_info[o.base_idx][0] # type: ignore[index] + ) + ) + # If base_idx is changed for OutputType.is_input, we need to update the output type to reflect the change + new_output_type = ( + OutputType.alias_of_input + if o.output_type == OutputType.is_input and o.base_idx != new_base_idx + else o.output_type + ) + existing_output_infos.append( + OutputAliasInfo( + output_type=new_output_type, + raw_type=o.raw_type, + dynamic_dims=o.dynamic_dims, + # Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases + base_idx=new_base_idx, # type: ignore[arg-type] + requires_grad=o.requires_grad, + view_meta_sequence=o.view_meta_sequence, + ) + ) + + inner_mutated_tangents_and_memory_formats = [ + # See Note [Tangents memory format] + ( + coerce_tangent_and_suggest_memory_format(x), + TangentAOTInput(InputMutationAOTOutput(x_desc)), + ) + for inner_idx, (x, x_desc) in enumerate(zip(inner_args, inner_args_desc)) + if input_infos[inner_idx].mutates_data and input_infos[inner_idx].requires_grad + ] + inner_mutated_tangents = [ + x[0][0] for x in inner_mutated_tangents_and_memory_formats + ] + inner_mutated_tangents_descs = [ + x[1] for x in inner_mutated_tangents_and_memory_formats + ] + inner_mutated_tangents_memory_formats = [ + x[0][1] for x in inner_mutated_tangents_and_memory_formats + ] + + output_info = existing_output_infos + input_metadata_output_info + # Regenerate traced tangents to include mutated inputs including synthetic bases + traced_tangents = ( + inner_mutated_tangents + m.traced_tangents[len(inner_mutated_tangents) :] + ) + traced_tangents_descs = ( + inner_mutated_tangents_descs + + m.traced_tangents_descs[len(inner_mutated_tangents) :] + ) + assert m.subclass_tangent_meta is not None + subclass_tangent_meta = [ + PlainTensorMeta(0, memory_format=x) + for x in inner_mutated_tangents_memory_formats + ] + m.subclass_tangent_meta[len(inner_mutated_tangents) :] + + return ( + ViewAndMutationMeta( + input_info=input_infos, + output_info=output_info, + num_intermediate_bases=m.num_intermediate_bases, + keep_input_mutations=m.keep_input_mutations, + traced_tangents=traced_tangents, + traced_tangents_descs=traced_tangents_descs, + # We are guaranteed not to get here, since synthetic_base codepaths are not supported today with subclass inputs. + subclass_inp_meta=[], + subclass_fw_graph_out_meta=[], + subclass_tangent_meta=subclass_tangent_meta, + is_train=m.is_train, + ), + outer_aliased_arg_idx_with_metadata_mutations, + ) + + +def compute_overlapping_inputs(aot_config, fwd_inputs, aliased_input_indices): + num_aliases = len(aliased_input_indices) + + shape_env = None + maybe_suppress_guards = contextlib.nullcontext + tracing_context = torch._guards.TracingContext.try_get() + + if tracing_context is not None: + assert tracing_context.fake_mode is not None + shape_env = tracing_context.fake_mode.shape_env + + # Check whether we can actually get the dynamo sources from within AOTAutograd. + if aot_config.aot_autograd_arg_pos_to_source and shape_env is not None: + maybe_suppress_guards = shape_env.suppress_guards # type: ignore[assignment] + + # Check whether there are any symbolic values being used. + # We do this for 2 reasons: + # 1. StorageOverlap guard is only issued whenever dynamic shapes is turned on + # 2. Triggers the fast-path for computing storage overlapping + symbolic = any( + isinstance(x, torch.SymInt) + for i in aliased_input_indices + for x in [ + *fwd_inputs[i].shape, + *fwd_inputs[i].stride(), + fwd_inputs[i].storage_offset(), + ] + ) + + if torch._inductor.config.is_fbcode(): + if symbolic and num_aliases > 400: + from torch._subclasses.fake_tensor import ( + UnsupportedMutationAliasingException, + ) + from torch._utils_internal import justknobs_check + + msg = f"Encountered {num_aliases} dynamic, aliased/mutated inputs, consider setting dynamic=False" + + if justknobs_check( + "pytorch/compiler:aliased_inputs_with_mutation_and_dyn_shapes_killswitch", + False, + ): + raise UnsupportedMutationAliasingException(msg) + + with maybe_suppress_guards(): + aliased_fwd_inputs = [fwd_inputs[i] for i in aliased_input_indices] + actual_aliased_indices = { + aliased_input_indices[i] + for i in compute_overlapping_tensors(aliased_fwd_inputs, symbolic=symbolic) + } + + # Add the StorageOverlap AOTAutograd guard only if we are actually keeping track of + # dynamo sources inside AOTAutograd. + if ( + tracing_context is not None + # Make sure dynamic shapes is currently being used. + and symbolic + # We check that we have more than 1 aliased tensor, which should be true at + # this point, anyway. + and num_aliases > 1 + and aot_config.aot_autograd_arg_pos_to_source + ): + no_overlap_indices = list(set(aliased_input_indices) - actual_aliased_indices) + + overlapping_sources = [ + aot_config.aot_autograd_arg_pos_to_source[i] for i in actual_aliased_indices + ] + non_overlapping_sources = [ + aot_config.aot_autograd_arg_pos_to_source[i] for i in no_overlap_indices + ] + + tracing_context.guards_context.aotautograd_guards.append( + StorageOverlap(overlapping_sources, non_overlapping_sources) + ) + + return actual_aliased_indices + + +def _graph_input_names(gm): + return [node.name for node in gm.graph.find_nodes(op="placeholder")] + + +def _graph_output_names(gm): + output_node = next(iter(reversed(gm.graph.nodes))) + assert output_node.op == "output" and len(output_node.args) == 1 + return_args = output_node.args[0] + return [getattr(return_arg, "name", None) for return_arg in return_args] + + +def create_graph_signature( + fx_g: torch.fx.GraphModule, + fw_metadata: ViewAndMutationMeta, + in_spec: pytree.TreeSpec, + out_spec: pytree.TreeSpec, + *, + user_args_flat: list[Tensor], + params_and_buffers_flat: list[Tensor], + param_names: list[str], + buffer_names: list[str], + trace_joint: bool, + num_user_fw_outs: Optional[int], + loss_index: Optional[int], +) -> GraphSignature: + # Retrieve graph input names + graph_input_names = _graph_input_names(fx_g) + # Retrieve graph output names + graph_output_names = _graph_output_names(fx_g) + + num_params_buffers = len(param_names) + len(buffer_names) + num_tokens = len(fw_metadata.tokens) + # We have enough restrictions on the graph (no de-duping, synthetic bases, etc), + # Such that # graph inps = # user inps + # params + # buffers + num_user_args = len(graph_input_names) - num_params_buffers - num_tokens + + if trace_joint: + assert num_user_fw_outs is not None + num_fw_outs = num_user_fw_outs + fw_metadata.num_mutated_inp_runtime_indices + backward_output_names = graph_output_names[num_fw_outs:] + + grad_index = itertools.count(0) + gradients_to_parameters = { + backward_output_names[next(grad_index)]: param_names[i] + for i, param in enumerate(params_and_buffers_flat) + if param.requires_grad + } + + gradients_to_user_inputs = { + backward_output_names[next(grad_index)]: graph_input_names[ + i + len(params_and_buffers_flat) + ] + for i, user_input in enumerate(user_args_flat) + if user_input.requires_grad + } + + assert len(gradients_to_parameters) + len(gradients_to_user_inputs) == len( + backward_output_names + ) + + # Check that we have fully accounted for all graph outputs + backward_signature = BackwardSignature( + gradients_to_parameters, + gradients_to_user_inputs, + graph_output_names[loss_index], + ) + else: + backward_signature = None + num_user_fw_outs = ( + len(graph_output_names) + - fw_metadata.num_mutated_inp_runtime_indices + - num_tokens + ) + + return GraphSignature.from_tracing_metadata( + in_spec=in_spec, + out_spec=out_spec, + graph_input_names=graph_input_names, + graph_output_names=graph_output_names, + view_mutation_metadata=fw_metadata, + named_parameters=param_names, + named_buffers=buffer_names, + num_user_inputs=num_user_args, + num_user_outputs=num_user_fw_outs, + trace_joint=trace_joint, + loss_index=loss_index, + backward_signature=backward_signature, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/schemas.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc03c7adb7ee1a3799b874f29f879d23055926d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/schemas.py @@ -0,0 +1,1297 @@ +# mypy: allow-untyped-defs +""" +The various dataclasses, Enums, namedtuples etc used in AOTAutograd. This includes +input/output types, metadata, config, function signatures etc. +""" + +from __future__ import annotations + +import collections +import functools +import itertools +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, NewType, Optional, Protocol, TYPE_CHECKING, TypeVar, Union + +import torch +import torch.utils._pytree as pytree +from torch import SymInt, Tensor +from torch._subclasses import FakeTensor +from torch._subclasses.fake_tensor import is_fake +from torch.fx.experimental._backward_state import BackwardState +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config +from .functional_utils import _check_if_mutation_can_be_in_graph, ViewMetaSequence +from .utils import strict_zip + + +if TYPE_CHECKING: + import contextlib + from collections.abc import Callable, Iterable, Sequence + + from torch._guards import Source + from torch._inductor.output_code import OutputCode + from torch._inductor.utils import InputType + from torch._ops import OpOverload + + from .descriptors import AOTInput, AOTOutput + from .graph_capture_wrappers import JointFnHandle + + +zip = strict_zip + + +OutputType = Enum( + "OutputType", + ( + # output is not an alias + "non_alias", + # output aliases an input + "alias_of_input", + # output **is** an input tensor + "is_input", + # output has a ._base tensor, which is a graph intermediate. + # We need to return its ._base as a graph output, + # so its requires_grad info is populated correctly. + # Instructs the runtime code to regenerate the current output + # from a base tensor, graph_intermediates[base_idx] + "alias_of_intermediate_save_as_output", + # Same as above; but we don't need to explicitly add its ._base + # as a graph output, because it already **is** a graph output. + "alias_of_intermediate", + # Same as above; but the output's ._base is **already** a user output. + # Instructs the runtime code to regenerate the current output from + # a base tensor, user_outputs[base_idx] + "alias_of_intermediate_base_is_user_output", + # See Note [Intermediate Bases Optimization] + "unsafe_view_alias", + # output is an alias, but has a custom autograd.Function backward. + # In this case, we don't want to do view-replay, since we won't be able to replay the custom function. + # Instead, we'll treat this output "normally", and trace its backward into the graph. + "custom_function_view", + ), +) + + +# This class stores info about every user output. +@dataclass(frozen=True) +class OutputAliasInfo: + # Tells us if this output is: + # (1) a regular (non-aliased) output + # (2) an alias of a forward input + # (3) **is** a forward input (special case of "alias_of_input") + # (4) an alias of an intermediate (aka an alias of an output of the inner traced forward) + # (5) an alias of an intermediate, that explicitly requires returning the intermediate + # as a graph output + # (6) an alias of an intermediate, where that intermediate is also a user output + output_type: OutputType + # The raw type of the output (torch.Tensor, SymInt, etc) + raw_type: type + # If (1) above, then + # - base_idx is None + # If (2) or (3) above, then + # - Tells us that the base of this alias is user_fwd_input[base_idx] + # (This is an index into the inputs *before* we make synthetic bases) + # If (4) or (5) above, then + # - Tells us that the base of this alias is output_graph_intermediates[base_idx] + # here, this refers to the index of the *direct* traced + # If (6) above, then: + # - Tells us that the base of this alias is output_user_fwds[base_idx] + # here, this refers to the index of the *direct* traced + base_idx: Optional[int] + # If it is a Tensor, what the dynamic dims are (otherwise is None) + dynamic_dims: Optional[set[int]] + # requires_grad + requires_grad: bool + # Sequence of ViewMeta objects. + # + # Provides us the means to re-run view functions on other tensors. + # + # We need to wrap the actual list of ViewMeta with this class so that + # we compare the ViewMeta elements appropriately, i.e. their type and + # the elements returned by the `as_tuple()` call. + view_meta_sequence: Optional[ViewMetaSequence] = None + + +class MutationType(Enum): + NOT_MUTATED = 1 + MUTATED_IN_GRAPH = 2 + MUTATED_OUT_GRAPH = 3 + + +# This class tells us info about user inputs. +@dataclass(frozen=True) +class InputAliasInfo: + is_leaf: bool + mutates_data: bool + mutates_metadata: bool + mutations_hidden_from_autograd: bool + mutations_under_no_grad_or_inference_mode: bool + mutation_inductor_storage_resize: bool + mutates_storage_metadata: bool + requires_grad: bool + keep_input_mutations: bool + + def __post_init__(self): + if self.mutates_storage_metadata: + # For convenience, we guarantee that this is always true. + # In practice, If we call .set_(), then at runtime there is no need + # to additionally fix up the tensor metadata, since our runtime + # call to inp.set_(updated_inp) will already have the right metadata + assert self.mutates_metadata + + @functools.cached_property + def mutation_type(self) -> MutationType: + if ( + (not self.mutates_data) + and (not self.mutates_metadata) + and not (self.mutation_inductor_storage_resize) + ): + return MutationType.NOT_MUTATED + + if _check_if_mutation_can_be_in_graph( + self.keep_input_mutations, + self.mutates_data, + self.mutates_metadata, + self.mutations_hidden_from_autograd, + self.mutations_under_no_grad_or_inference_mode, + self.mutates_storage_metadata, + self.mutation_inductor_storage_resize, + self.requires_grad, + ): + return MutationType.MUTATED_IN_GRAPH + + return MutationType.MUTATED_OUT_GRAPH + + +@dataclass +class MemoryFormatMeta: + # For static shapes we assume tangents have the same strideness as outputs + size: Optional[Sequence[int]] = None + stride: Optional[Sequence[int]] = None + + # For dynamic shapes we assume the same memory format: contiguous, channels_last etc. + memory_format: Optional[torch.memory_format] = None + + @staticmethod + def from_tensor(t: torch.Tensor) -> Optional[MemoryFormatMeta]: + # We only memorize expected memory format for + # 1. Traceable wrapper subclasses + # We can not create restrided subclass tensor, as torch.empty_strided works only with dense tensors. + # 2. Dynamic shape tensors + # Support for symbolic shapes is not implemented yet. + use_memory_format: bool = ( + not torch._functorch.config.guess_tangent_strides_as_outputs + or is_traceable_wrapper_subclass(t) + ) + if not use_memory_format: + is_static_shape = True + for s in itertools.chain(t.shape, t.stride()): + if not isinstance(s, int): + is_static_shape = False + break + + use_memory_format = not is_static_shape + + if use_memory_format: + return MemoryFormatMeta( + # pyrefly: ignore [unbound-name] + memory_format=torch._prims_common.suggest_memory_format(t), + ) + + return MemoryFormatMeta( + size=t.size(), + stride=t.stride(), + ) + + +@dataclass +class PlainTensorMeta: + unwrapped_idx: int + memory_format: Optional[MemoryFormatMeta] = None + + +@dataclass +class SubclassCreationMeta: + """ + Used for AOTDispatch. + This dataclass gives us the information we need to reconstruct a tensor subclass + from our flat inputs. + Why is this important? The graph that we'd like to trace out contains flat tensor inputs, + But the user's original model may have subclass inputs and outputs. + So we need to wrap/unwrap subclasses as necessary to translate between the user's + view (subclass inps/outs), and the backend compiler's view (graph with no subclass args). + + Complications arise mostly from the fact that a subclass can hold more than one inner tensor; + So for a given subclass input/output, we need to carefully track which indices map + to the subclass tensor in the corresponding "dense-tensor-only" graph. + """ + + # In the inner graph that only takes in dense tensor inputs, + # this maps to the first index of "tensors that should go in this subclass wrapper" + flat_tensor_start_idx: int + # arg_count is inclusive of the arg_counts of any + # inner tensor subclasses: If I have a TwoTensor and + # both of its inner elements are TwoTensors, then the + # arg_count of the outer-most subclass will be 4 + arg_count: int + # Mark where or not symints were included. This flag is only used in one assertion + # in "wrap_tensor_subclasses" + included_subclass_symints: bool + # meta and attrs are produced by the subclass's __tensor_flatten__. + # We need to keep them around along with outer_size / outer_stride to plumb them + # into __tensor_unflatten__ + attrs: dict[str, Union[SubclassCreationMeta, PlainTensorMeta]] + outer_size: Iterable[Union[None, int, torch.SymInt]] + outer_stride: Iterable[Union[None, int, torch.SymInt]] + meta: Any + # Stores the original subclass itself. + # This is needed because we need the autograd metadata on the original subclass + # (this is guaranteed to be a wrapper subclass that holds a fake tensor, + # so holding onto this at runtime shouldn't leak memory) + # This field is nulled out after calling make_runtime_safe() + original_subclass: Optional[torch.Tensor] + + # Used at runtime to determine the subclass type, so we don't need to save the original subclass + original_subclass_type: Optional[type] = None + memory_format: Optional[MemoryFormatMeta] = None + + def compute_outer_size_and_stride( + self, + all_args, + *, + curr_start_idx: int, + ): + from .subclass_utils import compute_symint_placeholders + + def compute(outer, start_idx): + placeholders = compute_symint_placeholders(outer) + has_symbolic = any(placeholders) + + if has_symbolic: + start = curr_start_idx + end = start_idx + sum(placeholders) + it_args = iter(all_args[start:end]) + it_placeholders = iter(placeholders) + return pytree.tree_map_only( + lambda _: next(it_placeholders), lambda _: next(it_args), outer + ), start + len(placeholders) + else: + return outer, start_idx + + outer_size, next_idx = compute(self.outer_size, curr_start_idx) + outer_stride, _ = compute(self.outer_stride, next_idx) + return outer_size, outer_stride + + def creation_fn( + self, + all_args, + *, + is_runtime: bool, + ): + inner_tensors = {} + + curr_start_idx = self.flat_tensor_start_idx + for attr, creation_meta in self.attrs.items(): + if isinstance(creation_meta, PlainTensorMeta): + subclass = all_args[curr_start_idx] + curr_start_idx += 1 + else: + subclass = creation_meta.creation_fn( + all_args, + is_runtime=is_runtime, + ) + curr_start_idx += creation_meta.arg_count + inner_tensors[attr] = subclass + + if is_runtime: + assert self.original_subclass_type is not None + original_subclass_type = self.original_subclass_type + else: + original_subclass_type = type(self.original_subclass) + + if is_runtime: + outer_size, outer_stride = self.compute_outer_size_and_stride( + all_args, + curr_start_idx=curr_start_idx, + ) + else: + outer_size, outer_stride = self.outer_size, self.outer_stride + + rebuilt = original_subclass_type.__tensor_unflatten__( # type: ignore[attr-defined] + inner_tensors, self.meta, outer_size, outer_stride + ) + + if not is_runtime: + # After wrapping up the inner dense tensors into a subclass, we need to make sure that our new wrapper + # has correct autograd metadata, since we'll be tracing through the autograd engine with the subclass. + # We don't trace through the autograd engine at runtime though, so no need + # to compute this extra metadata then! + torch._mirror_autograd_meta_to(self.original_subclass, rebuilt) # type: ignore[attr-defined] + + return rebuilt + + def make_runtime_safe(self): + def _make_size_runtime_safe(x: Union[None, int, torch.SymInt]) -> Optional[int]: + dummy = -1 + if isinstance(x, torch.SymInt): + # Replace nested ints by a dummy value (-1) as NJT ignores + # the outer_size/outer_stride at runtime. + return dummy if x.node.is_nested_int() else None + return x + + assert self.original_subclass is not None + self.original_subclass_type = type(self.original_subclass) + self.original_subclass = None + + # Note: NJT outer_size in AOTDispatcher + # `_make_size_runtime_safe` replaces any nested int with a dummy value (-1) + # to prevent serializing a SymInt at runtime. Internally, nested tensor __tensor_unflatten__ + # is designed to safely ignore this dummy value. + # For more details, see: https://github.com/pytorch/pytorch/blob/5141ade8e30c64e873e14dcc8de233da45d15025/torch/nested/_internal/nested_tensor.py#L266-L299 # noqa: B950 + self.outer_size = tuple(map(_make_size_runtime_safe, self.outer_size)) + self.outer_stride = tuple(map(_make_size_runtime_safe, self.outer_stride)) + + # Recurse on nested subclass info + for creation_meta in self.attrs.values(): + if isinstance(creation_meta, SubclassCreationMeta): + creation_meta.make_runtime_safe() + + def __post_init__(self): + # sanity assert to make sure we don't leak memory + assert is_fake(self.original_subclass) + + +# This class encapsulates all aliasing + mutation info we need about the forward graph +# See a more detailed overview of the edge case handling at +# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit +# NOTE: This class is saved in AOTAutogradCache, If you are adding elements, make sure +# they are covered by warm cache tests. +@dataclass(eq=False) +class ViewAndMutationMeta: + # length = # user inputs + # This gives us info about every input, and what sort of mutation happened to it (if any) + input_info: list[InputAliasInfo] + + # length = # user outputs + # This gives us info about every output (mostly around whether it aliases other tensors) + output_info: list[OutputAliasInfo] + + # length = the number of intermediate bases appended as outputs to the end of the forward graph. + # Note: this is not necessarily the same thing as: + # len([x for x in output_info if x.output_type == OutputType.alias_of_intermediate]) + # Because outputs might share a ._base, or an output's ._base might itself be + # another user output (in both cases, we won't redundantly append bases to the end of the graph) + num_intermediate_bases: int + + # For inference only: instructs us to keep data-only input mutations directly in the graph + keep_input_mutations: bool + + # length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors) + # + (# intermediate bases) + # These are the FakeTensor (or potential SymInt) outputs that we traced from our + # metadata pass of the user's forward function. + # Their only use today is to pass them as a best-guess for tangents when tracing the joint. + # Stashing them as part of our "metadata" makes it simpler if we want to run our analysis + # pass once, and reuse the output throughout AOTAutograd + traced_tangents: list[Any] + + # TODO doc + traced_tangents_descs: list[AOTInput] + + # Each of these is a list telling us about subclasses for the inputs/outputs/grad_outs + # They are used throughout AOTDispatch to tell us how to generate a list of subclass tensors, + # Given a (potentially larger) list of plain torch tensors. + + # Taking subclass_inp_meta as an example: + # subclass_inp_meta[i] = j (an int) tells us: + # "The i'th user input is not a subclass, and corresponds to inputs[j] of the plain-tensor graph." + # subclass_inp_meta[i] = SubclassCreationMeta(flat_tensor_start_idx=3, arg_count=2) + # "The i'th user input is subclass holding two inner tensors, which are + # inputs[3] and inputs[4] of the plain-tensor graph". + + # length = # user inputs + subclass_inp_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]] + # So, the full set of outputs to the forward graph looks something like: + # (*mutated_inps, *user_outs, *intermediate_bases, *saved_for_bw_tensors) + # where the first 3 of those 4 can be subclasses + # (but not saved_for_bw tensors, since these are internal to the compiler + # and not user visible, so there's no point in wrapping/unwrapping them at runtime). + # This list contains subclass information on all of the fw graph outputs + # except for saved_for_bw_tensors. + subclass_fw_graph_out_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]] + # length = # backward graph inputs + subclass_tangent_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]] + # TODO: we should kill this + # (need to default it to not break internal) + is_train: bool = False + + # length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors) + # + (# intermediate bases) + # At runtime, we don't keep the traced_tangents around since they're not serializable. + # Instead, we keep any necessary subclass metadata necessary about each traced_tangent. + # This list is generated after calling make_runtime_safe(). + traced_tangent_metas: Optional[list[Any]] = None + + num_symints_saved_for_bw: Optional[int] = None + + # The grad_enabled mutation that will be emitted in the runtime_wrapper epilogue + # NOTE: AOTAutograd will assume that the ambient `is_grad_enabled` is the grad mode + # that is intended to be in effect prior to running the graph, in keeping with + # equivalence to eager mode. It is the responsibility of upstream graph acquisition + # to reset the grad mode to its pre-graph value prior to calling aot_autograd. + grad_enabled_mutation: Optional[bool] = None + + # Keeps track of whether `torch.use_deterministic_algorithms` was turned on + # when the forward was run. If deterministic mode was turned off during the + # forward, but is turned on during the backward call, then an error is + # raised + deterministic: Optional[bool] = None + + # Keeps track of which input indices store parameters (which we will treat as static) + static_input_indices: list[int] = field(default_factory=list) + + # Map of effect type (ex. _EffectType.ORDERED) to token. If there are + # side-effectful operators, FunctionalTensorMode will populate this + # dictionary telling us how many tokens we will need during tracing. + tokens: dict[Any, torch.Tensor] = field(default_factory=dict) + + # Only filled in if/when we trace the joint function + # If an input requires grad and is mutated in the backward, it is only safe to keep the mutation + # in the graph if gradients are disabled while the backward runs + # (grad mode is disabled by default when users run the backward, but can be turned on with create_graph=True) + # At runtime during the backward, we use this list of indices to error properly if we find out + # that it was not safe to include a backward mutation in the graph. + indices_of_inputs_that_requires_grad_with_mutations_in_bw: list[int] = field( + default_factory=list + ) + + # Indexes of saved tensors which are donated buffer. + # Donated buffer means the tensor is not alias of any forward user input, forward user output, + # and backward output. + bw_donated_idxs: Optional[list[int]] = None + + # Number of tokens used in backward, appended at the end of backward outputs. + # Filled after tracing joint function. + num_backward_tokens: int = 0 + + # Number of rng states that will get thread into the forward and backward for + # cudagraph compatible run_and_save_rng + num_graphsafe_rng_states: int = 0 + + graphsafe_rng_state_index: Optional[int] = None + + def __post_init__(self): + # pre-compute the indices of the inputs that are mutated. + # When keep_input_mutations is set, we don't need to worry about our epilogue + # handling data-only mutations, because we keep them directly in the graph. + mutated_inp_runtime_indices = [ + i + for i, m in enumerate(self.input_info) + if (m.mutation_type == MutationType.MUTATED_OUT_GRAPH) + ] + + mutated_graph_handled_indices = [ + i + for i, m in enumerate(self.input_info) + if m.mutation_type == MutationType.MUTATED_IN_GRAPH + ] + self.mutated_graph_handled_indices = mutated_graph_handled_indices + self.num_mutated_graph_handled_indices = len(self.mutated_graph_handled_indices) + + mutated_graph_handled_indices_seen_by_autograd = [ + i + for i in mutated_graph_handled_indices + if not self.input_info[i].mutations_hidden_from_autograd + ] + + self.mutated_graph_handled_indices_seen_by_autograd = ( + mutated_graph_handled_indices_seen_by_autograd + ) + self.num_mutated_graph_handled_indices_seen_by_autograd = len( + self.mutated_graph_handled_indices_seen_by_autograd + ) + + aliased_out_indices = [ + i + for i, m in enumerate(self.output_info) + if m.output_type + not in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + ] + unsafe_view_out_indices = [ + i + for i, m in enumerate(self.output_info) + if m.output_type is OutputType.unsafe_view_alias + ] + + # This is pre-computed in post_init for perf. + # It contains the index of every element + # of input_info that corresponds to a mutation (data or metadata or both) + self.mutated_inp_runtime_indices = mutated_inp_runtime_indices + self.num_mutated_inp_runtime_indices = len(self.mutated_inp_runtime_indices) + + # This is pre-computed for perf. + # It contains the index of every element + # of output_info that corresponds to an alias (either of an input or intermediate) + self.aliased_out_indices = aliased_out_indices + self.unsafe_view_out_indices = unsafe_view_out_indices + self.num_outputs = len(self.output_info) + self.num_outputs_non_aliased = len( + [ + x + for x in self.output_info + if x.output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + ] + ) + self.num_outputs_aliased_to_inputs = len( + [ + x + for x in self.output_info + if x.output_type + in [ + OutputType.alias_of_input, + OutputType.is_input, + ] + ] + ) + self.num_unsafe_view_outputs = len(self.unsafe_view_out_indices) + self.num_outputs_aliased_to_intermediates = len( + [ + x + for x in self.output_info + if x.output_type + in [ + OutputType.alias_of_intermediate, + OutputType.alias_of_intermediate_save_as_output, + OutputType.alias_of_intermediate_base_is_user_output, + ] + ] + ) + self.num_outputs_aliased = ( + self.num_outputs_aliased_to_inputs + + self.num_outputs_aliased_to_intermediates + ) + + # Record dynamic outputs of the Dynamo traced forward graph + # Mark them as dynamic at the end of the runtime wrapper + self.dynamic_outputs = any(o.dynamic_dims for o in self.output_info) + + # Record the indices of dynamic outputs in the partitioned forward graph + # Mark them as dynamic in the runtime wrapper + # activation index -> dynamic dims indices + self.dynamic_saved_tensors_idxs: dict[int, set[int]] = {} + + # See Note: [AOTAutograd Backward Guards] + # This is pre-computed for fast asserts on the types of our grad_outputs in the backward. + # Eventually, we should kill this and replace with real backward guards. + # (we want to precompute the "runtime" types, so replace FakeTensor with torch.Tensor) + self.output_types = [ + torch.Tensor if isinstance(x, FakeTensor) else type(x) + for x in self.traced_tangents + ] + + self.is_rng_op_functionalized = config.functionalize_rng_ops + # All of the above metadata is collected by tracing the fw function. + # However, extra outputs for rng offsets behave differently. Both fwd + # and bwd graphs have their own outputs for the total consumed offsets. + # Unlike mutated inputs, we don't have to worry about sending the right + # set of tensors between fwd and bwd. Fwd and bwd offsets are + # independent and simpler to handle. Therefore, we track them + # separately. + self.num_outputs_rng_offset = 1 if self.is_rng_op_functionalized else 0 + + # Our forward() returns both (tokens, mutated_inputs, outputs, output_intermediate_bases, saved_tensors, saved_symints) + # Tokens will be split out before mutations/view handling and we do not count them here. + self.num_forward_returns = ( + self.num_mutated_inp_runtime_indices + + self.num_outputs + + self.num_intermediate_bases + ) + # In case of functionalization of rng ops, the fw_module returns one + # additional output for rng offset. This rng offset is used right + # away to advance the rng state, and is not passed on to the raw + # outputs. However, we need to know the exact boundary to identify + # which tensors to be saved for the bwd graph. num_forward captures + # this information. + self.num_forward = self.num_forward_returns + self.num_outputs_rng_offset + + def make_runtime_safe(self): + """ + There are various fields in ViewAndMutationMeta that aren't serializable. This function is called after all tracing + is completed to simplify certain fields in the metadata so that they can be safely cached. + + Doing so may lose information (in the case of traced_tangents), but none of the information is needed at runtime. + """ + # TODO: This function is only a best effort: there are other fields that may not be cache safe + # (i.e., there's no guarantee that tensor_flatten() returns a serializable result), or that + # SubclassCreationMeta is cache safe. + assert self.traced_tangent_metas is None + + def extract_metadata(t): + if isinstance(t, torch.Tensor) and is_traceable_wrapper_subclass(t): + (inner_tensors, flatten_spec) = t.__tensor_flatten__() # type: ignore[attr-defined] + # Technically, we only need the flatten_spec, not the inner tensors. + # However, some Tensor subclasses (like TwoTensor) may have flatten_spec = None. + # And we want to be able to assert that this metadata is non-None, + # to distinguish between "this was a tensor subclass with no metadata" vs. + # "this wasn't a tensor subclass at all". + return (inner_tensors, flatten_spec) + else: + return None + + self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents] + # Clear traced tangents at runtime + self.traced_tangents = [] + for inp_meta in self.subclass_inp_meta: + if isinstance(inp_meta, SubclassCreationMeta): + inp_meta.make_runtime_safe() + for inp_meta in self.subclass_fw_graph_out_meta: + if isinstance(inp_meta, SubclassCreationMeta): + inp_meta.make_runtime_safe() + for inp_meta in self.subclass_tangent_meta: + if isinstance(inp_meta, SubclassCreationMeta): + inp_meta.make_runtime_safe() + + @property + def tensors_saved_for_backwards_slice(self): + assert self.num_symints_saved_for_bw is not None + if self.num_symints_saved_for_bw > 0: + return slice(self.num_forward, -self.num_symints_saved_for_bw) + else: + return slice(self.num_forward, None) + + @property + def symints_saved_for_backwards_slice(self): + assert self.num_symints_saved_for_bw is not None + if self.num_symints_saved_for_bw > 0: + return slice(-self.num_symints_saved_for_bw, None) + else: + return slice(0, 0) # empty slice + + def __eq__(self, other): + if not isinstance(other, ViewAndMutationMeta): + return NotImplemented + return ( + self.input_info == other.input_info + and self.output_info == other.output_info + and self.num_intermediate_bases == other.num_intermediate_bases + and self.keep_input_mutations == other.keep_input_mutations + and self.is_rng_op_functionalized == other.is_rng_op_functionalized + and self.num_outputs_rng_offset == other.num_outputs_rng_offset + and len(self.traced_tangents) == len(other.traced_tangents) + and all( + x.shape == y.shape and x.dtype == y.dtype + for x, y in zip(self.traced_tangents, other.traced_tangents) + ) + and self.num_backward_tokens == other.num_backward_tokens + ) + + +@dataclass(eq=False) +class SubclassMeta: + # A copy of all forward metadata, but computed on the *dense* tensor forward (after desugaring subclasses) + # So for example, if the user had a model containing two `TwoTensor` inputs, + # Then `SubclassMeta.fw_metadata.input_infos` would have length 4 here. + fw_metadata: ViewAndMutationMeta + + # Note: [Computing Subclass Metadata about grad_inputs] + # Given a list of flattened, plain tensor grad_inputs, this tells us how to reconstruct the grad_input subclasses + # + # You might think: why not just assume that all grad_inputs will have the same subclass-ness as the original inputs? + # (AOTAutograd generally assumes other properties, e.g. that grad_outputs are contiguous) + # + # This doesn't really work though. take this example: + # + # def f(DoubleTensor, DenseTensor): + # return DoubleTensor * DenseTensor + # + # In the above example, the .grad field of *both* DoubleTensor and DenseTensor will be a DoubleTensor. + # When we trace out a joint fw-bw graph, we'll end up returning two subclasses for the two grad_inputs. + # This means that our backward graph will return 4 outputs (two dense tensors for each DoubleTensor grad_input) + # and we need to properly store the metadata that tells us how to turn these 4 outputs back into DoubleTensors. + # + # Note that this info **cannot** easily be figured out from ViewAndMutationMeta. + # We can only compute this info by tracing the entire joint and examining the grad_inputs that we computed. + # + # See Note: [AOTAutograd Backward Guards] + # This will also eventually require us to install backward guards, + # in case we made incorrect assumptions about the subclass-ness of our grad_outputs + # + # Optional field because we don't compute for inference graphs + grad_input_metas: Optional[list[Union[PlainTensorMeta, SubclassCreationMeta]]] = ( + None + ) + + def __init__(self) -> None: + # The fields in this class get set after its construction. + pass + + +# This class exists because: +# - the autograd.Function.forward() in aot autograd returns outputs that might alias inputs +# - we only care about the metadata on those aliases, so we can regenerate them. +# We do not want them to participate in the autograd.Function. +# We do that by wrapping them in an opaque class, so the autograd.Function +# does not know to treat them as tensors. +@dataclass(frozen=True) +class TensorAlias: + alias: torch.Tensor + + +@dataclass +class BackwardSignature: + """ + Provides information about the backward section of an exported + joint forward-backward graph. + For a particular fx GraphModule, this class contains information on: + (1) A mapping from each gradient (backwards output) to the parameter + it corresponds to (forward input) + (2) A mapping from each gradient (backwards output) to the user input + it corresponds to (forward input) + (3) Which of the forward outputs corresponds to the loss, that we backprop on. + + Each string name is the `node.name` of the corresponding node in the fx graph. + """ + + gradients_to_parameters: dict[str, str] + gradients_to_user_inputs: dict[str, str] + loss_output: str + + +GraphOutputName = NewType("GraphOutputName", str) +GraphInputName = NewType("GraphInputName", str) +FQN = NewType("FQN", str) + + +@dataclass +class GraphSignature: + """ + Provides information about an exported module. + For a particular fx GraphModule, this class contains information on: + (1) Which graph inputs are parameters, buffers, or user inputs + (2) (for params/buffers) a mapping from the name of each graph argument + to its parameter/buffer FQN in the original nn.Module. + (3) If there are input mutations, these are represented as extra outputs + in the fx GraphModule. We provide a mapping from these + extra output names to the names of the actual inputs. + (4) The pytree metadata on how to flatten/unflatten inputs and outputs. + The corresponding FX GraphModule only accepts and returns + pytree-flattened inputs/outputs. + (5) (Optionally) if the FX is a joint forward-backward graph, we provide + a signature on the backward section of the joint graph. + """ + + parameters: list[FQN] + buffers: list[FQN] + + user_inputs: list[GraphInputName] + user_outputs: list[GraphOutputName] + inputs_to_parameters: dict[GraphInputName, FQN] + inputs_to_buffers: dict[GraphInputName, FQN] + + # If the user's module mutates a buffer, + # it's represented in the graph as an extra graph output. + # This dict is a mapping from + # "graph outputs that correspond to updated buffers" + # to the FQN names of those mutated buffers. + buffers_to_mutate: dict[GraphOutputName, FQN] + parameters_to_mutate: dict[GraphOutputName, FQN] + user_inputs_to_mutate: dict[GraphOutputName, GraphInputName] + + in_spec: pytree.TreeSpec + out_spec: pytree.TreeSpec + + backward_signature: Optional[BackwardSignature] + + input_tokens: list[GraphInputName] + output_tokens: list[GraphOutputName] + + @classmethod + def from_tracing_metadata( + cls, + *, + in_spec: pytree.TreeSpec, + out_spec: pytree.TreeSpec, + graph_input_names: list[str], + graph_output_names: list[str], + view_mutation_metadata: ViewAndMutationMeta, + named_parameters: list[str], + named_buffers: list[str], + num_user_inputs: int, + num_user_outputs: int, + trace_joint: bool, + loss_index: Optional[int], + backward_signature: Optional[BackwardSignature], + ) -> GraphSignature: + graph_inputs = graph_input_names + graph_outputs = graph_output_names + parameters = list(named_parameters) + buffers = list(named_buffers) + num_tokens = len(view_mutation_metadata.tokens) + + # Calling convention assumptions: + # (1) graph inputs = (input_tokens, params, buffers, user_inputs) + # (2) graph outputs = (output_tokens, mutated_inputs, user_outs, param_gradients) + # (If we are capturing an inference graph, this convention is identical + # except that param_gradients is empty) + # See Note [Side-Effectful Tokens in AOTAutograd] for information on tokens + + # Address input calling conventions: + start, stop = 0, num_tokens + input_tokens = graph_inputs[start:stop] + + start, stop = stop, stop + len(parameters) + inputs_to_parameters = dict(zip(graph_inputs[start:stop], parameters)) + + start, stop = stop, stop + len(buffers) + inputs_to_buffers = dict( + zip( + graph_inputs[start:stop], + buffers, + ) + ) + + start, stop = stop, stop + num_user_inputs + user_inputs = graph_inputs[start:stop] + + # We should've gone through all the inputs now + assert len(graph_inputs) - stop == 0 + + # Address output calling conventions: + start, stop = 0, num_tokens + output_tokens = graph_outputs[start:stop] + + names = [*input_tokens, *parameters, *buffers, *user_inputs] + mutations = [] + for idx, input_info in enumerate(view_mutation_metadata.input_info): + if input_info.mutates_data: + if trace_joint: + # Only buffers can be mutated, not parameters + assert idx >= len(parameters) + mutations.append(names[idx + num_tokens]) + + assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices + + start, stop = ( + stop, + stop + view_mutation_metadata.num_mutated_inp_runtime_indices, + ) + outputs_to_mutations = dict(zip(graph_outputs[start:stop], mutations)) + + user_inputs_to_mutate = {} + buffers_to_mutate = {} + parameters_to_mutate = {} + for output_name, mutation_name in outputs_to_mutations.items(): + if mutation_name in user_inputs: + # pyrefly: ignore [unsupported-operation] + user_inputs_to_mutate[output_name] = mutation_name + else: + assert mutation_name in buffers or mutation_name in parameters + if mutation_name in buffers: + # pyrefly: ignore [unsupported-operation] + buffers_to_mutate[output_name] = mutation_name + else: + # pyrefly: ignore [unsupported-operation] + parameters_to_mutate[output_name] = mutation_name + + start, stop = stop, stop + num_user_outputs + user_outputs = graph_outputs[start:stop] + + unused_outputs = len(graph_outputs) - stop + if backward_signature is not None: + unused_outputs -= len(backward_signature.gradients_to_parameters) + len( + backward_signature.gradients_to_user_inputs + ) + assert unused_outputs == 0 + + return GraphSignature( + parameters=parameters, # type: ignore[arg-type] + buffers=buffers, # type: ignore[arg-type] + user_inputs=user_inputs, # type: ignore[arg-type] + user_outputs=user_outputs, # type: ignore[arg-type] + inputs_to_buffers=inputs_to_buffers, # type: ignore[arg-type] + inputs_to_parameters=inputs_to_parameters, # type: ignore[arg-type] + user_inputs_to_mutate=user_inputs_to_mutate, + buffers_to_mutate=buffers_to_mutate, # type: ignore[arg-type] + parameters_to_mutate=parameters_to_mutate, # type: ignore[arg-type] + in_spec=in_spec, + out_spec=out_spec, + backward_signature=backward_signature, + input_tokens=input_tokens, # type: ignore[arg-type] + output_tokens=output_tokens, # type: ignore[arg-type] + ) + + +@dataclass +class AOTAutogradCacheInfo: + cache_key: str + start_time_ns: int + forward_symints: list[torch.SymInt] + + +@dataclass +class AOTConfig: + """ + Configuration for AOTDispatcher + """ + + fw_compiler: Callable + bw_compiler: Callable + partition_fn: Callable + decompositions: dict[OpOverload, Callable] + num_params_buffers: int + aot_id: int + keep_inference_input_mutations: bool + is_export: bool = False + no_tangents: bool = False + dynamic_shapes: bool = False + aot_autograd_arg_pos_to_source: Optional[list[Source]] = None + static_input_indices: Optional[list[int]] = None + inference_compiler: Optional[Callable] = None + enable_log: bool = True + # this is always false outside of export. + pre_dispatch: bool = False + # Key to use for AOTAutogradCache + cache_info: Optional[AOTAutogradCacheInfo] = None + # If we should ignore the shape_env in the ambient tracing_context. + # The net effect is that if dynamic shapes are on, we end up + # specializing on example_inputs. + # Used only by standalone_compile. + ignore_shape_env: bool = False + precompile_backend_id: Optional[str] = None + force_non_lazy_backward_lowering: bool = False + # This config makes sure to check certain things like + # mutating input with req_grad in export joint tracing. + export_trace_joint: bool = False + disable_functionalization: bool = False + + def __post_init__(self): + if self.pre_dispatch: + assert self.is_export, "Can only have pre_dispatch IR for export." + + +# TODO: types here +# plain_tensor_trace_fn, when it is joint, has tuple structure on the trace +# info too! +# TODO: this needs to be generic, parameterized on AOTDescriptor +SubclassTracingInfo = collections.namedtuple( + "SubclassTracingInfo", + [ + "plain_tensor_trace_fn", + "plain_tensor_args", + "plain_tensor_args_descs", + "maybe_subclass_meta", + ], +) + + +@dataclass +class AOTState: + """ + When we run AOTAutograd, this class encapsulates the state in the compiler which + must be preserved across stages. This is state in the traditional sense (not an + environment) because some values in this structure change as we progress through + pipelines in AOTAutograd. + """ + + # Whether or not we need to handle autograd when doing graph capture and + # compilation. Although the calling convention for non-autograd graph + # capture in AOTAutograd is simple and can be relied upon, the autograph + # capture calling convention is quite complicated and in general you are + # only expected to pass to aot_stage2_compile to process. + needs_autograd: bool + + # The FAKE flat arguments which we will do tracing with. Although you + # might naively expect this to be immutable, it's not: when we perform + # tracing, we may execute code that modifies the metadata of inputs, + # causing the args to become "invalid". It's also nontrivial to have a + # "golden" set of fake values and deepcopy them just in time when you + # might destructively mutate them (Voz and I tried very hard to do this). + # So we just periodically renew this field. Don't worry too much about + # this unless you're specifically trying to track down an input metadata + # mutation bug. + # + # (By the way, this is NEVER the joint inputs! Those only ever go in + # AOTGraphCapture) + flat_args: list[FxValue] + + # The descriptor for each argument in flat_args. + flat_args_descs: list[AOTInput] + + # This contains view and mutation information about the function, which we + # detected by doing an initial trace when we created this state. + fw_metadata: ViewAndMutationMeta + + # Top-level configuration + # This is morally immutable but sometimes we are naughty and mutate it. + aot_config: AOTConfig + + # When performing AOTAutograd traces and other passes, we typically + # require a lot of active context managers; most typically these either + # (1) ensure we are faithfully replicating the original PyTorch context + # managers or (2) toggle some behaviors in PyTorch to make it more + # suitable for tracing. When you use AOTState, you're expected to have + # created an ExitStack, entered it; then while we are running AOTAutograd + # we will add things onto the stack as necessary. When you're all done + # with processing AOTAutograd, you can exit this stack. All functions + # that take AOTState expect the ExitStack to not have been exited yet. + # + # TODO: We potentially could offer a resumable context manager, where you + # can cancel it and reenable it later when you need it. + stack: contextlib.ExitStack + + +FxValue = Union[Tensor, int, SymInt, BackwardState] + + +class CompilerWrapper: + """ + AOTAutograd needs to do many transformations to the calling convention of the user function + it is tracing, e.g., deduplicating inputs, unpacking subclasses, etc. CompilerWrapper lets + us factor these into compositional stages so we can handle each transformation incrementally + instead of having to do it all at once. + + Since there is a calling convention change, there are two parts to the wrpaper: + + 1. The prologue, which is about compile-time behavior: given this original function, what + is the new function with modified calling convention that we should trace with AOTAutograd + to get the FX graph we will do joint passes, partitioning and ultimate Inductor compilation on? + We get (flat_fn, flat_args), the original function under trace and inputs we were + going to feed it, and produce a new function and new inputs to feed it. + + 2. The epilogue, which is about run-time behavior: we have now compiled the modified calling + convention function, we need to wrap it so that we have a new function that has the + original calling convention of the original function, so that our users can call it + at the old signature they expected. We get (compiled_fn, real arguments), the newly + compiled function we need to wrap. + + Note about caching: we do NOT directly serialize the runtime wrappers; instead, they + are reapplied to compiled_fn after we have finished deserializing the compiled_fn. + + Extra metadata that is needed to compute pre or post compile can be passed in via attributes. + """ + + def pre_compile( + self, + flat_fn, + flat_args: list[FxValue], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[Callable, list[FxValue], list[AOTInput], ViewAndMutationMeta]: + """ + Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs. + Args: + flat_fn: The function to compile + flat_args: Metadata from example inputs of the function to compile + aot_config: AOTConfig passed in at compile time + fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args + """ + return flat_fn, flat_args, flat_args_descs, fw_metadata + + def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable: + """ + Given an output of the compiler, wrap it with information received from prologue. + Args: + compiled_fn: Callable after calling compiler_fn + aot_config: AOTConfig after calling prologue + runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps. + Example: + + def wrapped_compiled_fn(args): + # do something with args, aot_config, fw_metadata + return compiled_fn(args) + + return wrapped_compiled_fn + """ + return compiled_fn + + +class InductorWrapper: + """ + This is sort of like CompilerWrapper, but it happens at a different part of the lifecycle: + it talks about transformations we do to the traced and partitioned FX graph before we + send it to the Inductor compiler. + + Once again, there are two parts: + + 1. The prologue, which "modifies" the FX graph before we send it to + Inductor. I say "modifies" because... we don't really actually do + anything nontrivial in either of our two implementations. + 2. The epilogue, which modifies the compiled function produced by Inductor + + Although hypothetically these wrappers could be used compositionally in a centralized + wrappers list, in practice they seem to just be invoked manually when needed. + + NB: The flat_args input is sometimes mutated. This is probably naughty but whatever. + """ + + def pre_compile( + self, + fw_module: torch.fx.GraphModule, + flat_args: list[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> None: + """ + Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs. + Args: + flat_fn: The function to compile + flat_args: Metadata from example inputs of the function to compile + aot_config: AOTConfig passed in at compile time + fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args + """ + return + + def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable: + """ + Given an output of the compiler, wrap it with information received from prologue. + Args: + compiled_fn: Callable after calling compiler_fn + aot_config: AOTConfig after calling prologue + runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps. + Example: + + def wrapped_compiled_fn(args): + # do something with args, aot_config, fw_metadata + return compiled_fn(args) + + return wrapped_compiled_fn + """ + return compiled_fn + + +@dataclass +class AOTGraphCapture: # Produced by aot_stage1_graph_capture + # AOTAutograd typically operates by taking complicated graphs and + # desugaring them into simpler graphs that use PyTorch features. These + # wrappers establish invariants so that when we actually do tracing we can + # assume these invariants hold, leading to a simpler tracing + # implementation. However, this means that we have to keep track of how + # to enter/exit these wrappers when passing inputs into the compiled + # graph, among other things! + wrappers: list[CompilerWrapper] + + # The actual captured graph module. In some circumstances (export) this + # graph has a specific calling convention that can be relied upon by + # external callers. In other situations, the calling convention is + # unspecified and only aot_stage2_compile knows how to deal with them. + graph_module: torch.fx.GraphModule + + # When compiling with autograd support, this is the joint_inputs, which is + # larger than the original flat_args as all tangents get inputs. The + # tuple organizes into primals and tangents. When not autograd it's just + # a plain list. + updated_flat_args: Union[list[Any], tuple[list[Any], list[Any]]] + + updated_flat_args_descs: Union[ + list[AOTInput], tuple[list[AOTInput], list[AOTInput]] + ] + + # Metadata about subclass inputs/outputs in the graph trace. + maybe_subclass_meta: Any + + +FakifiedFlatArgs = NewType("FakifiedFlatArgs", list[Any]) + + +TOutputCode = TypeVar("TOutputCode", bound="OutputCode") + + +class AOTDispatchCompiler(Protocol): + """ + Represents a fw or bw_compiler passed to AOTAutograd. + """ + + def __call__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + ) -> Any: ... + + +# TODO: bikeshed on this name +class SerializableAOTDispatchCompiler(AOTDispatchCompiler): + """ + Represents an AOTDispatchCompiler that returns an OutputCode, and is + therefore cacheable. SerializableAOTDispatchCompiler always return an OutputCode. + A _CompileFxCallable usually gets converted into an AOTDispatchCompiler after binding all of + the kwargs in _CompileFxKwargs. + """ + + def __init__( + self, + output_code_ty: type[TOutputCode], + compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode], + ): + # pyrefly: ignore [invalid-type-var] + self.output_code_ty = output_code_ty + # pyrefly: ignore [invalid-type-var] + self.compiler_fn = compiler_fn + + def __call__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + ) -> OutputCode: + return self.compiler_fn(gm, example_inputs) + + +class FlatFn(Protocol): + def __call__(self, *args: FxValue) -> list[FxValue]: ... + + +class TraceFn(Protocol): + def __call__(self, *args: FxValue) -> tuple[list[FxValue], list[AOTOutput]]: ... + + +class PreppedForAutogradTraceFn(Protocol): + def __call__( + self, + *args: FxValue, + ) -> tuple[tuple[list[FxValue], list[bool]], list[AOTOutput]]: ... + + +class JointTraceFn(Protocol): + handle: JointFnHandle + + def __call__( + self, primals: list[FxValue], tangents: list[FxValue] + ) -> tuple[ + tuple[list[FxValue], list[Optional[Tensor]]], + tuple[list[AOTOutput], list[Optional[AOTOutput]]], + ]: ... + + +@dataclass +class JointWithDescriptors: + _aot_state: AOTState + _aot_graph_capture: AOTGraphCapture + + # The exact order parameters and buffers are expected to be passed into + # the final compiled function. Parameters before buffers. + params_spec: list[str] + buffers_spec: list[str] + + in_spec: pytree.TreeSpec + out_spec: pytree.TreeSpec + + @property + def graph_module(self): + return self._aot_graph_capture.graph_module + + @graph_module.setter + def graph_module(self, value): + self._aot_graph_capture.graph_module = value diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/subclass_parametrization.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/subclass_parametrization.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea6635a62e81a57fba45e97d5f0eb2109e48d8f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/subclass_parametrization.py @@ -0,0 +1,104 @@ +import dataclasses +import itertools +from collections.abc import Iterable +from typing import Any, Union + +import torch +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +# This is technically very similar to SubclassCreatingMeta +# in aot_autograd, but we don't need all the stuff in there +# so just recreated a new dataclass. +@dataclasses.dataclass +class SubclassCreationMeta: + start_idx: int + num_tensors: int + class_type: Any + attrs: dict[str, "SubclassCreationMeta"] + metadata: Any + outer_size: Iterable[Union[None, int, torch.SymInt]] + outer_stride: Iterable[Union[None, int, torch.SymInt]] + + +class UnwrapTensorSubclass(torch.nn.Module): + def forward(self, *tensors) -> torch.Tensor: # type: ignore[no-untyped-def] + todo: list[torch.Tensor] = list(tensors) + + def _unwrap_tensor_subclasses(subclass_meta, tensors, offset): # type: ignore[no-untyped-def] + if subclass_meta is None: + return tensors[offset], offset + 1 + inner_tensors = {} + for attr, meta in subclass_meta.attrs.items(): + built_tensor, offset = _unwrap_tensor_subclasses(meta, tensors, offset) + inner_tensors[attr] = built_tensor + rebuilt = subclass_meta.class_type.__tensor_unflatten__( + inner_tensors, + subclass_meta.metadata, + subclass_meta.outer_size, + subclass_meta.outer_stride, + ) + return rebuilt, offset + + return _unwrap_tensor_subclasses(self.subclass_meta, todo, 0)[0] + + def right_inverse(self, tensor: torch.Tensor) -> list[torch.Tensor]: + assert type(tensor) is not torch.Tensor + plain_tensors: list[torch.Tensor] = [] + + def _create_subclass_meta(tensor, idx, plain_tensor_container): # type: ignore[no-untyped-def] + if type(tensor) is torch.Tensor: + plain_tensor_container.append(tensor) + return None, idx + 1 + inner_tensors_attrnames, metadata = tensor.__tensor_flatten__() # type: ignore[attr-defined] + new_idx = idx + attr_to_meta = {} + for attr in inner_tensors_attrnames: + val = getattr(tensor, attr) + subclass_meta, new_idx = _create_subclass_meta( + val, new_idx, plain_tensor_container + ) + attr_to_meta[attr] = subclass_meta + return ( + SubclassCreationMeta( + start_idx=idx, + num_tensors=new_idx - idx, + class_type=type(tensor), + attrs=attr_to_meta, + metadata=metadata, + outer_size=tensor.size(), + outer_stride=tensor.stride(), + ), + new_idx, + ) + + self.subclass_meta = _create_subclass_meta(tensor, 0, plain_tensors)[0] + return plain_tensors + + +def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Module: + """ + Model transformation that replaces all the parameters that are subclasses to plain tensors. + This reduces runtime overhead of flattening/unflattening the parameters. + + This transformation adds parametrization with `torch.nn.utils.parametrize`. + The FQNs of the subclass parameters will be changed and state_dict will become incompatible with the original model. + E.g. + Original model state_dict: {"p1": torch.testing._internal.TwoTensor} + becomes: {"parametrizations.p2.original0": torch.Tensor, "parametrizations.p2.original1": torch.Tensor} + + """ + for name, tensor in itertools.chain( + list(module.named_parameters(recurse=False)), + # pyrefly: ignore [no-matching-overload] + list(module.named_buffers(recurse=False)), + ): + if is_traceable_wrapper_subclass(tensor): + torch.nn.utils.parametrize.register_parametrization( + module, name, UnwrapTensorSubclass() + ) + + for child in module.children(): + unwrap_tensor_subclass_parameters(child) + + return module diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4a9c7e2af61c9d8e0dd59fafd096733b8749c68 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/_internal.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/_internal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e23d26e30c24568510e21ea9886e8b50b24b64b5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/_internal.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/_registrations.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/_registrations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb5484982ca6338b68c1d6cc91b65dd098232081 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/_registrations.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/scribe.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/scribe.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6afa9b103d4fc887c1de6a2fed419734ff048ee0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/scribe.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/structured.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/structured.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e7937200497524b60763598daca3ebcfa013aba Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_logging/__pycache__/structured.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88b88b1db760de19444c77c916e73b405a77c286 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e397221b4d111327ced070c763b0e9c0860ae306 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90e54af24d4242acfd6c73ecbecaee5729e6c34c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df48fe68d67c7c6d82b1404bd9f8d10d71d933ae Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d81ad0bd3c4e68287032a69d9528132ebb184682 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..827eb34388afbeaee869ef48ca448ff0e1a2ceaa Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37d9ecec87c9a3fcff6cf65f5ead29a1624e760a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf508405b3d7b72ac27c489287e009f36961fec6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03efcfacae8fcf39050e65c9ac4a578f52f834f3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c6ec63f6719ab35f168e589379f09fed91cd1e4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..403ddf43f122db40a32a4a4db9c9df54404c38e5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0653d6130d98329827c417f52ccd799391837508 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dc58b5954d6548a327907f7e09bb63874e5700e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_util.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8907e57b9540708fd89d3ccf7e2b9da3ccec14c4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/_util.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/fft.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/fft.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e12ac899257a85192ed3412121c31c378495f0fa Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/fft.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/linalg.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/linalg.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a2525a9789a534119f98778e40154b20a50cc3c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/linalg.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/random.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/random.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c3091ba6cfa5698b5545fe272361434834560d0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/__pycache__/random.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/testing/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05e73b12e29f8e6608647a3f16fabab39fbfb582 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/testing/__init__.py @@ -0,0 +1,20 @@ +# mypy: ignore-errors + +from .utils import ( + _gen_alignment_data, + assert_, + assert_allclose, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_array_less, + assert_equal, + assert_raises_regex, + assert_warns, + HAS_REFCOUNT, + IS_WASM, + suppress_warnings, +) + + +# from .testing import assert_allclose # FIXME diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/testing/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/testing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ab46453c888bf69d195239c9f426e139ecf0e33 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/testing/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/testing/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/testing/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b05b657680d1c8c14b018118a332d429a86a3c47 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/testing/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/testing/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/testing/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc027043b6f55aae572e2fb0ffe1142f6226959 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_numpy/testing/utils.py @@ -0,0 +1,2451 @@ +# mypy: ignore-errors + +""" +Utility function to facilitate testing. + +""" + +import contextlib +import gc +import operator +import os +import platform +import pprint +import re +import shutil +import sys +import warnings +from functools import wraps +from io import StringIO +from tempfile import mkdtemp, mkstemp +from warnings import WarningMessage + +import torch._numpy as np +from torch._numpy import arange, asarray as asanyarray, empty, float32, intp, ndarray + + +__all__ = [ + "assert_equal", + "assert_almost_equal", + "assert_approx_equal", + "assert_array_equal", + "assert_array_less", + "assert_string_equal", + "assert_", + "assert_array_almost_equal", + "build_err_msg", + "decorate_methods", + "print_assert_equal", + "verbose", + "assert_", + "assert_array_almost_equal_nulp", + "assert_raises_regex", + "assert_array_max_ulp", + "assert_warns", + "assert_no_warnings", + "assert_allclose", + "IgnoreException", + "clear_and_catch_warnings", + "temppath", + "tempdir", + "IS_PYPY", + "HAS_REFCOUNT", + "IS_WASM", + "suppress_warnings", + "assert_array_compare", + "assert_no_gc_cycles", + "break_cycles", + "IS_PYSTON", +] + + +verbose = 0 + +IS_WASM = platform.machine() in ["wasm32", "wasm64"] +IS_PYPY = sys.implementation.name == "pypy" +IS_PYSTON = hasattr(sys, "pyston_version_info") +HAS_REFCOUNT = getattr(sys, "getrefcount", None) is not None and not IS_PYSTON + + +def assert_(val, msg=""): + """ + Assert that works in release mode. + Accepts callable msg to allow deferring evaluation until failure. + + The Python built-in ``assert`` does not work when executing code in + optimized mode (the ``-O`` flag) - no byte-code is generated for it. + + For documentation on usage, refer to the Python documentation. + + """ + __tracebackhide__ = True # Hide traceback for py.test + if not val: + try: + smsg = msg() + except TypeError: + smsg = msg + raise AssertionError(smsg) + + +def gisnan(x): + return np.isnan(x) + + +def gisfinite(x): + return np.isfinite(x) + + +def gisinf(x): + return np.isinf(x) + + +def build_err_msg( + arrays, + err_msg, + header="Items are not equal:", + verbose=True, + names=("ACTUAL", "DESIRED"), + precision=8, +): + msg = ["\n" + header] + if err_msg: + if err_msg.find("\n") == -1 and len(err_msg) < 79 - len(header): + msg = [msg[0] + " " + err_msg] + else: + msg.append(err_msg) + if verbose: + for i, a in enumerate(arrays): + if isinstance(a, ndarray): + # precision argument is only needed if the objects are ndarrays + # r_func = partial(array_repr, precision=precision) + r_func = ndarray.__repr__ + else: + r_func = repr + + try: + r = r_func(a) + except Exception as exc: + r = f"[repr failed for <{type(a).__name__}>: {exc}]" + if r.count("\n") > 3: + r = "\n".join(r.splitlines()[:3]) + r += "..." + msg.append(f" {names[i]}: {r}") + return "\n".join(msg) + + +def assert_equal(actual, desired, err_msg="", verbose=True): + """ + Raises an AssertionError if two objects are not equal. + + Given two objects (scalars, lists, tuples, dictionaries or numpy arrays), + check that all elements of these objects are equal. An exception is raised + at the first conflicting values. + + When one of `actual` and `desired` is a scalar and the other is array_like, + the function checks that each element of the array_like object is equal to + the scalar. + + This function handles NaN comparisons as if NaN was a "normal" number. + That is, AssertionError is not raised if both objects have NaNs in the same + positions. This is in contrast to the IEEE standard on NaNs, which says + that NaN compared to anything must return False. + + Parameters + ---------- + actual : array_like + The object to check. + desired : array_like + The expected object. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal. + + Examples + -------- + >>> np.testing.assert_equal([4, 5], [4, 6]) + Traceback (most recent call last): + ... + AssertionError: + Items are not equal: + item=1 + ACTUAL: 5 + DESIRED: 6 + + The following comparison does not raise an exception. There are NaNs + in the inputs, but they are in the same positions. + + >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan]) + + """ + __tracebackhide__ = True # Hide traceback for py.test + + num_nones = sum([actual is None, desired is None]) + if num_nones == 1: + raise AssertionError(f"Not equal: {actual} != {desired}") + elif num_nones == 2: + return True + # else, carry on + + if isinstance(actual, np.DType) or isinstance(desired, np.DType): + result = actual == desired + if not result: + raise AssertionError(f"Not equal: {actual} != {desired}") + else: + return True + + if isinstance(desired, str) and isinstance(actual, str): + assert actual == desired + return + + if isinstance(desired, dict): + if not isinstance(actual, dict): + raise AssertionError(repr(type(actual))) + assert_equal(len(actual), len(desired), err_msg, verbose) + for k in desired: + if k not in actual: + raise AssertionError(repr(k)) + assert_equal(actual[k], desired[k], f"key={k!r}\n{err_msg}", verbose) + return + if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): + assert_equal(len(actual), len(desired), err_msg, verbose) + for k in range(len(desired)): + assert_equal(actual[k], desired[k], f"item={k!r}\n{err_msg}", verbose) + return + + from torch._numpy import imag, iscomplexobj, isscalar, ndarray, real, signbit + + if isinstance(actual, ndarray) or isinstance(desired, ndarray): + return assert_array_equal(actual, desired, err_msg, verbose) + msg = build_err_msg([actual, desired], err_msg, verbose=verbose) + + # Handle complex numbers: separate into real/imag to handle + # nan/inf/negative zero correctly + # XXX: catch ValueError for subclasses of ndarray where iscomplex fail + try: + usecomplex = iscomplexobj(actual) or iscomplexobj(desired) + except (ValueError, TypeError): + usecomplex = False + + if usecomplex: + if iscomplexobj(actual): + actualr = real(actual) + actuali = imag(actual) + else: + actualr = actual + actuali = 0 + if iscomplexobj(desired): + desiredr = real(desired) + desiredi = imag(desired) + else: + desiredr = desired + desiredi = 0 + try: + assert_equal(actualr, desiredr) + assert_equal(actuali, desiredi) + except AssertionError: + raise AssertionError(msg) # noqa: B904 + + # isscalar test to check cases such as [np.nan] != np.nan + if isscalar(desired) != isscalar(actual): + raise AssertionError(msg) + + # Inf/nan/negative zero handling + try: + isdesnan = gisnan(desired) + isactnan = gisnan(actual) + if isdesnan and isactnan: + return # both nan, so equal + + if desired == 0 and actual == 0: + if not signbit(desired) == signbit(actual): + raise AssertionError(msg) + + except (TypeError, ValueError, NotImplementedError): + pass + + try: + # Explicitly use __eq__ for comparison, gh-2552 + if not (desired == actual): + raise AssertionError(msg) + + except (DeprecationWarning, FutureWarning) as e: + # this handles the case when the two types are not even comparable + if "elementwise == comparison" in e.args[0]: + raise AssertionError(msg) # noqa: B904 + else: + raise + + +def print_assert_equal(test_string, actual, desired): + """ + Test if two objects are equal, and print an error message if test fails. + + The test is performed with ``actual == desired``. + + Parameters + ---------- + test_string : str + The message supplied to AssertionError. + actual : object + The object to test for equality against `desired`. + desired : object + The expected result. + + Examples + -------- + >>> np.testing.print_assert_equal( + ... "Test XYZ of func xyz", [0, 1], [0, 1] + ... ) # doctest: +SKIP + >>> np.testing.print_assert_equal( + ... "Test XYZ of func xyz", [0, 1], [0, 2] + ... ) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: Test XYZ of func xyz failed + ACTUAL: + [0, 1] + DESIRED: + [0, 2] + + """ + __tracebackhide__ = True # Hide traceback for py.test + import pprint + + if actual != desired: + msg = StringIO() + msg.write(test_string) + msg.write(" failed\nACTUAL: \n") + pprint.pprint(actual, msg) + msg.write("DESIRED: \n") + pprint.pprint(desired, msg) + raise AssertionError(msg.getvalue()) + + +def assert_almost_equal(actual, desired, decimal=7, err_msg="", verbose=True): + """ + Raises an AssertionError if two items are not equal up to desired + precision. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + The test verifies that the elements of `actual` and `desired` satisfy. + + ``abs(desired-actual) < float64(1.5 * 10**(-decimal))`` + + That is a looser test than originally documented, but agrees with what the + actual implementation in `assert_array_almost_equal` did up to rounding + vagaries. An exception is raised at conflicting values. For ndarrays this + delegates to assert_array_almost_equal + + Parameters + ---------- + actual : array_like + The object to check. + desired : array_like + The expected object. + decimal : int, optional + Desired precision, default is 7. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + >>> from torch._numpy.testing import assert_almost_equal + >>> assert_almost_equal(2.3333333333333, 2.33333334) + >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 10 decimals + ACTUAL: 2.3333333333333 + DESIRED: 2.33333334 + + >>> assert_almost_equal( + ... np.array([1.0, 2.3333333333333]), np.array([1.0, 2.33333334]), decimal=9 + ... ) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 9 decimals + + Mismatched elements: 1 / 2 (50%) + Max absolute difference: 6.666699636781459e-09 + Max relative difference: 2.8571569790287484e-09 + x: torch.ndarray([1.0000, 2.3333], dtype=float64) + y: torch.ndarray([1.0000, 2.3333], dtype=float64) + + """ + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import imag, iscomplexobj, ndarray, real + + # Handle complex numbers: separate into real/imag to handle + # nan/inf/negative zero correctly + # XXX: catch ValueError for subclasses of ndarray where iscomplex fail + try: + usecomplex = iscomplexobj(actual) or iscomplexobj(desired) + except ValueError: + usecomplex = False + + def _build_err_msg(): + header = f"Arrays are not almost equal to {decimal:d} decimals" + return build_err_msg([actual, desired], err_msg, verbose=verbose, header=header) + + if usecomplex: + if iscomplexobj(actual): + actualr = real(actual) + actuali = imag(actual) + else: + actualr = actual + actuali = 0 + if iscomplexobj(desired): + desiredr = real(desired) + desiredi = imag(desired) + else: + desiredr = desired + desiredi = 0 + try: + assert_almost_equal(actualr, desiredr, decimal=decimal) + assert_almost_equal(actuali, desiredi, decimal=decimal) + except AssertionError: + raise AssertionError(_build_err_msg()) # noqa: B904 + + if isinstance(actual, (ndarray, tuple, list)) or isinstance( + desired, (ndarray, tuple, list) + ): + return assert_array_almost_equal(actual, desired, decimal, err_msg) + try: + # If one of desired/actual is not finite, handle it specially here: + # check that both are nan if any is a nan, and test for equality + # otherwise + if not (gisfinite(desired) and gisfinite(actual)): + if gisnan(desired) or gisnan(actual): + if not (gisnan(desired) and gisnan(actual)): + raise AssertionError(_build_err_msg()) + else: + if not desired == actual: + raise AssertionError(_build_err_msg()) + return + except (NotImplementedError, TypeError): + pass + if abs(desired - actual) >= np.float64(1.5 * 10.0 ** (-decimal)): + raise AssertionError(_build_err_msg()) + + +def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True): + """ + Raises an AssertionError if two items are not equal up to significant + digits. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + Given two numbers, check that they are approximately equal. + Approximately equal is defined as the number of significant digits + that agree. + + Parameters + ---------- + actual : scalar + The object to check. + desired : scalar + The expected object. + significant : int, optional + Desired precision, default is 7. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + >>> np.testing.assert_approx_equal( + ... 0.12345677777777e-20, 0.1234567e-20 + ... ) # doctest: +SKIP + >>> np.testing.assert_approx_equal( + ... 0.12345670e-20, + ... 0.12345671e-20, # doctest: +SKIP + ... significant=8, + ... ) + >>> np.testing.assert_approx_equal( + ... 0.12345670e-20, + ... 0.12345672e-20, # doctest: +SKIP + ... significant=8, + ... ) + Traceback (most recent call last): + ... + AssertionError: + Items are not equal to 8 significant digits: + ACTUAL: 1.234567e-21 + DESIRED: 1.2345672e-21 + + the evaluated condition that raises the exception is + + >>> abs(0.12345670e-20 / 1e-21 - 0.12345672e-20 / 1e-21) >= 10 ** -(8 - 1) + True + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + (actual, desired) = map(float, (actual, desired)) + if desired == actual: + return + # Normalized the numbers to be in range (-10.0,10.0) + # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual)))))) + scale = 0.5 * (np.abs(desired) + np.abs(actual)) + scale = np.power(10, np.floor(np.log10(scale))) + try: + sc_desired = desired / scale + except ZeroDivisionError: + sc_desired = 0.0 + try: + sc_actual = actual / scale + except ZeroDivisionError: + sc_actual = 0.0 + msg = build_err_msg( + [actual, desired], + err_msg, + header=f"Items are not equal to {significant:d} significant digits:", + verbose=verbose, + ) + try: + # If one of desired/actual is not finite, handle it specially here: + # check that both are nan if any is a nan, and test for equality + # otherwise + if not (gisfinite(desired) and gisfinite(actual)): + if gisnan(desired) or gisnan(actual): + if not (gisnan(desired) and gisnan(actual)): + raise AssertionError(msg) + else: + if not desired == actual: + raise AssertionError(msg) + return + except (TypeError, NotImplementedError): + pass + if np.abs(sc_desired - sc_actual) >= np.power(10.0, -(significant - 1)): + raise AssertionError(msg) + + +def assert_array_compare( + comparison, + x, + y, + err_msg="", + verbose=True, + header="", + precision=6, + equal_nan=True, + equal_inf=True, + *, + strict=False, +): + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import all, array, asarray, bool_, inf, isnan, max + + x = asarray(x) + y = asarray(y) + + def array2string(a): + return str(a) + + # original array for output formatting + ox, oy = x, y + + def func_assert_same_pos(x, y, func=isnan, hasval="nan"): + """Handling nan/inf. + + Combine results of running func on x and y, checking that they are True + at the same locations. + + """ + __tracebackhide__ = True # Hide traceback for py.test + x_id = func(x) + y_id = func(y) + # We include work-arounds here to handle three types of slightly + # pathological ndarray subclasses: + # (1) all() on `masked` array scalars can return masked arrays, so we + # use != True + # (2) __eq__ on some ndarray subclasses returns Python booleans + # instead of element-wise comparisons, so we cast to bool_() and + # use isinstance(..., bool) checks + # (3) subclasses with bare-bones __array_function__ implementations may + # not implement np.all(), so favor using the .all() method + # We are not committed to supporting such subclasses, but it's nice to + # support them if possible. + if (x_id == y_id).all().item() is not True: + msg = build_err_msg( + [x, y], + err_msg + f"\nx and y {hasval} location mismatch:", + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + # If there is a scalar, then here we know the array has the same + # flag as it everywhere, so we should return the scalar flag. + if isinstance(x_id, bool) or x_id.ndim == 0: + return bool_(x_id) + elif isinstance(y_id, bool) or y_id.ndim == 0: + return bool_(y_id) + else: + return y_id + + try: + if strict: + cond = x.shape == y.shape and x.dtype == y.dtype + else: + cond = (x.shape == () or y.shape == ()) or x.shape == y.shape + if not cond: + if x.shape != y.shape: + reason = f"\n(shapes {x.shape}, {y.shape} mismatch)" + else: + reason = f"\n(dtypes {x.dtype}, {y.dtype} mismatch)" + msg = build_err_msg( + [x, y], + err_msg + reason, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + + flagged = bool_(False) + + if equal_nan: + flagged = func_assert_same_pos(x, y, func=isnan, hasval="nan") + + if equal_inf: + flagged |= func_assert_same_pos( + x, y, func=lambda xy: xy == +inf, hasval="+inf" + ) + flagged |= func_assert_same_pos( + x, y, func=lambda xy: xy == -inf, hasval="-inf" + ) + + if flagged.ndim > 0: + x, y = x[~flagged], y[~flagged] + # Only do the comparison if actual values are left + if x.size == 0: + return + elif flagged: + # no sense doing comparison if everything is flagged. + return + + val = comparison(x, y) + + if isinstance(val, bool): + cond = val + reduced = array([val]) + else: + reduced = val.ravel() + cond = reduced.all() + + # The below comparison is a hack to ensure that fully masked + # results, for which val.ravel().all() returns np.ma.masked, + # do not trigger a failure (np.ma.masked != True evaluates as + # np.ma.masked, which is falsy). + if not cond: + n_mismatch = reduced.size - int(reduced.sum(dtype=intp)) + n_elements = flagged.size if flagged.ndim != 0 else reduced.size + percent_mismatch = 100 * n_mismatch / n_elements + remarks = [ + f"Mismatched elements: {n_mismatch} / {n_elements} ({percent_mismatch:.3g}%)" + ] + + # with errstate(all='ignore'): + # ignore errors for non-numeric types + with contextlib.suppress(TypeError, RuntimeError): + error = abs(x - y) + if np.issubdtype(x.dtype, np.unsignedinteger): + error2 = abs(y - x) + np.minimum(error, error2, out=error) + max_abs_error = max(error) + remarks.append( + "Max absolute difference: " + array2string(max_abs_error.item()) + ) + + # note: this definition of relative error matches that one + # used by assert_allclose (found in np.isclose) + # Filter values where the divisor would be zero + nonzero = bool_(y != 0) + if all(~nonzero): + max_rel_error = array(inf) + else: + max_rel_error = max(error[nonzero] / abs(y[nonzero])) + remarks.append( + "Max relative difference: " + array2string(max_rel_error.item()) + ) + + err_msg += "\n" + "\n".join(remarks) + msg = build_err_msg( + [ox, oy], + err_msg, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + except ValueError: + import traceback + + efmt = traceback.format_exc() + header = f"error during assertion:\n\n{efmt}\n\n{header}" + + msg = build_err_msg( + [x, y], + err_msg, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise ValueError(msg) # noqa: B904 + + +def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False): + """ + Raises an AssertionError if two array_like objects are not equal. + + Given two array_like objects, check that the shape is equal and all + elements of these objects are equal (but see the Notes for the special + handling of a scalar). An exception is raised at shape mismatch or + conflicting values. In contrast to the standard usage in numpy, NaNs + are compared like numbers, no assertion is raised if both objects have + NaNs in the same positions. + + The usual caution for verifying equality with floating point numbers is + advised. + + Parameters + ---------- + x : array_like + The actual object to check. + y : array_like + The desired, expected object. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + strict : bool, optional + If True, raise an AssertionError when either the shape or the data + type of the array_like objects does not match. The special + handling for scalars mentioned in the Notes section is disabled. + + Raises + ------ + AssertionError + If actual and desired objects are not equal. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Notes + ----- + When one of `x` and `y` is a scalar and the other is array_like, the + function checks that each element of the array_like object is equal to + the scalar. This behaviour can be disabled with the `strict` parameter. + + Examples + -------- + The first assert does not raise an exception: + + >>> np.testing.assert_array_equal( + ... [1.0, 2.33333, np.nan], [np.exp(0), 2.33333, np.nan] + ... ) + + Use `assert_allclose` or one of the nulp (number of floating point values) + functions for these cases instead: + + >>> np.testing.assert_allclose( + ... [1.0, np.pi, np.nan], [1, np.sqrt(np.pi) ** 2, np.nan], rtol=1e-10, atol=0 + ... ) + + As mentioned in the Notes section, `assert_array_equal` has special + handling for scalars. Here the test checks that each value in `x` is 3: + + >>> x = np.full((2, 5), fill_value=3) + >>> np.testing.assert_array_equal(x, 3) + + Use `strict` to raise an AssertionError when comparing a scalar with an + array: + + >>> np.testing.assert_array_equal(x, 3, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (shapes (2, 5), () mismatch) + x: torch.ndarray([[3, 3, 3, 3, 3], + [3, 3, 3, 3, 3]]) + y: torch.ndarray(3) + + The `strict` parameter also ensures that the array data types match: + + >>> x = np.array([2, 2, 2]) + >>> y = np.array([2.0, 2.0, 2.0], dtype=np.float32) + >>> np.testing.assert_array_equal(x, y, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (dtypes dtype("int64"), dtype("float32") mismatch) + x: torch.ndarray([2, 2, 2]) + y: torch.ndarray([2., 2., 2.]) + """ + __tracebackhide__ = True # Hide traceback for py.test + assert_array_compare( + operator.__eq__, + x, + y, + err_msg=err_msg, + verbose=verbose, + header="Arrays are not equal", + strict=strict, + ) + + +def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True): + """ + Raises an AssertionError if two objects are not equal up to desired + precision. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + The test verifies identical shapes and that the elements of ``actual`` and + ``desired`` satisfy. + + ``abs(desired-actual) < 1.5 * 10**(-decimal)`` + + That is a looser test than originally documented, but agrees with what the + actual implementation did up to rounding vagaries. An exception is raised + at shape mismatch or conflicting values. In contrast to the standard usage + in numpy, NaNs are compared like numbers, no assertion is raised if both + objects have NaNs in the same positions. + + Parameters + ---------- + x : array_like + The actual object to check. + y : array_like + The desired, expected object. + decimal : int, optional + Desired precision, default is 6. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + the first assert does not raise an exception + + >>> np.testing.assert_array_almost_equal([1.0, 2.333, np.nan], [1.0, 2.333, np.nan]) + + >>> np.testing.assert_array_almost_equal( + ... [1.0, 2.33333, np.nan], [1.0, 2.33339, np.nan], decimal=5 + ... ) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 5 decimals + + Mismatched elements: 1 / 3 (33.3%) + Max absolute difference: 5.999999999994898e-05 + Max relative difference: 2.5713661239633743e-05 + x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) + y: torch.ndarray([1.0000, 2.3334, nan], dtype=float64) + + >>> np.testing.assert_array_almost_equal( + ... [1.0, 2.33333, np.nan], [1.0, 2.33333, 5], decimal=5 + ... ) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 5 decimals + + x and y nan location mismatch: + x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) + y: torch.ndarray([1.0000, 2.3333, 5.0000], dtype=float64) + + """ + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import any as npany, float_, issubdtype, number, result_type + + def compare(x, y): + try: + if npany(gisinf(x)) or npany(gisinf(y)): + xinfid = gisinf(x) + yinfid = gisinf(y) + if not (xinfid == yinfid).all(): + return False + # if one item, x and y is +- inf + if x.size == y.size == 1: + return x == y + x = x[~xinfid] + y = y[~yinfid] + except (TypeError, NotImplementedError): + pass + + # make sure y is an inexact type to avoid abs(MIN_INT); will cause + # casting of x later. + dtype = result_type(y, 1.0) + y = asanyarray(y, dtype) + z = abs(x - y) + + if not issubdtype(z.dtype, number): + z = z.astype(float_) # handle object arrays + + return z < 1.5 * 10.0 ** (-decimal) + + assert_array_compare( + compare, + x, + y, + err_msg=err_msg, + verbose=verbose, + header=f"Arrays are not almost equal to {decimal:d} decimals", + precision=decimal, + ) + + +def assert_array_less(x, y, err_msg="", verbose=True): + """ + Raises an AssertionError if two array_like objects are not ordered by less + than. + + Given two array_like objects, check that the shape is equal and all + elements of the first object are strictly smaller than those of the + second object. An exception is raised at shape mismatch or incorrectly + ordered values. Shape mismatch does not raise if an object has zero + dimension. In contrast to the standard usage in numpy, NaNs are + compared, no assertion is raised if both objects have NaNs in the same + positions. + + + + Parameters + ---------- + x : array_like + The smaller object to check. + y : array_like + The larger object to compare. + err_msg : string + The error message to be printed in case of failure. + verbose : bool + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired objects are not equal. + + See Also + -------- + assert_array_equal: tests objects for equality + assert_array_almost_equal: test objects for equality up to precision + + + + Examples + -------- + >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan]) + >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan]) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + Mismatched elements: 1 / 3 (33.3%) + Max absolute difference: 1.0 + Max relative difference: 0.5 + x: torch.ndarray([1., 1., nan], dtype=float64) + y: torch.ndarray([1., 2., nan], dtype=float64) + + >>> np.testing.assert_array_less([1.0, 4.0], 3) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + Mismatched elements: 1 / 2 (50%) + Max absolute difference: 2.0 + Max relative difference: 0.6666666666666666 + x: torch.ndarray([1., 4.], dtype=float64) + y: torch.ndarray(3) + + >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4]) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + (shapes (3,), (1,) mismatch) + x: torch.ndarray([1., 2., 3.], dtype=float64) + y: torch.ndarray([4]) + + """ + __tracebackhide__ = True # Hide traceback for py.test + assert_array_compare( + operator.__lt__, + x, + y, + err_msg=err_msg, + verbose=verbose, + header="Arrays are not less-ordered", + equal_inf=False, + ) + + +def assert_string_equal(actual, desired): + """ + Test if two strings are equal. + + If the given strings are equal, `assert_string_equal` does nothing. + If they are not equal, an AssertionError is raised, and the diff + between the strings is shown. + + Parameters + ---------- + actual : str + The string to test for equality against the expected string. + desired : str + The expected string. + + Examples + -------- + >>> np.testing.assert_string_equal("abc", "abc") # doctest: +SKIP + >>> np.testing.assert_string_equal("abc", "abcd") # doctest: +SKIP + Traceback (most recent call last): + File "", line 1, in + ... + AssertionError: Differences in strings: + - abc+ abcd? + + + """ + # delay import of difflib to reduce startup time + __tracebackhide__ = True # Hide traceback for py.test + import difflib + + if not isinstance(actual, str): + raise AssertionError(repr(type(actual))) + if not isinstance(desired, str): + raise AssertionError(repr(type(desired))) + if desired == actual: + return + + diff = list( + difflib.Differ().compare(actual.splitlines(True), desired.splitlines(True)) + ) + diff_list = [] + while diff: + d1 = diff.pop(0) + if d1.startswith(" "): + continue + if d1.startswith("- "): + l = [d1] + d2 = diff.pop(0) + if d2.startswith("? "): + l.append(d2) + d2 = diff.pop(0) + if not d2.startswith("+ "): + raise AssertionError(repr(d2)) + l.append(d2) + if diff: + d3 = diff.pop(0) + if d3.startswith("? "): + l.append(d3) + else: + diff.insert(0, d3) + if d2[2:] == d1[2:]: + continue + diff_list.extend(l) + continue + raise AssertionError(repr(d1)) + if not diff_list: + return + msg = f"Differences in strings:\n{''.join(diff_list).rstrip()}" + if actual != desired: + raise AssertionError(msg) + + +import unittest + + +class _Dummy(unittest.TestCase): + def nop(self): + pass + + +_d = _Dummy("nop") + + +def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): + """ + assert_raises_regex(exception_class, expected_regexp, callable, *args, + **kwargs) + assert_raises_regex(exception_class, expected_regexp) + + Fail unless an exception of class exception_class and with message that + matches expected_regexp is thrown by callable when invoked with arguments + args and keyword arguments kwargs. + + Alternatively, can be used as a context manager like `assert_raises`. + + Notes + ----- + .. versionadded:: 1.9.0 + + """ + __tracebackhide__ = True # Hide traceback for py.test + return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs) + + +def decorate_methods(cls, decorator, testmatch=None): + """ + Apply a decorator to all methods in a class matching a regular expression. + + The given decorator is applied to all public methods of `cls` that are + matched by the regular expression `testmatch` + (``testmatch.search(methodname)``). Methods that are private, i.e. start + with an underscore, are ignored. + + Parameters + ---------- + cls : class + Class whose methods to decorate. + decorator : function + Decorator to apply to methods + testmatch : compiled regexp or str, optional + The regular expression. Default value is None, in which case the + nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``) + is used. + If `testmatch` is a string, it is compiled to a regular expression + first. + + """ + if testmatch is None: + testmatch = re.compile(rf"(?:^|[\\b_\\.{os.sep}-])[Tt]est") + else: + testmatch = re.compile(testmatch) + cls_attr = cls.__dict__ + + # delayed import to reduce startup time + from inspect import isfunction + + methods = [_m for _m in cls_attr.values() if isfunction(_m)] + for function in methods: + try: + if hasattr(function, "compat_func_name"): + funcname = function.compat_func_name + else: + funcname = function.__name__ + except AttributeError: + # not a function + continue + if testmatch.search(funcname) and not funcname.startswith("_"): + setattr(cls, funcname, decorator(function)) + return + + +def _assert_valid_refcount(op): + """ + Check that ufuncs don't mishandle refcount of object `1`. + Used in a few regression tests. + """ + if not HAS_REFCOUNT: + return True + + import gc + + import numpy as np + + b = np.arange(100 * 100).reshape(100, 100) + c = b + i = 1 + + gc.disable() + try: + rc = sys.getrefcount(i) + for _ in range(15): + d = op(b, c) + assert_(sys.getrefcount(i) >= rc) + finally: + gc.enable() + del d # for pyflakes + + +def assert_allclose( + actual, + desired, + rtol=1e-7, + atol=0, + equal_nan=True, + err_msg="", + verbose=True, + check_dtype=False, +): + """ + Raises an AssertionError if two objects are not equal up to desired + tolerance. + + Given two array_like objects, check that their shapes and all elements + are equal (but see the Notes for the special handling of a scalar). An + exception is raised if the shapes mismatch or any values conflict. In + contrast to the standard usage in numpy, NaNs are compared like numbers, + no assertion is raised if both objects have NaNs in the same positions. + + The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note + that ``allclose`` has different default values). It compares the difference + between `actual` and `desired` to ``atol + rtol * abs(desired)``. + + .. versionadded:: 1.5.0 + + Parameters + ---------- + actual : array_like + Array obtained. + desired : array_like + Array desired. + rtol : float, optional + Relative tolerance. + atol : float, optional + Absolute tolerance. + equal_nan : bool, optional. + If True, NaNs will compare equal. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_array_almost_equal_nulp, assert_array_max_ulp + + Notes + ----- + When one of `actual` and `desired` is a scalar and the other is + array_like, the function checks that each element of the array_like + object is equal to the scalar. + + Examples + -------- + >>> x = [1e-5, 1e-3, 1e-1] + >>> y = np.arccos(np.cos(x)) + >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0) + + """ + __tracebackhide__ = True # Hide traceback for py.test + + def compare(x, y): + return np.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + + actual, desired = asanyarray(actual), asanyarray(desired) + header = f"Not equal to tolerance rtol={rtol:g}, atol={atol:g}" + + if check_dtype: + assert actual.dtype == desired.dtype + + assert_array_compare( + compare, + actual, + desired, + err_msg=str(err_msg), + verbose=verbose, + header=header, + equal_nan=equal_nan, + ) + + +def assert_array_almost_equal_nulp(x, y, nulp=1): + """ + Compare two arrays relatively to their spacing. + + This is a relatively robust method to compare two arrays whose amplitude + is variable. + + Parameters + ---------- + x, y : array_like + Input arrays. + nulp : int, optional + The maximum number of unit in the last place for tolerance (see Notes). + Default is 1. + + Returns + ------- + None + + Raises + ------ + AssertionError + If the spacing between `x` and `y` for one or more elements is larger + than `nulp`. + + See Also + -------- + assert_array_max_ulp : Check that all items of arrays differ in at most + N Units in the Last Place. + spacing : Return the distance between x and the nearest adjacent number. + + Notes + ----- + An assertion is raised if the following condition is not met:: + + abs(x - y) <= nulp * spacing(maximum(abs(x), abs(y))) + + Examples + -------- + >>> x = np.array([1.0, 1e-10, 1e-20]) + >>> eps = np.finfo(x.dtype).eps + >>> np.testing.assert_array_almost_equal_nulp(x, x * eps / 2 + x) # doctest: +SKIP + + >>> np.testing.assert_array_almost_equal_nulp(x, x * eps + x) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: X and Y are not equal to 1 ULP (max is 2) + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + ax = np.abs(x) + ay = np.abs(y) + ref = nulp * np.spacing(np.where(ax > ay, ax, ay)) + if not np.all(np.abs(x - y) <= ref): + if np.iscomplexobj(x) or np.iscomplexobj(y): + msg = f"X and Y are not equal to {nulp:d} ULP" + else: + max_nulp = np.max(nulp_diff(x, y)) + msg = f"X and Y are not equal to {nulp:d} ULP (max is {max_nulp:g})" + raise AssertionError(msg) + + +def assert_array_max_ulp(a, b, maxulp=1, dtype=None): + """ + Check that all items of arrays differ in at most N Units in the Last Place. + + Parameters + ---------- + a, b : array_like + Input arrays to be compared. + maxulp : int, optional + The maximum number of units in the last place that elements of `a` and + `b` can differ. Default is 1. + dtype : dtype, optional + Data-type to convert `a` and `b` to if given. Default is None. + + Returns + ------- + ret : ndarray + Array containing number of representable floating point numbers between + items in `a` and `b`. + + Raises + ------ + AssertionError + If one or more elements differ by more than `maxulp`. + + Notes + ----- + For computing the ULP difference, this API does not differentiate between + various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 + is zero). + + See Also + -------- + assert_array_almost_equal_nulp : Compare two arrays relatively to their + spacing. + + Examples + -------- + >>> a = np.linspace(0.0, 1.0, 100) + >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) # doctest: +SKIP + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + ret = nulp_diff(a, b, dtype) + if not np.all(ret <= maxulp): + raise AssertionError( + f"Arrays are not almost equal up to {maxulp:g} " + f"ULP (max difference is {np.max(ret):g} ULP)" + ) + return ret + + +def nulp_diff(x, y, dtype=None): + """For each item in x and y, return the number of representable floating + points between them. + + Parameters + ---------- + x : array_like + first input array + y : array_like + second input array + dtype : dtype, optional + Data-type to convert `x` and `y` to if given. Default is None. + + Returns + ------- + nulp : array_like + number of representable floating point numbers between each item in x + and y. + + Notes + ----- + For computing the ULP difference, this API does not differentiate between + various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 + is zero). + + Examples + -------- + # By definition, epsilon is the smallest number such as 1 + eps != 1, so + # there should be exactly one ULP between 1 and 1 + eps + >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps) # doctest: +SKIP + 1.0 + """ + import numpy as np + + if dtype: + x = np.asarray(x, dtype=dtype) + y = np.asarray(y, dtype=dtype) + else: + x = np.asarray(x) + y = np.asarray(y) + + t = np.common_type(x, y) + if np.iscomplexobj(x) or np.iscomplexobj(y): + raise NotImplementedError("_nulp not implemented for complex array") + + x = np.array([x], dtype=t) + y = np.array([y], dtype=t) + + x[np.isnan(x)] = np.nan + y[np.isnan(y)] = np.nan + + if not x.shape == y.shape: + raise ValueError(f"x and y do not have the same shape: {x.shape} - {y.shape}") + + def _diff(rx, ry, vdt): + diff = np.asarray(rx - ry, dtype=vdt) + return np.abs(diff) + + rx = integer_repr(x) + ry = integer_repr(y) + return _diff(rx, ry, t) + + +def _integer_repr(x, vdt, comp): + # Reinterpret binary representation of the float as sign-magnitude: + # take into account two-complement representation + # See also + # https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ + rx = x.view(vdt) + if rx.size != 1: + rx[rx < 0] = comp - rx[rx < 0] + else: + if rx < 0: + rx = comp - rx + + return rx + + +def integer_repr(x): + """Return the signed-magnitude interpretation of the binary representation + of x.""" + import numpy as np + + if x.dtype == np.float16: + return _integer_repr(x, np.int16, np.int16(-(2**15))) + elif x.dtype == np.float32: + return _integer_repr(x, np.int32, np.int32(-(2**31))) + elif x.dtype == np.float64: + return _integer_repr(x, np.int64, np.int64(-(2**63))) + else: + raise ValueError(f"Unsupported dtype {x.dtype}") + + +@contextlib.contextmanager +def _assert_warns_context(warning_class, name=None): + __tracebackhide__ = True # Hide traceback for py.test + with suppress_warnings() as sup: + l = sup.record(warning_class) + yield + if not len(l) > 0: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError("No warning raised" + name_str) + + +def assert_warns(warning_class, *args, **kwargs): + """ + Fail unless the given callable throws the specified warning. + + A warning of class warning_class should be thrown by the callable when + invoked with arguments args and keyword arguments kwargs. + If a different type of warning is thrown, it will not be caught. + + If called with all arguments other than the warning class omitted, may be + used as a context manager: + + with assert_warns(SomeWarning): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. + + .. versionadded:: 1.4.0 + + Parameters + ---------- + warning_class : class + The class defining the warning that `func` is expected to throw. + func : callable, optional + Callable to test + *args : Arguments + Arguments for `func`. + **kwargs : Kwargs + Keyword arguments for `func`. + + Returns + ------- + The value returned by `func`. + + Examples + -------- + >>> import warnings + >>> def deprecated_func(num): + ... warnings.warn("Please upgrade", DeprecationWarning) + ... return num * num + >>> with np.testing.assert_warns(DeprecationWarning): + ... assert deprecated_func(4) == 16 + >>> # or passing a func + >>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4) + >>> assert ret == 16 + """ + if not args: + return _assert_warns_context(warning_class) + + func = args[0] + args = args[1:] + with _assert_warns_context(warning_class, name=func.__name__): + return func(*args, **kwargs) + + +@contextlib.contextmanager +def _assert_no_warnings_context(name=None): + __tracebackhide__ = True # Hide traceback for py.test + with warnings.catch_warnings(record=True) as l: + warnings.simplefilter("always") + yield + if len(l) > 0: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError(f"Got warnings{name_str}: {l}") + + +def assert_no_warnings(*args, **kwargs): + """ + Fail if the given callable produces any warnings. + + If called with all arguments omitted, may be used as a context manager: + + with assert_no_warnings(): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. + + .. versionadded:: 1.7.0 + + Parameters + ---------- + func : callable + The callable to test. + \\*args : Arguments + Arguments passed to `func`. + \\*\\*kwargs : Kwargs + Keyword arguments passed to `func`. + + Returns + ------- + The value returned by `func`. + + """ + if not args: + return _assert_no_warnings_context() + + func = args[0] + args = args[1:] + with _assert_no_warnings_context(name=func.__name__): + return func(*args, **kwargs) + + +def _gen_alignment_data(dtype=float32, type="binary", max_size=24): + """ + generator producing data with different alignment and offsets + to test simd vectorization + + Parameters + ---------- + dtype : dtype + data type to produce + type : string + 'unary': create data for unary operations, creates one input + and output array + 'binary': create data for unary operations, creates two input + and output array + max_size : integer + maximum size of data to produce + + Returns + ------- + if type is 'unary' yields one output, one input array and a message + containing information on the data + if type is 'binary' yields one output array, two input array and a message + containing information on the data + + """ + ufmt = "unary offset=(%d, %d), size=%d, dtype=%r, %s" + bfmt = "binary offset=(%d, %d, %d), size=%d, dtype=%r, %s" + for o in range(3): + for s in range(o + 2, max(o + 3, max_size)): + if type == "unary": + + def inp(): + return arange(s, dtype=dtype)[o:] + + out = empty((s,), dtype=dtype)[o:] + yield out, inp(), ufmt % (o, o, s, dtype, "out of place") + d = inp() + yield d, d, ufmt % (o, o, s, dtype, "in place") + yield ( + out[1:], + inp()[:-1], + ufmt + % ( + o + 1, + o, + s - 1, + dtype, + "out of place", + ), + ) + yield ( + out[:-1], + inp()[1:], + ufmt + % ( + o, + o + 1, + s - 1, + dtype, + "out of place", + ), + ) + yield inp()[:-1], inp()[1:], ufmt % (o, o + 1, s - 1, dtype, "aliased") + yield inp()[1:], inp()[:-1], ufmt % (o + 1, o, s - 1, dtype, "aliased") + if type == "binary": + + def inp1(): + return arange(s, dtype=dtype)[o:] + + inp2 = inp1 + out = empty((s,), dtype=dtype)[o:] + yield out, inp1(), inp2(), bfmt % (o, o, o, s, dtype, "out of place") + d = inp1() + yield d, d, inp2(), bfmt % (o, o, o, s, dtype, "in place1") + d = inp2() + yield d, inp1(), d, bfmt % (o, o, o, s, dtype, "in place2") + yield ( + out[1:], + inp1()[:-1], + inp2()[:-1], + bfmt + % ( + o + 1, + o, + o, + s - 1, + dtype, + "out of place", + ), + ) + yield ( + out[:-1], + inp1()[1:], + inp2()[:-1], + bfmt + % ( + o, + o + 1, + o, + s - 1, + dtype, + "out of place", + ), + ) + yield ( + out[:-1], + inp1()[:-1], + inp2()[1:], + bfmt + % ( + o, + o, + o + 1, + s - 1, + dtype, + "out of place", + ), + ) + yield ( + inp1()[1:], + inp1()[:-1], + inp2()[:-1], + bfmt + % ( + o + 1, + o, + o, + s - 1, + dtype, + "aliased", + ), + ) + yield ( + inp1()[:-1], + inp1()[1:], + inp2()[:-1], + bfmt + % ( + o, + o + 1, + o, + s - 1, + dtype, + "aliased", + ), + ) + yield ( + inp1()[:-1], + inp1()[:-1], + inp2()[1:], + bfmt + % ( + o, + o, + o + 1, + s - 1, + dtype, + "aliased", + ), + ) + + +class IgnoreException(Exception): + "Ignoring this exception due to disabled feature" + + +@contextlib.contextmanager +def tempdir(*args, **kwargs): + """Context manager to provide a temporary test folder. + + All arguments are passed as this to the underlying tempfile.mkdtemp + function. + + """ + tmpdir = mkdtemp(*args, **kwargs) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir) + + +@contextlib.contextmanager +def temppath(*args, **kwargs): + """Context manager for temporary files. + + Context manager that returns the path to a closed temporary file. Its + parameters are the same as for tempfile.mkstemp and are passed directly + to that function. The underlying file is removed when the context is + exited, so it should be closed at that time. + + Windows does not allow a temporary file to be opened if it is already + open, so the underlying file must be closed after opening before it + can be opened again. + + """ + fd, path = mkstemp(*args, **kwargs) + os.close(fd) + try: + yield path + finally: + os.remove(path) + + +class clear_and_catch_warnings(warnings.catch_warnings): + """Context manager that resets warning registry for catching warnings + + Warnings can be slippery, because, whenever a warning is triggered, Python + adds a ``__warningregistry__`` member to the *calling* module. This makes + it impossible to retrigger the warning in this module, whatever you put in + the warnings filters. This context manager accepts a sequence of `modules` + as a keyword argument to its constructor and: + + * stores and removes any ``__warningregistry__`` entries in given `modules` + on entry; + * resets ``__warningregistry__`` to its previous state on exit. + + This makes it possible to trigger any warning afresh inside the context + manager without disturbing the state of warnings outside. + + For compatibility with Python 3.0, please consider all arguments to be + keyword-only. + + Parameters + ---------- + record : bool, optional + Specifies whether warnings should be captured by a custom + implementation of ``warnings.showwarning()`` and be appended to a list + returned by the context manager. Otherwise None is returned by the + context manager. The objects appended to the list are arguments whose + attributes mirror the arguments to ``showwarning()``. + modules : sequence, optional + Sequence of modules for which to reset warnings registry on entry and + restore on exit. To work correctly, all 'ignore' filters should + filter by one of these modules. + + Examples + -------- + >>> import warnings + >>> with np.testing.clear_and_catch_warnings( # doctest: +SKIP + ... modules=[np.core.fromnumeric] + ... ): + ... warnings.simplefilter("always") + ... warnings.filterwarnings("ignore", module="np.core.fromnumeric") + ... # do something that raises a warning but ignore those in + ... # np.core.fromnumeric + """ + + class_modules = () + + def __init__(self, record=False, modules=()): + self.modules = set(modules).union(self.class_modules) + self._warnreg_copies = {} + super().__init__(record=record) + + def __enter__(self): + for mod in self.modules: + if hasattr(mod, "__warningregistry__"): + mod_reg = mod.__warningregistry__ + self._warnreg_copies[mod] = mod_reg.copy() + mod_reg.clear() + return super().__enter__() + + def __exit__(self, *exc_info): + super().__exit__(*exc_info) + for mod in self.modules: + if hasattr(mod, "__warningregistry__"): + mod.__warningregistry__.clear() + if mod in self._warnreg_copies: + mod.__warningregistry__.update(self._warnreg_copies[mod]) + + +class suppress_warnings: + """ + Context manager and decorator doing much the same as + ``warnings.catch_warnings``. + + However, it also provides a filter mechanism to work around + https://bugs.python.org/issue4180. + + This bug causes Python before 3.4 to not reliably show warnings again + after they have been ignored once (even within catch_warnings). It + means that no "ignore" filter can be used easily, since following + tests might need to see the warning. Additionally it allows easier + specificity for testing warnings and can be nested. + + Parameters + ---------- + forwarding_rule : str, optional + One of "always", "once", "module", or "location". Analogous to + the usual warnings module filter mode, it is useful to reduce + noise mostly on the outmost level. Unsuppressed and unrecorded + warnings will be forwarded based on this rule. Defaults to "always". + "location" is equivalent to the warnings "default", match by exact + location the warning warning originated from. + + Notes + ----- + Filters added inside the context manager will be discarded again + when leaving it. Upon entering all filters defined outside a + context will be applied automatically. + + When a recording filter is added, matching warnings are stored in the + ``log`` attribute as well as in the list returned by ``record``. + + If filters are added and the ``module`` keyword is given, the + warning registry of this module will additionally be cleared when + applying it, entering the context, or exiting it. This could cause + warnings to appear a second time after leaving the context if they + were configured to be printed once (default) and were already + printed before the context was entered. + + Nesting this context manager will work as expected when the + forwarding rule is "always" (default). Unfiltered and unrecorded + warnings will be passed out and be matched by the outer level. + On the outmost level they will be printed (or caught by another + warnings context). The forwarding rule argument can modify this + behaviour. + + Like ``catch_warnings`` this context manager is not threadsafe. + + Examples + -------- + + With a context manager:: + + with np.testing.suppress_warnings() as sup: + sup.filter(DeprecationWarning, "Some text") + sup.filter(module=np.ma.core) + log = sup.record(FutureWarning, "Does this occur?") + command_giving_warnings() + # The FutureWarning was given once, the filtered warnings were + # ignored. All other warnings abide outside settings (may be + # printed/error) + assert_(len(log) == 1) + assert_(len(sup.log) == 1) # also stored in log attribute + + Or as a decorator:: + + sup = np.testing.suppress_warnings() + sup.filter(module=np.ma.core) # module must match exactly + + + @sup + def some_function(): + # do something which causes a warning in np.ma.core + pass + """ + + def __init__(self, forwarding_rule="always"): + self._entered = False + + # Suppressions are either instance or defined inside one with block: + self._suppressions = [] + + if forwarding_rule not in {"always", "module", "once", "location"}: + raise ValueError("unsupported forwarding rule.") + self._forwarding_rule = forwarding_rule + + def _clear_registries(self): + if hasattr(warnings, "_filters_mutated"): + # clearing the registry should not be necessary on new pythons, + # instead the filters should be mutated. + warnings._filters_mutated() + return + # Simply clear the registry, this should normally be harmless, + # note that on new pythons it would be invalidated anyway. + for module in self._tmp_modules: + if hasattr(module, "__warningregistry__"): + module.__warningregistry__.clear() + + def _filter(self, category=Warning, message="", module=None, record=False): + if record: + record = [] # The log where to store warnings + else: + record = None + if self._entered: + if module is None: + warnings.filterwarnings("always", category=category, message=message) + else: + module_regex = module.__name__.replace(".", r"\.") + "$" + warnings.filterwarnings( + "always", category=category, message=message, module=module_regex + ) + self._tmp_modules.add(module) + self._clear_registries() + + self._tmp_suppressions.append( + (category, message, re.compile(message, re.IGNORECASE), module, record) + ) + else: + self._suppressions.append( + (category, message, re.compile(message, re.IGNORECASE), module, record) + ) + + return record + + def filter(self, category=Warning, message="", module=None): + """ + Add a new suppressing filter or apply it if the state is entered. + + Parameters + ---------- + category : class, optional + Warning class to filter + message : string, optional + Regular expression matching the warning message. + module : module, optional + Module to filter for. Note that the module (and its file) + must match exactly and cannot be a submodule. This may make + it unreliable for external modules. + + Notes + ----- + When added within a context, filters are only added inside + the context and will be forgotten when the context is exited. + """ + self._filter(category=category, message=message, module=module, record=False) + + def record(self, category=Warning, message="", module=None): + """ + Append a new recording filter or apply it if the state is entered. + + All warnings matching will be appended to the ``log`` attribute. + + Parameters + ---------- + category : class, optional + Warning class to filter + message : string, optional + Regular expression matching the warning message. + module : module, optional + Module to filter for. Note that the module (and its file) + must match exactly and cannot be a submodule. This may make + it unreliable for external modules. + + Returns + ------- + log : list + A list which will be filled with all matched warnings. + + Notes + ----- + When added within a context, filters are only added inside + the context and will be forgotten when the context is exited. + """ + return self._filter( + category=category, message=message, module=module, record=True + ) + + def __enter__(self): + if self._entered: + raise RuntimeError("cannot enter suppress_warnings twice.") + + self._orig_show = warnings.showwarning + self._filters = warnings.filters + warnings.filters = self._filters[:] + + self._entered = True + self._tmp_suppressions = [] + self._tmp_modules = set() + self._forwarded = set() + + self.log = [] # reset global log (no need to keep same list) + + for cat, mess, _, mod, log in self._suppressions: + if log is not None: + del log[:] # clear the log + if mod is None: + warnings.filterwarnings("always", category=cat, message=mess) + else: + module_regex = mod.__name__.replace(".", r"\.") + "$" + warnings.filterwarnings( + "always", category=cat, message=mess, module=module_regex + ) + self._tmp_modules.add(mod) + warnings.showwarning = self._showwarning + self._clear_registries() + + return self + + def __exit__(self, *exc_info): + warnings.showwarning = self._orig_show + warnings.filters = self._filters + self._clear_registries() + self._entered = False + del self._orig_show + del self._filters + + def _showwarning( + self, message, category, filename, lineno, *args, use_warnmsg=None, **kwargs + ): + for cat, _, pattern, mod, rec in (self._suppressions + self._tmp_suppressions)[ + ::-1 + ]: + if issubclass(category, cat) and pattern.match(message.args[0]) is not None: + if mod is None: + # Message and category match, either recorded or ignored + if rec is not None: + msg = WarningMessage( + message, category, filename, lineno, **kwargs + ) + self.log.append(msg) + rec.append(msg) + return + # Use startswith, because warnings strips the c or o from + # .pyc/.pyo files. + elif mod.__file__.startswith(filename): + # The message and module (filename) match + if rec is not None: + msg = WarningMessage( + message, category, filename, lineno, **kwargs + ) + self.log.append(msg) + rec.append(msg) + return + + # There is no filter in place, so pass to the outside handler + # unless we should only pass it once + if self._forwarding_rule == "always": + if use_warnmsg is None: + self._orig_show(message, category, filename, lineno, *args, **kwargs) + else: + self._orig_showmsg(use_warnmsg) + return + + if self._forwarding_rule == "once": + signature = (message.args, category) + elif self._forwarding_rule == "module": + signature = (message.args, category, filename) + elif self._forwarding_rule == "location": + signature = (message.args, category, filename, lineno) + + if signature in self._forwarded: + return + self._forwarded.add(signature) + if use_warnmsg is None: + self._orig_show(message, category, filename, lineno, *args, **kwargs) + else: + self._orig_showmsg(use_warnmsg) + + def __call__(self, func): + """ + Function decorator to apply certain suppressions to a whole + function. + """ + + @wraps(func) + def new_func(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return new_func + + +@contextlib.contextmanager +def _assert_no_gc_cycles_context(name=None): + __tracebackhide__ = True # Hide traceback for py.test + + # not meaningful to test if there is no refcounting + if not HAS_REFCOUNT: + yield + return + + assert_(gc.isenabled()) + gc.disable() + gc_debug = gc.get_debug() + try: + for _ in range(100): + if gc.collect() == 0: + break + else: + raise RuntimeError( + "Unable to fully collect garbage - perhaps a __del__ method " + "is creating more reference cycles?" + ) + + gc.set_debug(gc.DEBUG_SAVEALL) + yield + # gc.collect returns the number of unreachable objects in cycles that + # were found -- we are checking that no cycles were created in the context + n_objects_in_cycles = gc.collect() + objects_in_cycles = gc.garbage[:] + finally: + del gc.garbage[:] + gc.set_debug(gc_debug) + gc.enable() + + if n_objects_in_cycles: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError( + "Reference cycles were found{}: {} objects were collected, " + "of which {} are shown below:{}".format( + name_str, + n_objects_in_cycles, + len(objects_in_cycles), + "".join( + "\n {} object with id={}:\n {}".format( + type(o).__name__, + id(o), + pprint.pformat(o).replace("\n", "\n "), + ) + for o in objects_in_cycles + ), + ) + ) + + +def assert_no_gc_cycles(*args, **kwargs): + """ + Fail if the given callable produces any reference cycles. + + If called with all arguments omitted, may be used as a context manager: + + with assert_no_gc_cycles(): + do_something() + + .. versionadded:: 1.15.0 + + Parameters + ---------- + func : callable + The callable to test. + \\*args : Arguments + Arguments passed to `func`. + \\*\\*kwargs : Kwargs + Keyword arguments passed to `func`. + + Returns + ------- + Nothing. The result is deliberately discarded to ensure that all cycles + are found. + + """ + if not args: + return _assert_no_gc_cycles_context() + + func = args[0] + args = args[1:] + with _assert_no_gc_cycles_context(name=func.__name__): + func(*args, **kwargs) + + +def break_cycles(): + """ + Break reference cycles by calling gc.collect + Objects can call other objects' methods (for instance, another object's + __del__) inside their own __del__. On PyPy, the interpreter only runs + between calls to gc.collect, so multiple calls are needed to completely + release all cycles. + """ + + gc.collect() + if IS_PYPY: + # a few more, just to make sure all the finalizers are called + gc.collect() + gc.collect() + gc.collect() + gc.collect() + + +def requires_memory(free_bytes): + """Decorator to skip a test if not enough memory is available""" + import pytest + + def decorator(func): + @wraps(func) + def wrapper(*a, **kw): + msg = check_free_memory(free_bytes) + if msg is not None: + pytest.skip(msg) + + try: + return func(*a, **kw) + except MemoryError: + # Probably ran out of memory regardless: don't regard as failure + pytest.xfail("MemoryError raised") + + return wrapper + + return decorator + + +def check_free_memory(free_bytes): + """ + Check whether `free_bytes` amount of memory is currently free. + Returns: None if enough memory available, otherwise error message + """ + env_var = "NPY_AVAILABLE_MEM" + env_value = os.environ.get(env_var) + if env_value is not None: + try: + mem_free = _parse_size(env_value) + except ValueError as exc: + raise ValueError( # noqa: B904 + f"Invalid environment variable {env_var}: {exc}" + ) + + msg = ( + f"{free_bytes / 1e9} GB memory required, but environment variable " + f"NPY_AVAILABLE_MEM={env_value} set" + ) + else: + mem_free = _get_mem_available() + + if mem_free is None: + msg = ( + "Could not determine available memory; set NPY_AVAILABLE_MEM " + "environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run " + "the test." + ) + mem_free = -1 + else: + msg = f"{free_bytes / 1e9} GB memory required, but {mem_free / 1e9} GB available" + + return msg if mem_free < free_bytes else None + + +def _parse_size(size_str): + """Convert memory size strings ('12 GB' etc.) to float""" + suffixes = { + "": 1, + "b": 1, + "k": 1000, + "m": 1000**2, + "g": 1000**3, + "t": 1000**4, + "kb": 1000, + "mb": 1000**2, + "gb": 1000**3, + "tb": 1000**4, + "kib": 1024, + "mib": 1024**2, + "gib": 1024**3, + "tib": 1024**4, + } + + size_re = re.compile( + r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())), + re.IGNORECASE, + ) + + m = size_re.match(size_str.lower()) + if not m or m.group(2) not in suffixes: + raise ValueError(f"value {size_str!r} not a valid size") + return int(float(m.group(1)) * suffixes[m.group(2)]) + + +def _get_mem_available(): + """Return available memory in bytes, or None if unknown.""" + try: + import psutil + + return psutil.virtual_memory().available + except (ImportError, AttributeError): + pass + + if sys.platform.startswith("linux"): + info = {} + with open("/proc/meminfo") as f: + for line in f: + p = line.split() + info[p[0].strip(":").lower()] = int(p[1]) * 1024 + + if "memavailable" in info: + # Linux >= 3.14 + return info["memavailable"] + else: + return info["memfree"] + info["cached"] + + return None + + +def _no_tracing(func): + """ + Decorator to temporarily turn off tracing for the duration of a test. + Needed in tests that check refcounting, otherwise the tracing itself + influences the refcounts + """ + if not hasattr(sys, "gettrace"): + return func + else: + + @wraps(func) + def wrapper(*args, **kwargs): + original_trace = sys.gettrace() + try: + sys.settrace(None) + return func(*args, **kwargs) + finally: + sys.settrace(original_trace) + + return wrapper + + +def _get_glibc_version(): + try: + ver = os.confstr("CS_GNU_LIBC_VERSION").rsplit(" ")[1] + except Exception: + ver = "0.0" + + return ver + + +_glibcver = _get_glibc_version() + + +def _glibc_older_than(x): + return _glibcver != "0.0" and _glibcver < x diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..475fe5eee8db16ef55a118f79ee14ff1f1cf3a4f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55b85e808fd34bcfd50a3cc40386dfa52d99a642 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dca71fcf09b019f3e197576eb415ba4fd54fa28a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py @@ -0,0 +1,4 @@ +from .linear import Linear + + +__all__ = ["Linear"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc3a86d4e00ff91be7bb21adbbb7b2a288cf3b88 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42323a2cf321d18d98539d2a5b8176bf002d41a4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2238eedf6f902174421a94702a4188fa463098 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py @@ -0,0 +1,40 @@ +from typing import Optional, TYPE_CHECKING + +import torch + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import QConfig + + +__all__ = ["Linear"] + + +class Linear(torch.ao.nn.qat.Linear): + r""" + A linear module attached with FakeQuantize modules for weight, + used for dynamic quantization aware training. + + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear + for documentation. + + Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to + default. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + qconfig: Optional["QConfig"] = None, + device: int | str | torch.device | None = None, + dtype: str | None = None, + ) -> None: + super().__init__(in_features, out_features, bias, qconfig, device, dtype) + if not torch.ao.quantization.qconfig._activation_is_memoryless(qconfig): # type: ignore[arg-type] + raise ValueError( + "Dynamic QAT requires a memoryless observer." + + "This means a MovingAverage observer with averaging constant equal to 1" + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e28e0968a60d7612ebbd26d5f607b4407c2d380 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__init__.py @@ -0,0 +1,13 @@ +from .conv import Conv1d, Conv2d, Conv3d +from .embedding_ops import Embedding, EmbeddingBag +from .linear import Linear + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "Embedding", + "EmbeddingBag", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b5e7c6a70f23c9ad6095185c67597cffa4782fe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f3a5fc6a42d27a1cc69e10c3158e29bcf034047 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cae0f612915c77f42972a9daa02f10a86fd8bf72 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..765fa4be205a292f114c385936a20eab6d38e892 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/conv.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9d228d56fce129860f0ebad805b042771b941804 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/conv.py @@ -0,0 +1,312 @@ +# mypy: allow-untyped-defs +from typing import ClassVar, Literal + +import torch +import torch.nn as nn +from torch.ao.nn.intrinsic import _FusedModule +from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t +from torch.nn.modules.utils import _pair, _single, _triple + + +__all__ = ["Conv1d", "Conv2d", "Conv3d"] + + +class _ConvNd(nn.modules.conv._ConvNd): + _FLOAT_MODULE: ClassVar[type[nn.modules.conv._ConvNd]] + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: tuple[int, ...], + stride: tuple[int, ...], + padding: str | tuple[int, ...], + dilation: tuple[int, ...], + transposed: bool, + output_padding: tuple[int, ...], + groups: int, + bias: bool, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"], + qconfig=None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + nn.modules.conv._ConvNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) + + def forward(self, input): + return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + + @staticmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module + + Args: + `mod`: a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + assert type(mod) is cls._FLOAT_MODULE, ( + "qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + if issubclass(type(mod), _FusedModule): + mod = mod[0] + qconfig = mod.qconfig + qat_conv = cls( + mod.in_channels, + mod.out_channels, + mod.kernel_size, + stride=mod.stride, + padding=mod.padding, + dilation=mod.dilation, + groups=mod.groups, + bias=mod.bias is not None, + padding_mode=mod.padding_mode, + qconfig=qconfig, + ) + qat_conv.weight = mod.weight + qat_conv.bias = mod.bias + return qat_conv + + def to_float(self): + """This works for both single qat conv, and the qat conv - relu modules + to convert the qat module to a floating point module + """ + cls = type(self) + conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined] + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.groups, + self.bias is not None, + self.padding_mode, + ) + conv.weight = torch.nn.Parameter(self.weight.detach()) + if self.bias is not None: + conv.bias = torch.nn.Parameter(self.bias.detach()) + # conv relu + if issubclass(cls, _FusedModule): + modules = [conv] + assert hasattr(cls, "_FLOAT_RELU_MODULE") + relu = cls._FLOAT_RELU_MODULE() + modules.append(relu) + # pyrefly: ignore [missing-attribute] + fused = cls._FLOAT_MODULE(*modules) + fused.train(self.training) + return fused + else: + return conv + + +class Conv1d(_ConvNd, nn.Conv1d): + r""" + A Conv1d module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as :class:`~torch.nn.Conv1d` + + Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: str | _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + qconfig=None, + device=None, + dtype=None, + ) -> None: + kernel_size_ = _single(kernel_size) + stride_ = _single(stride) + padding_ = padding if isinstance(padding, str) else _single(padding) + dilation_ = _single(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride=stride_, + padding=padding_, + dilation=dilation_, + transposed=False, + output_padding=_single(0), + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + device=device, + dtype=dtype, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class Conv2d(_ConvNd, nn.Conv2d): + r""" + A Conv2d module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.Conv2d`, please see + https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d + for documentation. + + Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: str | _size_2_t = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + qconfig=None, + device=None, + dtype=None, + ) -> None: + kernel_size_ = _pair(kernel_size) + stride_ = _pair(stride) + padding_ = padding if isinstance(padding, str) else _pair(padding) + dilation_ = _pair(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride=stride_, + padding=padding_, + dilation=dilation_, + transposed=False, + output_padding=_pair(0), + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + device=device, + dtype=dtype, + ) + + def forward(self, input): + return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class Conv3d(_ConvNd, nn.Conv3d): + r""" + A Conv3d module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.Conv3d`, please see + https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d + for documentation. + + Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: str | _size_3_t = 0, + dilation: _size_3_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + qconfig=None, + device=None, + dtype=None, + ) -> None: + kernel_size_ = _triple(kernel_size) + stride_ = _triple(stride) + padding_ = padding if isinstance(padding, str) else _triple(padding) + dilation_ = _triple(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride=stride_, + padding=padding_, + dilation=dilation_, + transposed=False, + output_padding=_triple(0), + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + device=device, + dtype=dtype, + ) + + def forward(self, input): + return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/embedding_ops.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..1f69e70abcf1d43c4a96ca15dae355c31f66a627 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/embedding_ops.py @@ -0,0 +1,251 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +__all__ = ["Embedding", "EmbeddingBag"] + + +class Embedding(nn.Embedding): + r""" + An embedding bag module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.Embedding`, please see + https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding + for documentation. + + Similar to `torch.nn.Embedding`, with FakeQuantize modules initialized to + default. + + Attributes: + weight: fake quant module for weight + """ + + _FLOAT_MODULE = nn.Embedding + + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + qconfig=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + # pyrefly: ignore [bad-argument-type] + **factory_kwargs, + ) + assert qconfig, "qconfig must be provided for QAT module" + assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, ( + "Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got " + + str(qconfig.weight().qscheme) + ) + self.qconfig = qconfig + self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) + + def forward(self, input) -> Tensor: + return F.embedding( + input, + self.weight_fake_quant(self.weight), + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module + + Args: `mod` a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + assert type(mod) is cls._FLOAT_MODULE, ( + " qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator] + assert weight_qscheme == torch.per_channel_affine_float_qparams, ( + "Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got " + + str(weight_qscheme) + ) + + qconfig = mod.qconfig + qat_embedding_bag = cls( + mod.num_embeddings, + mod.embedding_dim, + mod.padding_idx, + mod.max_norm, + mod.norm_type, + mod.scale_grad_by_freq, + mod.sparse, + mod.weight, + qconfig=qconfig, + ) + + return qat_embedding_bag + + def to_float(self): + embedding_bag = torch.nn.Embedding( + self.num_embeddings, + self.embedding_dim, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + None, + ) + embedding_bag.weight = torch.nn.Parameter(self.weight.detach()) + embedding_bag.train(self.training) + return embedding_bag + + +class EmbeddingBag(nn.EmbeddingBag): + r""" + An embedding bag module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.EmbeddingBag`, please see + https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag + for documentation. + + Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to + default. + + Attributes: + weight: fake quant module for weight + """ + + _FLOAT_MODULE = nn.EmbeddingBag + + def __init__( + self, + num_embeddings, + embedding_dim, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + mode="mean", + sparse=False, + _weight=None, + include_last_offset=False, + padding_idx=None, + qconfig=None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_embeddings, + embedding_dim, + max_norm, + norm_type, + scale_grad_by_freq, + mode, + sparse, + _weight, + include_last_offset, + padding_idx, + **factory_kwargs, + ) + assert qconfig, "qconfig must be provided for QAT module" + assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, ( + "Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got " + + str(qconfig.weight().qscheme) + ) + self.qconfig = qconfig + self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) + + def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor: + return F.embedding_bag( + input, + self.weight_fake_quant(self.weight), + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module + + Args: `mod` a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + assert type(mod) is cls._FLOAT_MODULE, ( + " qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator] + assert weight_qscheme == torch.per_channel_affine_float_qparams, ( + "Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got " + + str(weight_qscheme) + ) + + qconfig = mod.qconfig + qat_embedding_bag = cls( + mod.num_embeddings, + mod.embedding_dim, + mod.max_norm, + mod.norm_type, + mod.scale_grad_by_freq, + mod.mode, + mod.sparse, + mod.weight, + mod.include_last_offset, + mod.padding_idx, + qconfig=qconfig, + ) + + return qat_embedding_bag + + def to_float(self): + embedding_bag = torch.nn.EmbeddingBag( + self.num_embeddings, + self.embedding_dim, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + None, + self.include_last_offset, + self.padding_idx, + ) + embedding_bag.weight = torch.nn.Parameter(self.weight.detach()) + embedding_bag.train(self.training) + return embedding_bag diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/linear.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..5edf16ed3ea53d0323eda248b95703d5245b1786 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/modules/linear.py @@ -0,0 +1,97 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.nn.intrinsic import LinearReLU +from torch.nn.utils.parametrize import ( + is_parametrized, + transfer_parametrizations_and_params, + type_before_parametrizations, +) + + +__all__ = ["Linear"] + + +class Linear(nn.Linear): + r""" + A linear module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear + for documentation. + + Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to + default. + + Attributes: + weight: fake quant module for weight + """ + + _FLOAT_MODULE = nn.Linear + + def __init__( + self, + in_features, + out_features, + bias=True, + qconfig=None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(in_features, out_features, bias, **factory_kwargs) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) + + def forward(self, input): + return F.linear(input, self.weight_fake_quant(self.weight), self.bias) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module or qparams_dict + Args: `mod` a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + assert type_before_parametrizations(mod) == cls._FLOAT_MODULE, ( + " qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + if type_before_parametrizations(mod) == LinearReLU: + mod = mod[0] + + qconfig = mod.qconfig + qat_linear = cls( + mod.in_features, + mod.out_features, + bias=mod.bias is not None, + qconfig=qconfig, + ) + + if is_parametrized(mod, "weight"): + transfer_parametrizations_and_params(mod, qat_linear, "weight") + else: + qat_linear.weight = mod.weight + + if is_parametrized(mod, "bias"): + transfer_parametrizations_and_params(mod, qat_linear, "bias") + else: + qat_linear.bias = mod.bias + + return qat_linear + + def to_float(self): + linear = torch.nn.Linear( + self.in_features, self.out_features, self.bias is not None + ) + linear.weight = torch.nn.Parameter(self.weight.detach()) + if self.bias is not None: + linear.bias = torch.nn.Parameter(self.bias.detach()) + linear.train(self.training) + return linear diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7e908018e2774ade165669dbbc0c77b5d841b68 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69210e6d126cd8262179a2a3e2202021358abbaa Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b83664db976aa83097c76374b1033d9bcd311c9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60752e127fec2cc43dacc1c5d0f174040bdc7cfb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ed68428d0eaf9c8219b9247330bcbf1a07b9410 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6141a3cf3db6c364350ba37d7f1fc465141068a0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b61a42257fb14ccee741ab8a6933f6469a7eb3bc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/graph_passes.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/graph_passes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88f04b5fc307aaac5a7d67bd00ff5f3125b1cde7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/graph_passes.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89417724e474060974259c7bb938a95357cabec4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00fb64cbc21a406ade2fa9713586905a45aa99c7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80d0af289e29e2d9e4764fd04bc590efabe8d998 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/pattern_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/pattern_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efda7b7b6a88b55dbc27752a96100572dd7c49ee Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/pattern_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7581f8176c0f597a0fb43970789fd7a52faa45cf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e2ecb7b4bd41ade2df75537bc56446d2331f592 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80caf33f8d430206db3e8d0f09092cb3e6ccf43e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/graph_matcher.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/graph_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..4fdad3f2d9bc49094c0da3264012cc206c28ab86 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/graph_matcher.py @@ -0,0 +1,485 @@ +# mypy: allow-untyped-defs +import collections +import enum +from typing import Any + +import torch +from torch.ao.quantization import FakeQuantizeBase, ObserverBase +from torch.ao.quantization.utils import getattr_from_fqn +from torch.fx import GraphModule +from torch.fx.graph import Graph, Node + +from .mappings import get_base_name_to_sets_of_related_ops, get_unmatchable_types_map +from .ns_types import NSNodeTargetType, NSSubgraph +from .pattern_utils import ( + end_node_matches_reversed_fusion, + get_reversed_fusions, + get_type_a_related_to_b, +) + + +toq = torch.ops.quantized + + +def _get_output_nodes(g: Graph) -> list[Node]: + return [n for n in g.nodes if n.op == "output"] + + +class _NSGraphMatchableSubgraphsIterator: + """ + Iterates through the graph of gm, starting with the output nodes + and continuing backwards. + 1. Returns matchable subgraphs, in order. A subgraph is defined by + (start_node, end_node). + 2. Skips over non-matchable subgraphs + """ + + def __init__( + self, + gm: GraphModule, + non_matchable_functions: set[NSNodeTargetType], + non_matchable_modules: set[NSNodeTargetType], + non_matchable_methods: set[NSNodeTargetType], + ): + self.gm: GraphModule = gm + self.non_matchable_functions: set[NSNodeTargetType] = non_matchable_functions + self.non_matchable_modules: set[NSNodeTargetType] = non_matchable_modules + self.non_matchable_methods: set[NSNodeTargetType] = non_matchable_methods + self.seen_nodes: set[Node] = set() + self.stack: list[Node] = [] + for start_node in _get_output_nodes(self.gm.graph): + self.stack.append(start_node) + + def __iter__(self): + return self + + def __next__(self) -> NSSubgraph: + """ + Returns the next matchable subgraph. + """ + while len(self.stack) > 0: + cur_end_node = self.stack.pop() + if cur_end_node in self.seen_nodes: + continue + + # for subgraphs which are single nodes, start_node == end_node + # for subgraphs with more than one node, start node != end_node + cur_start_node = cur_end_node + # Subgraphs like linear-relu have the base node as the start node. + # Subgraphs like dequantize-linear-relu-to(torch.float16) have the + # base node as the second node. + # The cur_base_op_node var will move to the actual node during + # the fusion matching later in this code block. + cur_base_op_node = cur_end_node + + # Check for potential fusions. For now, we are greedy + # and always skip all non-base nodes of a fusion. For example, + # if we match linear-relu backwards, we will always skip the + # relu node and attempt to match the linear node. This can + # be made configurable later if needed. + for _reverse_fusion_ops, base_op_idx in get_reversed_fusions(): + is_match = end_node_matches_reversed_fusion( + cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes + ) + if is_match: + # navigate to the base node + for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1): + # pyrefly: ignore [bad-argument-type] + self.seen_nodes.add(cur_start_node) + # for now, assume that there are no other nodes + # which need to be added to the stack + cur_start_node = cur_start_node.args[0] # type: ignore[assignment] + # if the base op index matches the current node, set it + rev_base_op_idx = len(_reverse_fusion_ops) - 2 - base_op_idx + if rev_fusion_idx == rev_base_op_idx: + cur_base_op_node = cur_start_node + break + + # pyrefly: ignore [bad-argument-type] + self.seen_nodes.add(cur_start_node) + # add args of previous nodes to stack + # pyrefly: ignore [missing-attribute] + for arg in cur_start_node.all_input_nodes: + self._recursively_add_node_arg_to_stack(arg) + + # skip unmatchable nodes + # note: this check is done on the start_node, i.e. + # if we are matching linear-relu in reverse, this would do the matchable + # check on the linear + # pyrefly: ignore [bad-argument-type] + if not self._is_matchable(cur_base_op_node): + continue + + # If an observer or a fake_quant was not matched as a part of + # a pattern of multiple nodes, ignore it. One case where this is + # relevant is an observer on a graph input, which was added because + # it is necessary for the next node. + if cur_end_node.op == "call_module" and cur_start_node is cur_end_node: + maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target) # type: ignore[arg-type] + if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)): + continue + + return NSSubgraph( + # pyrefly: ignore [bad-argument-type] + start_node=cur_start_node, + end_node=cur_end_node, + # pyrefly: ignore [bad-argument-type] + base_op_node=cur_base_op_node, + ) + + raise StopIteration + + def _recursively_add_node_arg_to_stack(self, arg: Any) -> None: + """ + Adds all of the nodes in this arg to the stack, properly navigating + through list, dicts and tuples. + """ + if isinstance(arg, Node): + self.stack.append(arg) + elif ( + isinstance(arg, torch.fx.immutable_collections.immutable_list) + or type(arg) is tuple + ): + for inner_arg in arg: + self._recursively_add_node_arg_to_stack(inner_arg) + elif isinstance(arg, torch.fx.immutable_collections.immutable_dict): + for value in arg.values(): + self._recursively_add_node_arg_to_stack(value) + + def _is_matchable(self, node: Node) -> bool: + if node.op == "call_function": + return node.target not in self.non_matchable_functions + elif node.op == "call_module": + if not isinstance(node.target, str): + raise AssertionError(f"Expected str, got {type(node.target)}") + target_mod = getattr_from_fqn(self.gm, node.target) + return not any( + isinstance(target_mod, t) # type: ignore[arg-type] + for t in self.non_matchable_modules + ) + elif node.op == "call_method": + return node.target not in self.non_matchable_methods + else: + return False + + +class GraphMatchingException(Exception): + """ + Exception raised when two graphs cannot be matched. + """ + + +class SubgraphTypeRelationship(enum.Enum): + # same type, known + # example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d + EQUAL = enum.auto() + # same type, but the type is not known to Numerical Suite + # (user defined type, etc). + EQUAL_BUT_UKNOWN = enum.auto() + # known, same subgraph_relationship set, but not the same type + # example: F.linear and toq.linear + RELATED_BUT_NOT_EQUAL = enum.auto() + # not related + NOT_RELATED = enum.auto() + + +def _get_subgraph_relationship_type( + subgraph_a: NSSubgraph, + subgraph_b: NSSubgraph, + gm_a: GraphModule, + gm_b: GraphModule, + type_a_related_to_b: set[tuple[NSNodeTargetType, NSNodeTargetType]], +) -> SubgraphTypeRelationship: + node_a = subgraph_a.base_op_node + node_b = subgraph_b.base_op_node + + # TODO(next): make this code handle matching by what is before the base op + if node_a.op != node_b.op: + if not ( + node_a.op in ("call_function", "call_method") + and node_b.op in ("call_function", "call_method") + ): + return SubgraphTypeRelationship.NOT_RELATED + + if node_a.op in ("call_function", "call_method"): + key = (node_a.target, node_b.target) + + if key not in type_a_related_to_b: + if node_a.target == node_b.target: + return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN + else: + return SubgraphTypeRelationship.NOT_RELATED + # after this point, we are dealing with known types + + if node_a.target == node_b.target: + node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node + node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node + if node_a_has_prev and (not node_b_has_prev): + return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL + elif (not node_a_has_prev) and node_b_has_prev: + return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL + elif (not node_a_has_prev) and (not node_b_has_prev): + return SubgraphTypeRelationship.EQUAL + else: + # TODO(future PR): check for matches start_op_node and base_op_node + return SubgraphTypeRelationship.EQUAL + + if key in type_a_related_to_b: + return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL + else: + return SubgraphTypeRelationship.NOT_RELATED + elif node_a.op == "call_module": + if ( + subgraph_a.base_op_node != subgraph_a.start_node + or subgraph_b.base_op_node != subgraph_b.start_node + ): + raise AssertionError( + "Matching call_module patterns where base_op_node != start_node is not supported yet" + ) + # for call_module, we need to look up the modules to do the type check + if not isinstance(node_a.target, str): + raise AssertionError(f"Expected str, got {type(node_a.target)}") + mod_a = getattr_from_fqn(gm_a, node_a.target) + if not isinstance(node_b.target, str): + raise AssertionError(f"Expected str, got {type(node_b.target)}") + mod_b = getattr_from_fqn(gm_b, node_b.target) + + key = (type(mod_a), type(mod_b)) + + if key not in type_a_related_to_b: + if type(mod_a) is type(mod_b): + return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN + else: + return SubgraphTypeRelationship.NOT_RELATED + elif type(mod_a) is type(mod_b): + return SubgraphTypeRelationship.EQUAL + else: + return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL + + return SubgraphTypeRelationship.NOT_RELATED + + +def _get_name_for_subgraph( + subgraph_a: NSSubgraph, + gm_a: GraphModule, + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]], + existing_names: set[str], +) -> str: + """ + Returns a unique name for a subgraph. This name is based on two things: + 1. the name of the set containing the underlying type of the base op in the + subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op) + 2. the number of previous subgraphs with related underlying type of the base op + + For example, in the graph + + linear0 -> relu0 -> linear1 -> relu1 + + The subgraphs are (linear0, relu0) and (linear1, relu1). If we iterate + from the output node backwards, the name given to (linear1, relu1) will be + `base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0) + will be `base_op_torch.nn.functional.linear_1`. + + Why are we not just using the node name? Answer: because of two requirements: + A. fusions must be supported + B. some Numeric Suite APIs can be called without having all of the models in memory + + For example, let's say we need to match nodes of + + (1) ... -> linear0 -> relu0 -> ... + + And + + (2) ... -> linear_relu0 -> ... + + Without being able to inspect them together. With the current naming scheme, if + we iterate through both of these graphs in the same order, and assuming the rest + of the graphs match, both of these subgraphs will get the same name without + (1) and (2) knowing anything about each other. + """ + target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a) + target_base_type = None + for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items(): + if target_type in sets_of_related_ops: + target_base_type = base_name + target_base_name = "base_op_" + str(target_base_type) + counter = 0 + proposed_name = target_base_name + "_" + str(counter) + while proposed_name in existing_names: + counter += 1 + proposed_name = target_base_name + "_" + str(counter) + existing_names.add(proposed_name) + return proposed_name + + +def _get_node_target_type(node: Node, gm: GraphModule) -> NSNodeTargetType | None: + if node.op in ("call_function", "call_method"): + return node.target + elif node.op == "call_module": + if not isinstance(node.target, str): + raise AssertionError(f"Expected str, got {type(node.target)}") + mod = getattr_from_fqn(gm, node.target) + return type(mod) + return None + + +def get_matching_subgraph_pairs( + gm_a: GraphModule, + gm_b: GraphModule, + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]] | None = None, + unmatchable_types_map: dict[str, set[NSNodeTargetType]] | None = None, +) -> dict[str, tuple[NSSubgraph, NSSubgraph]]: + """ + Matches matchable subgraphs of graph_a to graph_b. + + For a node, "matchable" is defined as a node which is not an observer, + fake_quants, quant or dequant. + + A subgraph can contain one or more nodes. A subgraph is matchable if + at least one node inside of it is matchable. Currently, all nodes in + a subgraph must be matchable (because we assume no observers will be + inserted in the middle of a fusion). + + A subgraph is defined by (start_node, end_node). We assume that only + start_node and end_node are linked with the surrounding graph, all other + nodes in a subgraph are self-contained. + + A pair of nodes is "related" if both nodes represent the same mathematical + operation across different quantization flavors. For example, + `F.linear` and `torch.ops.quantized.linear` are related, and + `F.linear` and `torch.nn.Conv` are not related. + + For each matchable pair of nodes node_a and node_b, they will match + if node_a and node_b are related. + + For graphs A and B, they will match iff: + 1. the number of matchable subgraphs in A and B is equivalent + 2. when iterating through the matchable subgraphs of A and B in the same order, each + corresponding pair of base nodes is related. + + This enables us to find the corresponding subgraphs between + graphs of related models. For example, if we had two graphs such as: + + graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1 + w -/ + b -/ + + graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1 + packed_params_0 -/ + + This function will return the following result: + { + 'conv_0': ( # the name of the node in graph_b + (conv_0, conv_0), # (start_node_a, end_node_a) + (qconv_0, qconv_0), # (start_node_b, end_node_b) + ), + } + + Or, if we have a fusion pattern, + + graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1 + w -/ + b -/ + + graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1 + packed_params_0 -/ + + This function will return the following result: + { + 'linear_relu_0': ( # the name of the node in graph_b + (linear_0, relu_0), # (start_node_a, end_node_a) + (linear_relu_0, linear_relu_0), # (start_node_b, end_node_b) + ), + } + """ + if unmatchable_types_map is None: + unmatchable_types_map = get_unmatchable_types_map() + non_matchable_functions = unmatchable_types_map["funs_unmatchable"] + non_matchable_modules = unmatchable_types_map["mods_unmatchable"] + non_matchable_methods = unmatchable_types_map["meths_unmatchable"] + + graph_a_iterator = _NSGraphMatchableSubgraphsIterator( + gm_a, non_matchable_functions, non_matchable_modules, non_matchable_methods + ) + graph_b_iterator = _NSGraphMatchableSubgraphsIterator( + gm_b, non_matchable_functions, non_matchable_modules, non_matchable_methods + ) + results = collections.OrderedDict() + if base_name_to_sets_of_related_ops is None: + base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() + type_a_related_to_b = get_type_a_related_to_b(base_name_to_sets_of_related_ops) + + existing_names_a: set[str] = set() + existing_names_b: set[str] = set() + + while True: + # fetch the next subgraphs from a and b + cur_subgraph_a, cur_subgraph_b = None, None + try: + cur_subgraph_a = next(graph_a_iterator) + except StopIteration: + pass + try: + cur_subgraph_b = next(graph_b_iterator) + except StopIteration: + pass + + # look up types of a and b for useful error messages + type_start_a, type_start_b = None, None + if cur_subgraph_a is not None: + type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a) + if cur_subgraph_b is not None: + type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b) + + # check for results and determine what to do next + if cur_subgraph_a is not None and cur_subgraph_b is not None: + # both nodes were fetched, check for subgraph_relationship + # note: subgraph_relationship is checked on the start node, i.e. + # if a linear-relu pattern is checked, we would check for subgraph_relationship + # of the linear + subgraph_relationship = _get_subgraph_relationship_type( + cur_subgraph_a, cur_subgraph_b, gm_a, gm_b, type_a_related_to_b + ) + if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED: + msg = f""" +The subgraphs +({cur_subgraph_a}, {type_start_a}) and +({cur_subgraph_b}, {type_start_b}) +are not related. Please ensure that the two models you pass in have the same number +of subgraphs, and each pair of subgraphs is related to each other.""" + raise GraphMatchingException(msg) + elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN: + # skip matching but unknown types + continue + key_name_a = _get_name_for_subgraph( + cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops, existing_names_a + ) + key_name_b = _get_name_for_subgraph( + cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops, existing_names_b + ) + if key_name_a != key_name_b: + raise AssertionError( + f"Subgraph names {key_name_a} and {key_name_b} do not match" + ) + results[key_name_a] = (cur_subgraph_a, cur_subgraph_b) + continue + elif cur_subgraph_a is None and cur_subgraph_b is None: + # we reached the end of both graphs + break + else: + # only one node was fetched, no match possible, throw error + msg = f""" +Attempting to match +({cur_subgraph_a}, {type_start_a}) and +({cur_subgraph_b}, {type_start_b}), +one of which is empty. Please ensure that the two models you pass in have the same number +of subgraphs.""" + raise GraphMatchingException(msg) + + # The subgraph pairs are originally created by traversing the two graphs + # from the outputs to the inputs. Reverse the results to return the + # subgraphs in their order of execution. + results = collections.OrderedDict(reversed(results.items())) + + # pyrefly: ignore [bad-return] + return results diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/graph_passes.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/graph_passes.py new file mode 100644 index 0000000000000000000000000000000000000000..338db28ce41d96ec5d3de38591f5937543d65394 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/graph_passes.py @@ -0,0 +1,1155 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable +from typing import Any + +import torch +from torch.ao.ns.fx.mappings import get_node_type_to_io_type_map +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix +from torch.ao.quantization.observer import _is_activation_post_process +from torch.fx import GraphModule, map_arg +from torch.fx.graph import Graph, Node + +from .ns_types import NSNodeTargetType, NSSingleResultValuesType, NSSubgraph +from .utils import ( + get_arg_indices_of_inputs_to_log, + get_node_first_input_and_output_type, + get_node_input_qparams, + get_normalized_nth_input, + get_number_of_non_param_args, + get_target_type_str, + getattr_from_fqn, + NodeInputOrOutputType, + op_type_supports_shadowing, + return_first_non_observer_node, +) + + +def _maybe_get_fqn(node: Node, gm: GraphModule) -> str | None: + fqn = None + if hasattr(gm, "_node_name_to_scope"): + # fqn on observers is not present, because they do not + # exist when the fqns are created during tracing. If this is + # an observer, get the fqn of the node being observed. + node_to_use_for_fqn = node + if node.op == "call_module": + if not isinstance(node.target, str): + raise AssertionError(f"Expected str, got {type(node.target)}") + module = getattr_from_fqn(gm, node.target) + if _is_activation_post_process(module): + node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0) + fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index] + return fqn # type: ignore[return-value] + + +def _insert_logger_after_node( + node: Node, + gm: GraphModule, + logger_cls: Callable, + logger_node_name_suffix: str, + ref_node_name: str, + model_name: str, + ref_name: str, + ref_node_target_type: str, + results_type: str, + index_within_arg: int, + index_of_arg: int, + fqn: str | None, +) -> Node: + """ + Given a starting graph of + + prev_node -> node -> next_node + + This function creates a new logger_cls obj and adds it + after node, resulting in + + prev_node -> node -> logger_obj -> next_node + """ + # create new name + logger_node_name = get_new_attr_name_with_prefix( + node.name + logger_node_name_suffix + )(gm) + target_type = get_target_type_str(node, gm) + # create the logger object + logger_obj = logger_cls( + ref_node_name, + node.name, + model_name, + ref_name, + target_type, + ref_node_target_type, + results_type, + index_within_arg, + index_of_arg, + fqn, + ) + # attach the logger object to the parent module + setattr(gm, logger_node_name, logger_obj) + logger_node = node.graph.create_node("call_module", logger_node_name, (node,), {}) + return logger_node + + +def add_loggers_to_model( + gm: GraphModule, + node_to_instrument_inputs_to_ref_node_name: dict[Node, tuple[str, str]], + node_to_instrument_outputs_to_ref_node_name: dict[Node, tuple[str, str]], + logger_cls: Callable, + model_name: str, +) -> GraphModule: + """ + Takes the graph of gm, adds loggers to the output + of each node in nodes_to_instrument. Returns a GraphModule with the new + graph. + """ + + new_graph = Graph() + env: dict[str, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node.name]) + + for node in gm.graph.nodes: + if node.op == "output": + new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg)) + continue + + if (node in node_to_instrument_inputs_to_ref_node_name) or ( + node in node_to_instrument_outputs_to_ref_node_name + ): + fqn = _maybe_get_fqn(node, gm) + + if node in node_to_instrument_inputs_to_ref_node_name: + ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[ + node + ] + # Ops such add and mul are special because either + # one or two of the first two arguments can be tensors, + # and if one argument is a tensor it can be first or + # second (x + 1 versus 1 + x). + arg_indices_to_log = get_arg_indices_of_inputs_to_log(node) + for node_arg_idx in arg_indices_to_log: + node_arg = get_normalized_nth_input(node, gm, node_arg_idx) + if type(node_arg) is Node: + # create a single input logger + prev_node = env[node_arg.name] + env[node_arg.name] = _insert_logger_after_node( + prev_node, + gm, + logger_cls, + "_ns_logger_", + node.name, + model_name, + ref_name, + ref_node_type, + NSSingleResultValuesType.NODE_INPUT.value, + index_within_arg=0, + index_of_arg=node_arg_idx, + fqn=fqn, + ) + elif ( + type(node_arg) is torch.fx.immutable_collections.immutable_list + ): + # create N input loggers, one for each node + for arg_idx, arg in enumerate(node_arg): # type: ignore[var-annotated, arg-type] + prev_node = env[arg.name] + env[prev_node.name] = _insert_logger_after_node( + prev_node, + gm, + logger_cls, + "_ns_logger_", + node.name, + model_name, + ref_name, + ref_node_type, + NSSingleResultValuesType.NODE_INPUT.value, + index_within_arg=arg_idx, + index_of_arg=node_arg_idx, + fqn=fqn, + ) + + # ensure env is populated with base node + # Note: runs for both inputs and outputs + env[node.name] = new_graph.node_copy(node, load_arg) + + if node in node_to_instrument_outputs_to_ref_node_name: + ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[ + node + ] + # add the logger after the base node + env[node.name] = _insert_logger_after_node( + env[node.name], + gm, + logger_cls, + "_ns_logger_", + node.name, + model_name, + ref_name, + ref_node_type, + NSSingleResultValuesType.NODE_OUTPUT.value, + index_within_arg=0, + index_of_arg=0, + fqn=fqn, + ) + + else: + env[node.name] = new_graph.node_copy(node, load_arg) + + new_gm = GraphModule(gm, new_graph) + return new_gm + + +def _insert_quantize_per_tensor_node( + prev_node_c: Node, + node_a: Node, + gm_b: GraphModule, + graph_c: Graph, + scale: torch.Tensor | float, + zero_point: torch.Tensor | int, + dtype_cast_name: str, +) -> Node: + # copy scale + scale_node_name = get_new_attr_name_with_prefix(node_a.name + "_input_scale_")(gm_b) + setattr(gm_b, scale_node_name, scale) + scale_node = graph_c.create_node( + "get_attr", scale_node_name, (), {}, scale_node_name + ) + # copy zero_point + zero_point_node_name = get_new_attr_name_with_prefix( + node_a.name + "_input_zero_point_" + )(gm_b) + setattr(gm_b, zero_point_node_name, zero_point) + zero_point_node = graph_c.create_node( + "get_attr", zero_point_node_name, (), {}, zero_point_node_name + ) + # create the quantize_per_tensor call + return graph_c.create_node( + "call_function", + torch.quantize_per_tensor, + (prev_node_c, scale_node, zero_point_node, torch.quint8), + {}, + dtype_cast_name, + ) + + +def _insert_dtype_cast_after_node( + node_a: Node, + node_c: Node, + prev_node_c: Node | list[Node], + gm_a: GraphModule, + gm_b: GraphModule, + graph_c: Graph, + node_name_prefix: str, + logger_cls: Callable, + node_type_to_io_type_map: dict[str, set[NSNodeTargetType]], +) -> Node | list[Node]: + """ + Given a starting graph C (derived from graph B) of + + ... -> prev_node_c -> node_c -> ... + + And a corresponding related node_a, inserts the correct dtype + cast node after prev_node_c to cast into the dtype expected + by node_a, resulting in: + + dtype_cast + / + ... -> prev_node_c -> node_c -> ... + + For example, if node_c is an int8 op and node_a is an fp32 op, this function + will insert a dequant. + """ + dtype_cast_op = None + dtype_cast_mod_cls = None + dtype_cast_method = None + dtype_cast_method_dtype = None + dtype_cast_scale = None + dtype_cast_zero_point = None + node_input_type_a, _node_output_type_a = get_node_first_input_and_output_type( + node_a, gm_a, logger_cls, node_type_to_io_type_map + ) + node_input_type_c, _node_output_type_c = get_node_first_input_and_output_type( + node_c, gm_b, logger_cls, node_type_to_io_type_map + ) + + if ( + ( + node_input_type_a == NodeInputOrOutputType.FP32 + and node_input_type_c == NodeInputOrOutputType.INT8 + ) + or ( + node_input_type_a == NodeInputOrOutputType.FP32 + and node_input_type_c == NodeInputOrOutputType.FP16 + ) + or + # TODO(future PR): determine the actual dtype of node_c, + # the current code only works because dequantize works with + # multiple input dtypes. + ( + node_input_type_a == NodeInputOrOutputType.FP32 + and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8 + ) + ): + dtype_cast_op = torch.dequantize + elif ( + node_input_type_a == node_input_type_c + and node_input_type_a != NodeInputOrOutputType.UNKNOWN + ): + dtype_cast_mod_cls = torch.nn.Identity + elif ( + node_input_type_a == NodeInputOrOutputType.INT8 + and node_input_type_c == NodeInputOrOutputType.FP32 + ): + # int8 shadows fp32, the dtype cast needs to quantize to int8 + # with the right qparams. + node_a_input_qparams = get_node_input_qparams( + node_a, gm_a, node_type_to_io_type_map + ) + if node_a_input_qparams is not None: + dtype_cast_op = torch.quantize_per_tensor # type: ignore[assignment] + dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams + elif ( + node_input_type_a == NodeInputOrOutputType.FP16 + and node_input_type_c == NodeInputOrOutputType.FP32 + ): + dtype_cast_method = "to" + dtype_cast_method_dtype = torch.float16 + else: + raise AssertionError( + f"dtype cast from {node_input_type_c} {node_c.format_node()} to " + + f"{node_input_type_a} {node_a.format_node()} needs to be implemented" + ) + + if isinstance(prev_node_c, Node): + new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b) + if dtype_cast_op: + if dtype_cast_scale is not None and dtype_cast_zero_point is not None: + return _insert_quantize_per_tensor_node( + prev_node_c, + node_a, + gm_b, + graph_c, + dtype_cast_scale, + dtype_cast_zero_point, + new_dtype_cast_name, + ) + else: + return graph_c.create_node( + "call_function", + dtype_cast_op, + (prev_node_c,), + {}, + new_dtype_cast_name, + ) + elif dtype_cast_method: + return graph_c.create_node( + "call_method", + dtype_cast_method, + (prev_node_c, dtype_cast_method_dtype), + {}, + new_dtype_cast_name, + ) + else: + if not dtype_cast_mod_cls: + raise AssertionError("Expected dtype_cast_mod_cls to be not None") + dtype_cast_mod = dtype_cast_mod_cls() + setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) + return graph_c.create_node( + "call_module", + new_dtype_cast_name, + (prev_node_c,), + {}, + new_dtype_cast_name, + ) + elif isinstance(prev_node_c, list): + results = [] + for prev_node_c_inner in prev_node_c: + new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b) + if dtype_cast_op: + # TODO(future PR): add handling for quantize_per_tensor + new_dtype_cast_node = graph_c.create_node( + "call_function", + dtype_cast_op, + (prev_node_c_inner,), + {}, + new_dtype_cast_name, + ) + results.append(new_dtype_cast_node) + else: + if not dtype_cast_mod_cls: + raise AssertionError("Expected dtype_cast_mod_cls to be not None") + dtype_cast_mod = dtype_cast_mod_cls() + setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) + new_dtype_cast_node = graph_c.create_node( + "call_module", + new_dtype_cast_name, + (prev_node_c_inner,), + {}, + new_dtype_cast_name, + ) + results.append(new_dtype_cast_node) + return results + else: + raise AssertionError(f"type f{type(prev_node_c)} is not handled") + + +# TODO(future PR): look into using copy_node API instead +def _copy_node_from_a_to_c( + node_a: Node, + gm_a: GraphModule, + gm_b: GraphModule, + graph_c: Graph, +) -> Node: + """ + Simple copy of node_a to graph_c. + """ + if node_a.op == "get_attr": + node_a_copy_name = get_new_attr_name_with_prefix(node_a.name + "_shadow_copy_")( + gm_b + ) + node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type] + if torch.is_tensor(node_a_obj): + node_a_obj = node_a_obj.detach() + setattr(gm_b, node_a_copy_name, node_a_obj) + node_a_copy = graph_c.create_node( + node_a.op, node_a_copy_name, (), {}, node_a_copy_name + ) + return node_a_copy + elif node_a.op == "call_method": + if node_a.target not in ("dequantize", "to"): + raise AssertionError(f"target {node_a.target} is not implemented") + if node_a.target == "dequantize": + arg_copy = _copy_node_from_a_to_c( + get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c + ) # type: ignore[arg-type] + node_a_copy_name = get_new_attr_name_with_prefix( + node_a.name + "_shadow_copy_" + )(gm_b) + node_a_copy = graph_c.create_node( + node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name + ) + return node_a_copy + else: # to + arg_copy = _copy_node_from_a_to_c( + get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c + ) # type: ignore[arg-type] + node_a_copy_name = get_new_attr_name_with_prefix( + node_a.name + "_shadow_copy_" + )(gm_b) + node_a_copy = graph_c.create_node( + node_a.op, + node_a.target, + (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)), + {}, + node_a_copy_name, + ) + return node_a_copy + + else: + raise AssertionError( + f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented" + ) + + +def _can_insert_copy_of_subgraph_a( + subgraph_a: NSSubgraph, + gm_a: GraphModule, + num_non_param_args_node_a: int, +) -> bool: + """ + This function returns `False` if the input subgraph cannot be copied by + `_insert_copy_of_subgraph_a_after_input_node_c`. This usually means + that there is a corner case logic for which copy is not yet implemented. + """ + # populate the list of nodes we need to check + nodes = [] + cur_node = subgraph_a.end_node + while cur_node != subgraph_a.start_node: + nodes.append(cur_node) + cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment] + nodes.append(cur_node) + nodes.reverse() + + def _can_insert(node_a_arg, gm_a): + if isinstance(node_a_arg, Node): + arg_a = return_first_non_observer_node(node_a_arg, gm_a) + if arg_a.op == "call_method": + return arg_a.target in ("dequantize", "to") + elif arg_a.op == "get_attr": + return True + else: + return False + elif isinstance(node_a_arg, (list, tuple)): + for el in node_a_arg: + if not isinstance(el, Node): + return False + return True + + # For each node, check if we handle the copy behavior. This follows the + # logic in `_insert_copy_of_subgraph_a_after_input_node_c`. + for node_a in nodes: + local_num_non_param_args_node_a = ( + num_non_param_args_node_a if node_a is nodes[0] else 1 + ) + + norm_args_kwargs = node_a.normalized_arguments( + gm_a, normalize_to_only_use_kwargs=True + ) + if norm_args_kwargs is not None: + norm_args, norm_kwargs = norm_args_kwargs + else: + norm_args, norm_kwargs = node_a.args, node_a.kwargs + + cur_idx = 0 + + while cur_idx < len(norm_args): + if cur_idx == 0: + pass + elif cur_idx == 1 and local_num_non_param_args_node_a == 2: + pass + else: + if not _can_insert(norm_args[cur_idx], gm_a): + return False + cur_idx += 1 + + for kwarg_val in norm_kwargs.values(): + # stitch the inputs from base graph + if cur_idx == 0: + pass + elif cur_idx == 1 and local_num_non_param_args_node_a == 2: + pass + else: + if not _can_insert(kwarg_val, gm_a): + return False + cur_idx += 1 + + return True + + +def _insert_copy_of_subgraph_a_after_input_node_c( + input_node_c: Node | list[Node], + input_node_c_2: Node | list[Node] | None, + subgraph_a: NSSubgraph, + gm_a: GraphModule, + gm_b: GraphModule, + node_name_prefix: str, +) -> Node: + """ + TODO(before land): real docblock + """ + if not isinstance(input_node_c, (Node, list)): + raise AssertionError(f"Expected Node or list, got {type(input_node_c)}") + + # create a sequential list of the subgraphs' nodes from start to end, + # because we need to add the nodes to graph C in non-reverse order + nodes_of_a = [subgraph_a.end_node] + cur_node = subgraph_a.end_node + while cur_node != subgraph_a.start_node: + cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment] + nodes_of_a.insert(0, cur_node) + + # go through nodes of a in order, and insert them into the graph of c + # sequentially + cur_node_a = nodes_of_a[0] + cur_node_c = _insert_copy_of_node_a_after_input_node_c( + input_node_c, input_node_c_2, cur_node_a, gm_a, gm_b, node_name_prefix + ) + for cur_idx_a in range(1, len(nodes_of_a)): + cur_node_a = nodes_of_a[cur_idx_a] + prev_node_c = cur_node_c # previous added node is the input to next node + cur_node_c = _insert_copy_of_node_a_after_input_node_c( + prev_node_c, + # TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph + None, + cur_node_a, + gm_a, + gm_b, + node_name_prefix, + ) + # return the last inserted node + return cur_node_c + + +def _insert_copy_of_node_a_after_input_node_c( + input_node_c: Node | list[Node], + input_node_c_2: Node | list[Node] | None, + node_a: Node, + gm_a: GraphModule, + gm_b: GraphModule, + node_name_prefix: str, +) -> Node: + """ + Assume that node_a from graph_a has + args (input, (input2)?, arg1, ...), and + kwargs {kw0: kwarg0, ...} + + Note: input2 is optional. If it equals to None, we assume that the op + has a single non-param input. If it is specified, we assume that the op + has two non-param inputs. + + Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b, + and creates the corresponding nodes in graph_c. Note: observers are ignored, + so if an arg is an observer we navigate up until we find a non-observer parent. + + If node_a is a call_module, points the module pointed to by node_a to gm_b. + + Creates the copy of node_a in graph_c, with input as the first arg, + and all other args and kwargs pointing to the copies of the objects + in gm_b created above. + + An example in pictures: + + graph A: + ======== + + input -------------> node_a + / / / + (input_2)?----------/ / / + / / + weight -> weight_obs / + / + bias ---------------- + + graph C (derived from B): + ========================= + + input_node_c --> node_a_copy + / / / + (input_node_c_2)? / / + / / + weight_copy ----/ / + / + bias_copy ------/ + """ + if isinstance(input_node_c, Node): + graph_c = input_node_c.graph + else: + if not isinstance(input_node_c, list): + raise AssertionError(f"Expected list, got {type(input_node_c)}") + graph_c = input_node_c[0].graph + + norm_args_kwargs = node_a.normalized_arguments( + gm_a, normalize_to_only_use_kwargs=True + ) + if norm_args_kwargs is not None: + norm_args, norm_kwargs = norm_args_kwargs + else: + norm_args, norm_kwargs = node_a.args, node_a.kwargs + + new_args = [] + new_kwargs = {} + + def _copy_arg(arg): + # copy the other inputs from the other graph + if isinstance(arg, Node): + arg = return_first_non_observer_node(arg, gm_a) + arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c) + return arg + elif isinstance(arg, (int, float, torch.dtype)): + return arg + elif isinstance(kwarg_val, (list, tuple)): + for el in kwarg_val: + if isinstance(el, Node): + raise AssertionError( + "handling of Node inside list is not implemented" + ) + return arg + else: + raise AssertionError( + f"handling for kwarg of type {type(kwarg_val)} is not implemented" + ) + + cur_idx = 0 + + while cur_idx < len(norm_args): + if cur_idx == 0: + new_arg = input_node_c + elif cur_idx == 1 and input_node_c_2 is not None: + new_arg = input_node_c_2 + else: + new_arg = _copy_arg(norm_args[cur_idx]) + new_args.append(new_arg) + cur_idx += 1 + + for kwarg_name, kwarg_val in norm_kwargs.items(): + # stitch the inputs from base graph + if cur_idx == 0: + new_kwargs[kwarg_name] = input_node_c + elif cur_idx == 1 and input_node_c_2 is not None: + new_kwargs[kwarg_name] = input_node_c_2 + else: + new_kwargs[kwarg_name] = _copy_arg(kwarg_val) + cur_idx += 1 + + new_args = tuple(new_args) # type: ignore[assignment] + + node_a_shadows_c_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b) + + if node_a.op == "call_module": + # if target is a module, we point to the module from gm_b + new_mod_copy_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b) + # fetch the corresponding module from gm_a + if not isinstance(node_a.target, str): + raise AssertionError(f"Expected str, got {type(node_a.target)}") + mod_a = getattr_from_fqn(gm_a, node_a.target) + setattr(gm_b, new_mod_copy_name, mod_a) + node_a_shadows_c = graph_c.create_node( + node_a.op, + new_mod_copy_name, + new_args, # type: ignore[arg-type] + new_kwargs, # type: ignore[arg-type] + node_a_shadows_c_name, + ) + return node_a_shadows_c + else: + if node_a.op not in ("call_function", "call_method"): + raise AssertionError(f"Unexpected op: {node_a.op}") + node_a_shadows_c = graph_c.create_node( + node_a.op, + node_a.target, + new_args, # type: ignore[arg-type] + new_kwargs, # type: ignore[arg-type] + node_a_shadows_c_name, + ) + return node_a_shadows_c + + +def create_a_shadows_b( + name_a: str, + gm_a: GraphModule, + name_b: str, + gm_b: GraphModule, + matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]], + logger_cls: Callable, + should_log_inputs: bool, + node_type_to_io_type_map: dict[str, set[NSNodeTargetType]] | None = None, +) -> GraphModule: + """ + Creates a new GraphModule consisting of the graph of C, with the meaningful + nodes of A shadowing the corresponding nodes of B. For example, + + Graph A: + a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2 + + Graph B: + b0 -> op0_int8 -> b1 -> op1_int8 -> b2 + + matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)} + + Graph C (A shadows B): + + / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1 + / / + b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1 + + In a nutshell, this function does the following for each node pair: + * copies the necessary attributes and modules from gm_a to gm_b, + keeping names unique + * adds a dtype cast op (dequant, quant, etc) + * adds a copy of node_a in gm_b's graph + * adds loggers to the outputs of node_a and node_b + """ + + if node_type_to_io_type_map is None: + node_type_to_io_type_map = get_node_type_to_io_type_map() + + # graph_c is the graph created from copying the nodes of graph_b and inserting + # the shadows with the nodes copied from graph_a + graph_c = Graph() + env_c: dict[str, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env_c[node.name]) + + start_node_b_to_matched_subgraph_a_and_name = {} + end_node_b_to_matched_subgraph_a_and_name = {} + for match_name, match in matched_subgraph_pairs.items(): + subgraph_a, subgraph_b = match + ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a) + ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b) + start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = ( + subgraph_a, + match_name, + ref_node_type_a, + ref_node_type_b, + ) + end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = ( + subgraph_a, + match_name, + ref_node_type_a, + ref_node_type_b, + ) + + for node_b in gm_b.graph.nodes: + if node_b.op == "output": + graph_c.output(map_arg(node_b.args[0], load_arg)) + continue + + # calculate the flags to determine what to do with this node + node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name + node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name + + if node_b_is_start_node or node_b_is_end_node: + if node_b_is_start_node: + ( + subgraph_a, + ref_name, + ref_node_type_a, + ref_node_type_b, + ) = start_node_b_to_matched_subgraph_a_and_name[node_b] + else: + if not node_b_is_end_node: + raise AssertionError("Expected node_b_is_end_node to be not false") + ( + subgraph_a, + ref_name, + ref_node_type_a, + ref_node_type_b, + ) = end_node_b_to_matched_subgraph_a_and_name[node_b] + + all_op_types_support_shadowing = op_type_supports_shadowing( + subgraph_a.start_node + ) and op_type_supports_shadowing(node_b) + if not all_op_types_support_shadowing: + print( + f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}" + + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}" + + ", unsupported" + ) + env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) + continue + + # For both start_node and end_node verify that we know how to do + # the dtype cast. If we do not, skip. + ( + node_input_type_a, + node_output_type_a, + ) = get_node_first_input_and_output_type( + subgraph_a.start_node, gm_a, logger_cls, node_type_to_io_type_map + ) + ( + node_input_type_b, + node_output_type_b, + ) = get_node_first_input_and_output_type( + node_b, gm_b, logger_cls, node_type_to_io_type_map + ) + node_io_types_known_a_and_b = ( + node_input_type_a != NodeInputOrOutputType.UNKNOWN + and node_output_type_a != NodeInputOrOutputType.UNKNOWN + and node_input_type_b != NodeInputOrOutputType.UNKNOWN + and node_output_type_b != NodeInputOrOutputType.UNKNOWN + ) + if not node_io_types_known_a_and_b: + print( + f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}" + + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}" + + ", unknown dtype cast" + ) + env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) + continue + + # If we are shadowing from fp32 to int8, we need to insert + # quantize_per_tensor call with qparams from the previous node. + # Only do this if we are able to infer these qparams from the graph. + if ( + node_input_type_a == NodeInputOrOutputType.INT8 + and node_input_type_b == NodeInputOrOutputType.FP32 + ): + node_a_input_qparams = get_node_input_qparams( + subgraph_a.start_node, gm_a, node_type_to_io_type_map + ) + if not node_a_input_qparams: + print( + f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}" + + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}" + + ", unknown input qparams" + ) + env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) + continue + + num_non_param_args_node_a = get_number_of_non_param_args( + subgraph_a.start_node, gm_a + ) + if not _can_insert_copy_of_subgraph_a( + subgraph_a, gm_a, num_non_param_args_node_a + ): + print( + f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}" + + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}" + + ", unhandled logic in subgraph copy" + ) + env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) + continue + + fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a) + fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b) # type: ignore[possibly-undefined] + + if node_b_is_start_node: + # if necessary, log the input of node_c + if should_log_inputs: + prev_node_b = get_normalized_nth_input(node_b, gm_b, 0) + if isinstance(prev_node_b, Node): + prev_node_c = env_c[prev_node_b.name] + env_c[prev_node_c.name] = _insert_logger_after_node( + prev_node_c, + gm_b, + logger_cls, + "_ns_logger_b_inp_", + node_b.name, + name_b, + ref_name, + ref_node_type_b, + NSSingleResultValuesType.NODE_INPUT.value, + index_within_arg=0, + index_of_arg=0, + fqn=fqn_base_b, + ) + elif isinstance(prev_node_b, list): + # first, save the prev_node instances, because they + # will be overwritten in the env after the first logger + # is added + prev_node_c_list = [env_c[arg.name] for arg in prev_node_b] + + for arg_idx, prev_node_c in enumerate(prev_node_c_list): + env_c[prev_node_c.name] = _insert_logger_after_node( + prev_node_c, + gm_b, + logger_cls, + "_ns_logger_b_inp_", + node_b.name, + name_b, + ref_name, + ref_node_type_b, + NSSingleResultValuesType.NODE_INPUT.value, + index_within_arg=arg_idx, + index_of_arg=0, + fqn=fqn_base_b, + ) + else: + # logging of inputs which are not lists is not supported yet + raise AssertionError( + f"type {type(prev_node_b)} is not handled yet" + ) + # subgraph so far: + # + # (prev_node_c)+ -> (logger_c_input)? + + # Note: this if statement is always True, spelling it out to clarify code + # intent. + if node_b_is_start_node or node_b_is_end_node: + # ensure env_c is populated with base node + env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) + node_c = env_c[node_b.name] + + # after this point, + # + # node_a is the original node from graph_a, with parent module gm_a + # node_b is the original node from graph_b, with parent module gm_b + # node_c is the copy of node_b in graph_c + # + # subgraph so far: + # + # (prev_node_c)+ -> (logger_c_input)? -> node_start_c + + if node_b_is_start_node: + # cast dtype from the dtype of node_c's input to the dtype of + # node_a's input (dequant, etc) + # prev_node_c = node_c.args[0] + prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) # type: ignore[possibly-undefined] + if should_log_inputs: + # skip the input logger when inserting a dtype cast + if isinstance(prev_node_c, Node): + # pyrefly: ignore [unbound-name] + prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) + elif isinstance(prev_node_c, list): + prev_node_c = [ + get_normalized_nth_input(arg, gm_b, 0) + for arg in prev_node_c + ] + dtype_cast_node = _insert_dtype_cast_after_node( + subgraph_a.start_node, + # pyrefly: ignore [unbound-name] + node_c, + prev_node_c, + gm_a, + gm_b, + graph_c, + node_b.name + "_dtype_cast_", + logger_cls, + node_type_to_io_type_map, + ) + # note: not inserting to env_c because all nodes which use the dtype + # casts are copied from graph_a + # + # subgraph so far: + # + # (dtype_cast_node)+ + # / + # (prev_node_c)+ -> (logger_c_input)? -> node_start_c + + # if input logging is enabled, log the input to the subgraph + if should_log_inputs: + # TODO: explain this + ref_node_name = "" + if isinstance(dtype_cast_node, Node): + dtype_cast_node = _insert_logger_after_node( + dtype_cast_node, + gm_b, + logger_cls, + "_ns_logger_a_inp_", + ref_node_name, + name_a, + ref_name, + ref_node_type_a, + NSSingleResultValuesType.NODE_INPUT.value, + index_within_arg=0, + index_of_arg=0, + fqn=fqn_base_a, + ) + input_logger: Node | list[Node] = dtype_cast_node + else: + if not isinstance(dtype_cast_node, list): + raise AssertionError( + f"Expected list, got {type(dtype_cast_node)}" + ) + new_loggers = [] + for dtype_cast_idx, dtype_cast_node_inner in enumerate( + dtype_cast_node + ): + dtype_cast_logger = _insert_logger_after_node( + dtype_cast_node_inner, + gm_b, + logger_cls, + "_ns_logger_a_inp_", + ref_node_name, + name_a, + ref_name, + ref_node_type_a, + NSSingleResultValuesType.NODE_INPUT.value, + index_within_arg=dtype_cast_idx, + index_of_arg=0, + fqn=fqn_base_a, + ) + new_loggers.append(dtype_cast_logger) + dtype_cast_node = new_loggers + input_logger = dtype_cast_node + # subgraph so far: + # + # (dtype_cast_node)+ -> (logger_a_input)? + # / + # prev_node_c -> (logger_c_input)? -> node_start_c + + # hook up the new mod_a copy to be in the graph, receiving the + # same inputs as mod_b does, with dtype cast to match a + # Some ops, such as LSTMs, have two non-param inputs. If we have + # such an op, pass the second param as well. Note: dtype casting + # for the second param is not implemented yet, it can be added + # later if there is a use case. + node_c_second_non_param_arg = None + num_non_param_args_node_a = get_number_of_non_param_args( + subgraph_a.start_node, gm_a + ) + if num_non_param_args_node_a == 2: + # node_c_second_non_param_arg = node_c.args[1] + node_c_second_non_param_arg = get_normalized_nth_input( + # pyrefly: ignore [unbound-name] + node_c, + gm_b, + 1, + ) + node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c( + dtype_cast_node, + node_c_second_non_param_arg, + subgraph_a, + gm_a, + gm_b, + # pyrefly: ignore [unbound-name] + node_c.name + "_shadow_copy_", + ) + env_c[node_a_shadows_c.name] = node_a_shadows_c + # subgraph so far: + # + # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown) + # / + # (prev_node_c)+ -> (logger_c_input)? -> node_start_c + + if should_log_inputs: + # When we created the input logger, we left the ref_node_name + # as an empty string, because the subgraph copy did not exist + # yet. Now that the subgraph copy exists, we modify this name + # to its true value. + # Note: the alternative to this is to create the input logger + # after creating the subgraph, which is slightly more + # complicated. This is the lesser of two evils. + # input_logger = env_c[dtype_cast_node.name] + # Find the first node in the subgraph + cur_node = node_a_shadows_c + while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined] + cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment] + # pyrefly: ignore [unbound-name] + if isinstance(input_logger, Node): + # pyrefly: ignore [unbound-name] + input_logger_mod = getattr(gm_b, input_logger.name) + input_logger_mod.ref_node_name = cur_node.name + else: + # pyrefly: ignore [unbound-name] + if not isinstance(input_logger, list): + raise AssertionError( + # pyrefly: ignore [unbound-name] + f"Expected list, got {type(input_logger)}" + ) + # pyrefly: ignore [unbound-name] + for input_logger_inner in input_logger: + input_logger_mod = getattr(gm_b, input_logger_inner.name) + input_logger_mod.ref_node_name = cur_node.name + + # hook up a logger to the mod_a copy + env_c[node_a_shadows_c.name] = _insert_logger_after_node( + env_c[node_a_shadows_c.name], + gm_b, + logger_cls, + "_ns_logger_a_", + node_a_shadows_c.name, + name_a, + ref_name, + ref_node_type_a, + NSSingleResultValuesType.NODE_OUTPUT.value, + index_within_arg=0, + index_of_arg=0, + fqn=fqn_base_a, + ) + # subgraph so far: + # + # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a + # / + # (prev_node_c)+ -> (logger_c_input)? -> node_start_c + + if node_b_is_end_node: + # hook up a logger to the mod_b copy + env_c[node_b.name] = _insert_logger_after_node( + env_c[node_b.name], + gm_b, + logger_cls, + "_ns_logger_b_", + node_b.name, + name_b, + ref_name, + ref_node_type_b, + NSSingleResultValuesType.NODE_OUTPUT.value, + index_within_arg=0, + index_of_arg=0, + fqn=fqn_base_b, + ) + # subgraph so far: + # + # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a + # / + # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c + # + # Note: node_start_c may be the same node as node_end_c, or they + # may have nodes in between. + + else: + env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) + + gm_c = GraphModule(gm_b, graph_c) + return gm_c diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/mappings.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..275291789f1c5461af366038d7702801bf5fc303 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/mappings.py @@ -0,0 +1,763 @@ +import operator +from typing import TYPE_CHECKING + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.intrinsic.quantized.dynamic as nniqd +import torch.ao.nn.qat as nnqat +import torch.ao.nn.qat.dynamic as nnqatd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.ao.quantization.fx._lower_to_native_backend as _lower_to_native_backend +import torch.ao.quantization.quantization_mappings as quantization_mappings +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.backend_config import get_native_backend_config + +from .ns_types import NSNodeTargetType + + +if TYPE_CHECKING: + from collections.abc import Callable + + +toq = torch.ops.quantized + + +def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]: + # note: this set is modified below by items from backend_config + sets_of_related_ops: list[set[NSNodeTargetType]] = [ + # conv modules + { + nn.Conv1d, + }, + { + nn.Conv2d, + }, + { + nn.Conv3d, + }, + # conv functionals + { + F.conv1d, + }, + { + F.conv2d, + }, + { + F.conv3d, + }, + # linear modules + { + nn.Linear, + }, + # linear functionals + { + F.linear, + }, + # average pool + { + nn.AvgPool1d, + torch.avg_pool1d, + }, + { + nn.AvgPool2d, + torch._C._nn.avg_pool2d, + }, + { + nn.AvgPool3d, + torch._C._nn.avg_pool3d, + }, + # adaptive average pool + { + nn.AdaptiveAvgPool1d, + F.adaptive_avg_pool1d, + }, + { + nn.AdaptiveAvgPool2d, + F.adaptive_avg_pool2d, + }, + { + nn.AdaptiveAvgPool3d, + F.adaptive_avg_pool3d, + }, + # LSTM + { + nn.LSTM, + }, + # add + { + torch.add, + operator.add, # x + y + }, + # cat + { + torch.cat, + }, + # mul + { + torch.mul, + operator.mul, + }, + # relu + { + F.relu, + nn.ReLU, + "relu", + "relu_", + torch.relu, + }, + # maxpool + { + nn.MaxPool1d, + F.max_pool1d, + }, + { + nn.MaxPool2d, + F.max_pool2d, + }, + { + nn.MaxPool3d, + F.max_pool3d, + }, + # sigmoid + { + torch.sigmoid, + "sigmoid", + "sigmoid_", + nn.Sigmoid, + F.sigmoid, + }, + # BatchNorm + { + nn.BatchNorm2d, + }, + { + nn.BatchNorm3d, + }, + # ConvTranspose + { + nn.ConvTranspose1d, + }, + { + nn.ConvTranspose2d, + }, + { + nn.ConvTranspose3d, + }, + # functional transposed conv + { + F.conv_transpose1d, + }, + { + F.conv_transpose2d, + }, + { + F.conv_transpose3d, + }, + # ELU + { + nn.ELU, + }, + # Embedding + { + nn.Embedding, + }, + # EmbeddingBag + { + nn.EmbeddingBag, + }, + # GroupNorm + { + nn.GroupNorm, + }, + # Hardswish + { + nn.Hardswish, + }, + # InstanceNorm + { + nn.InstanceNorm1d, + }, + { + nn.InstanceNorm2d, + }, + { + nn.InstanceNorm3d, + }, + # LayerNorm + { + nn.LayerNorm, + }, + # LeakyReLU + { + nn.LeakyReLU, + }, + # ReLU6 + { + nn.ReLU6, + F.relu6, + }, + # F.elu + { + F.elu, + }, + # F.hardswish + { + F.hardswish, + }, + # F.group_norm + { + F.group_norm, + }, + # F.instance_norm + { + F.instance_norm, + }, + # F.layer_norm + { + F.layer_norm, + }, + # F.leaky_relu + { + F.leaky_relu, + }, + # F.silu + { + nn.SiLU, + F.silu, + }, + # F.mish + { + nn.Mish, + F.mish, + }, + # F.tanh + { + nn.Tanh, + F.tanh, + torch.tanh, + "tanh_", + "tanh", + }, + # F.hardsigmoid + { + "hardsigmoid_", + "hardsigmoid", + F.hardsigmoid, + nn.Hardsigmoid, + }, + # F.hardtanh + { + nn.Hardtanh, + F.hardtanh, + F.hardtanh_, + }, + # floordiv + { + operator.floordiv, + }, + # unsqueeze + { + torch.unsqueeze, + }, + # stack + { + torch.stack, + }, + # squeeze + { + torch.squeeze, + }, + # sort + { + torch.sort, + }, + # repeat_interleave + { + torch.repeat_interleave, + }, + # min + { + torch.min, + }, + # mean + { + torch.mean, + }, + # max + { + torch.max, + }, + # transpose + { + torch.transpose, + }, + # flatten + { + torch.flatten, + }, + # clamp + { + torch.clamp, + }, + # chunk + { + torch.chunk, + }, + # interpolate + { + torch.nn.functional.interpolate, + }, + # dropout + { + nn.Dropout, + }, + # F.dropout + { + F.dropout, + }, + # matmul + { + torch.matmul, + }, + # Softmax + { + nn.Softmax, + }, + # PReLU + { + nn.PReLU, + nnq.PReLU, + }, + # F.prelu + { + F.prelu, + toq.prelu, + }, + # pixel shuffle + { + nn.PixelShuffle, + }, + { + F.pixel_shuffle, + }, + # pixel unshuffle + { + nn.PixelUnshuffle, + }, + { + F.pixel_unshuffle, + }, + # narrow + { + torch.narrow, + }, + ] + + # for each floating point op, add versions of the op added by + # backend_config + backend_config = get_native_backend_config() + + new_connections: list[tuple[Callable, Callable]] = [ + # technical debt edge case + (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear), + ] + + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + # pattern format: (c, (b, a)) + first_element = pattern + # look from the end, because pattern is in reverse order + while isinstance(first_element, (list, tuple)): + first_element = first_element[-1] + + if config.fused_module is not None: + # case 1: pattern fuses a pattern of ops into an op + # example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d + new_connections.append((first_element, config.fused_module)) + + if config.qat_module is not None: + # case 2: pattern swaps a module into a QAT module + # example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d + new_connections.append((first_element, config.qat_module)) + + if config.reference_quantized_module is not None: + # case 3: reference version of floating point module, such as + # nn.Conv2d and nnqr.Conv2d + new_connections.append((first_element, config.reference_quantized_module)) + + # + # Add reference module swaps from default lowering path + # + + for source_to_target in ( + _lower_to_native_backend.STATIC_LOWER_MODULE_MAP, + _lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP, + _lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP, + _lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP, + ): + for source, target in source_to_target.items(): # type: ignore[attr-defined] + new_connections.append((source, target)) + + for source_to_double_target in ( + _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP, + _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP, + _lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP, + ): + for source, (target1, target2) in source_to_double_target.items(): # type: ignore[attr-defined] + new_connections.append((source, target1)) + new_connections.append((source, target2)) + + # + # Add function swaps from default lowering path + # + + for source, ( # type:ignore[assignment] + target1, + target2, + ) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items(): + new_connections.append((source, target1)) + # pyrefly: ignore [bad-argument-type] + new_connections.append((source, target2)) + + for source_to_target in ( + _lower_to_native_backend.QBIN_OP_MAPPING, + _lower_to_native_backend.QBIN_RELU_OP_MAPPING, + quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, + ): + for source, target in source_to_target.items(): # type:ignore[assignment] + # pyrefly: ignore [bad-argument-type] + new_connections.append((source, target)) + + # + # Add other swaps, ideally in the future this could be removed + # after the lowering code stops using these. + # + for source_to_target in ( + quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, + ): + for source, target in source_to_target.items(): # type:ignore[assignment] + new_connections.append((source, target)) + + # add the new connections from backend_config + for item1, item2 in new_connections: + for set_of_related_ops in sets_of_related_ops: + if item1 in set_of_related_ops or item2 in set_of_related_ops: + set_of_related_ops.add(item1) + set_of_related_ops.add(item2) + break + + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]] = {} + + for counter, set_of_related_ops in enumerate(sets_of_related_ops): + base_name = str(counter) + base_name_to_sets_of_related_ops[base_name] = set_of_related_ops + + return base_name_to_sets_of_related_ops + + +def get_base_name_for_op( + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]], + op: NSNodeTargetType, +) -> str | None: + for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items(): + if op in set_of_related_ops: + return base_name + return None + + +def add_op_to_sets_of_related_ops( + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]], + op: NSNodeTargetType, + related_op: NSNodeTargetType | None, +) -> None: + if related_op is not None: + for set_of_related_ops in base_name_to_sets_of_related_ops.values(): + if related_op in set_of_related_ops: + set_of_related_ops.add(op) + return + # if we got here, related_op was not found + raise AssertionError(f"{related_op} was not found") + else: + counter = 0 + while str(counter) in base_name_to_sets_of_related_ops: + counter += 1 + base_name_to_sets_of_related_ops[str(counter)] = {op} + + +# TODO(future PR): clean this up +def get_node_type_to_io_type_map() -> dict[str, set[NSNodeTargetType]]: + FUNS_IO_TYPE_FP32: set[NSNodeTargetType] = { + F.linear, + F.conv1d, + F.conv2d, + F.conv3d, + torch.cat, + F.elu, + F.hardswish, + F.instance_norm, + F.layer_norm, + F.leaky_relu, + F.dropout, + F.silu, + F.mish, + operator.add, + torch.add, + operator.mul, + torch.mul, + torch.sum, + F.prelu, + } + + FUNS_IO_TYPE_FP16: set[NSNodeTargetType] = set() + + FUNS_IO_TYPE_INT8: set[NSNodeTargetType] = { + toq.linear, + toq.linear_relu, + toq.conv1d, + toq.conv1d_relu, + toq.conv2d, + toq.conv2d_relu, + toq.conv3d, + toq.conv3d_relu, + toq.cat, + toq.elu, + toq.hardswish, + toq.instance_norm, + toq.layer_norm, + toq.leaky_relu, + toq.dropout, + toq.prelu, + # TODO(future PR): implement shadowing for binary ops and + # uncomment below + # toq.add, + # toq.mul, + } + + FUNS_IO_TYPE_FP32_OR_INT8: set[NSNodeTargetType] = { + F.relu, + F.tanh, + torch.tanh, + F.sigmoid, + torch.sigmoid, + F.hardsigmoid, + operator.floordiv, + torch.adaptive_avg_pool1d, + F.adaptive_avg_pool2d, + F.adaptive_avg_pool3d, + F.dropout, + F.hardtanh, + F.hardtanh_, + F.interpolate, + F.max_pool1d, + F.max_pool2d, + F.max_pool3d, + F.relu6, + F.pixel_shuffle, + F.pixel_unshuffle, + torch.avg_pool1d, + torch._C._nn.avg_pool2d, + torch._C._nn.avg_pool3d, + torch.cat, + torch.chunk, + torch.clamp, + torch.flatten, + torch.transpose, + torch.max, + torch.mean, + torch.min, + torch.narrow, + torch.repeat_interleave, + torch.sort, + torch.squeeze, + torch.stack, + torch.unsqueeze, + operator.add, + } + + MODS_IO_TYPE_FP32: set[NSNodeTargetType] = { + nn.Linear, + nnqat.Linear, + nnqatd.Linear, + nnqd.Linear, + torch.nn.modules.linear.NonDynamicallyQuantizableLinear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + nnqat.Embedding, + nnqat.EmbeddingBag, + nn.LSTM, + # note: nnqd.Linear is an instance of nnq.Linear, so this + # check has to happen before the int8 module check + nnqd.LSTM, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.Dropout, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.ELU, + nn.GroupNorm, + nn.InstanceNorm1d, + nn.InstanceNorm2d, + nn.InstanceNorm3d, + nn.LayerNorm, + nn.Hardswish, + nn.LeakyReLU, + nn.ReLU6, + nn.SiLU, + nn.Mish, + nn.Softmax, + nn.PReLU, + nni.BNReLU2d, + nni.BNReLU3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.LinearReLU, + nni.LinearBn1d, + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nniqat.ConvBn1d, + nniqat.ConvBn2d, + nniqat.ConvBn3d, + nniqat.ConvBnReLU1d, + nniqat.ConvBnReLU2d, + nniqat.ConvBnReLU3d, + nniqat.ConvReLU1d, + nniqat.ConvReLU2d, + nniqat.ConvReLU3d, + nniqat.LinearReLU, + nniqat.LinearBn1d, + nniqd.LinearReLU, + nni.LinearLeakyReLU, + nni.LinearTanh, + nni.ConvAdd2d, + nni.ConvAddReLU2d, + } + + MODS_IO_TYPE_INT8: set[NSNodeTargetType] = { + nnq.Linear, + nnq.Conv1d, + nnq.Conv2d, + nnq.Conv3d, + nnq.BatchNorm2d, + nnq.BatchNorm3d, + nnq.Dropout, + nnq.ConvTranspose1d, + nnq.ConvTranspose2d, + nnq.ELU, + nnq.InstanceNorm1d, + nnq.InstanceNorm2d, + nnq.InstanceNorm3d, + nnq.LayerNorm, + nnq.Hardswish, + nnq.LeakyReLU, + nnq.Embedding, + nnq.EmbeddingBag, + nnq.Dropout, + nnq.Softmax, + nnq.PReLU, + nniq.BNReLU2d, + nniq.BNReLU3d, + nniq.ConvReLU1d, + nniq.ConvReLU2d, + nniq.ConvReLU3d, + nniq.LinearReLU, + nniq.LinearLeakyReLU, + nniq.LinearTanh, + nniq.ConvAdd2d, + nniq.ConvAddReLU2d, + } + + MODS_IO_TYPE_FP32_OR_INT8: set[NSNodeTargetType] = { + nn.ReLU, + nn.Tanh, + nn.Sigmoid, + nn.Hardsigmoid, + nn.AdaptiveAvgPool1d, + nn.AdaptiveAvgPool2d, + nn.AdaptiveAvgPool3d, + nn.AvgPool1d, + nn.AvgPool2d, + nn.AvgPool3d, + nn.Dropout, + nn.Hardtanh, + nn.Identity, + nn.MaxPool1d, + nn.MaxPool2d, + nn.MaxPool3d, + nn.PixelShuffle, + nn.PixelUnshuffle, + nn.ReLU6, + } + + METHS_IO_TYPE_FP32_OR_INT8: set[NSNodeTargetType] = { + "sigmoid_", + "sigmoid", + "tanh_", + "tanh", + "hardsigmoid_", + "hardsigmoid", + "relu_", + "relu", + } + + return { + "funs_io_type_fp32": FUNS_IO_TYPE_FP32, + "funs_io_type_fp16": FUNS_IO_TYPE_FP16, + "funs_io_type_int8": FUNS_IO_TYPE_INT8, + "funs_io_type_fp32_or_int8": FUNS_IO_TYPE_FP32_OR_INT8, + "mods_io_type_fp32": MODS_IO_TYPE_FP32, + "mods_io_type_int8": MODS_IO_TYPE_INT8, + "mods_io_type_fp32_or_int8": MODS_IO_TYPE_FP32_OR_INT8, + "meths_io_type_fp32_or_int8": METHS_IO_TYPE_FP32_OR_INT8, + } + + +def get_unmatchable_types_map() -> dict[str, set[NSNodeTargetType]]: + FUNS_UNMATCHABLE: set[NSNodeTargetType] = { + torch.quantize_per_tensor, + operator.getitem, + } + + MODS_UNMATCHABLE: set[NSNodeTargetType] = { + nn.Identity, + } + + METHS_UNMATCHABLE: set[NSNodeTargetType] = { + "to", + "dequantize", + "reshape", + "view", + "unsqueeze_", + "unsqueeze", + "transpose", + "squeeze_", + "squeeze", + "size", + "shape", + "resize_", + "repeat_interleave", + "repeat", + "permute", + "numel", + "mean", + "detach_", + "detach", + "contiguous", + "clamp", + "chunk", + } + + return { + "funs_unmatchable": FUNS_UNMATCHABLE, + "mods_unmatchable": MODS_UNMATCHABLE, + "meths_unmatchable": METHS_UNMATCHABLE, + } diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/n_shadows_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/n_shadows_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..95d467d9337ea24d676d282740df042d5bdd16f3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/n_shadows_utils.py @@ -0,0 +1,1416 @@ +# mypy: allow-untyped-defs +import collections +import copy +import operator +from collections.abc import Callable +from typing import Any + +import torch +import torch.fx +from torch.ao.ns.fx.graph_passes import _maybe_get_fqn +from torch.ao.ns.fx.ns_types import NSResultsType, NSSingleResultValuesType +from torch.ao.ns.fx.utils import ( # TODO(future PR): make this work correctly for methods + get_normalized_nth_input, + get_target_type_str, +) +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.fx.match_utils import _MatchResult +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization.utils import getattr_from_fqn +from torch.fx import Graph, GraphModule, Node +from torch.utils._pytree import tree_map + + +SHADOW_NODE_NAME_PREFIX = "shadow" +SHADOW_WRAPPER_NODE_NAME_PREFIX = "shadow_wrapper" + +# TODO(future PR): reuse existing mapping instead of creating a new one +BINARY_FUNCTIONS = { + torch.add, + torch.Tensor.add, + operator.add, + torch.mul, + torch.Tensor.mul, + operator.mul, +} + + +def _get_attr_name(subgraph_idx, subgraph_candidate_idx): + return f"{SHADOW_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}" + + +def _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx): + return f"{SHADOW_WRAPPER_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}" + + +class OutputProp: + """ + Output propagation (modeled from shape propagation). + + Given a GraphModule and an example input, saves the output flowing + through each node on `node.traced_result`. + + Code based on the example from + https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern + """ + + def __init__(self, mod): + self.mod = mod + self.graph = mod.graph + self.modules = dict(self.mod.named_modules()) + + def propagate(self, *args): + args_iter = iter(args) + env: dict[str, Node] = {} + + def load_arg(a): + return torch.fx.graph.map_arg(a, lambda n: env[n.name]) + + def fetch_attr(target: str): + target_atoms = target.split(".") + attr_itr = self.mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + for node in self.graph.nodes: + if node.op == "placeholder": + result = next(args_iter) + elif node.op == "get_attr": + result = fetch_attr(node.target) + elif node.op == "call_function": + result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) + elif node.op == "call_method": + self_obj, *args = load_arg(node.args) + kwargs = load_arg(node.kwargs) + result = getattr(self_obj, node.target)(*args, **kwargs) + elif node.op == "call_module": + result = self.modules[node.target]( + *load_arg(node.args), **load_arg(node.kwargs) + ) + + if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined] + # pyrefly: ignore [unbound-name] + node.traced_result = result + + # pyrefly: ignore [unsupported-operation] + # pyrefly: ignore [unbound-name] + env[node.name] = result + + return None + + +def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Node]]: + # the original matches variable is unique by node, make it unique by subgraph + # instead + seen_nodes = set() + subgraphs_dedup = {} + + # Dict items are not reversible until Python 3.8, so we hack it + # to be compatible with previous Python versions + # TODO(future PR): try reversed(list(matches.items())) + matches_items_reversed: list[tuple[str, _MatchResult]] = list( + reversed(matches.items()) + ) + + # Note: the order is important. `matches` currently provides the matches + # in reverse order. We would like to process the matches in non-reverse + # order, so that we can create an intuitive naming scheme, such as + # naming the first op's submodules `shadow_0_0` through `shadow_0_(n-1)` + for name, cur_match in matches_items_reversed: # type: ignore[call-overload] + was_seen = False + for node_or_tuple in cur_match[1]: + # Cur_match[1] has an unusual type. It says that it's a `List[Node]`, + # but it is really not. Furthermore, the contents of this field + # can change from match results of multiple nodes of the same pattern + # + # For example, for conv -> bn -> relu, we see + # match_results = { + # 'conv': (relu, [(bn, conv), relu], ...), + # 'bn': (relu, [(bn, conv), relu], ...), + # 'relu': (relu, [(bn, conv), relu], ...), + # } + # + # Ideally we should clean up the `find_matches` function to make + # this more intuitive. For the purposes of this prototype, we hack + # around it. + + if isinstance(node_or_tuple, Node): + if node_or_tuple in seen_nodes: + was_seen = True + seen_nodes.add(node_or_tuple) + + else: + if not isinstance(node_or_tuple, tuple): + raise AssertionError(f"Expected tuple, got {type(node_or_tuple)}") + for node in node_or_tuple: + if not isinstance(node, Node): + raise AssertionError(f"Expected Node, got {type(node)}") + if node in seen_nodes: + was_seen = True + seen_nodes.add(node) + + if was_seen: + continue + + # Start with the unusual type, convert it to [op_0, ..., op_n] + list_of_nodes = [] + + if len(cur_match[1]) == 1: + list_of_nodes = cur_match[1] + else: + if len(cur_match[1]) != 2: + raise ValueError( + f"Expected cur_match[1] to have length 2, got {len(cur_match[1])}" + ) + # either (a, b), or ((a, b), c) or (c, (a, b)) + # cannot make any assumptions on order, not clear what the + # _find_matches function is doing to populate this + # TODO(future PR): make this code less confusing, see discussion + # in https://github.com/pytorch/pytorch/pull/80521/files#r975918836 + + def _order_nodes(node_a, node_b, node_c) -> list[Node]: + nodes = [node_a, node_b, node_c] + first_node = None + mid_node = None + last_node = None + for n in nodes: + prev_n = n.args[0] + next_n = next(iter(n.users)) + if prev_n not in nodes: + first_node = n + elif next_n not in nodes: + last_node = n + else: + mid_node = n + if first_node is None or mid_node is None or last_node is None: + raise AssertionError("Expected all nodes to be non-None") + if mid_node.args[0] is not first_node: + raise AssertionError("Expected mid_node.args[0] to be first_node") + if last_node.args[0] is not mid_node: + raise AssertionError("Expected last_node.args[0] to be mid_node") + return [last_node, mid_node, first_node] + + if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node): + # (a, b) + list_of_nodes = cur_match[1] + elif isinstance(cur_match[1][0], tuple): + # ((a, b), c) + node_a, node_b = cur_match[1][0] + node_c = cur_match[1][1] + list_of_nodes = _order_nodes(node_a, node_b, node_c) + elif isinstance(cur_match[1][1], tuple): + # (a, (b, c)) + node_a, node_b = cur_match[1][1] + node_c = cur_match[1][0] + list_of_nodes = _order_nodes(node_a, node_b, node_c) + + # [node_n, ..., node_0], note that the order is reversed + # to make it chronological for simple subgraphs + list_of_nodes.reverse() + subgraphs_dedup[name] = list_of_nodes + + return subgraphs_dedup + + +def _get_logger_for_subgraph( + model: GraphModule, + first_node: Node, + last_node: Node, + subgraph_idx: int, + subgraph_candidate_idx: int, + qconfig_str: str, + logger_cls: Callable, + fqn: str | None, +) -> torch.nn.Module: + """ + Given a model and a linear subgraph starting from `first_node` and + ending with `last_node`, creates a logger for the end of this + subgraph. + """ + if fqn is None: + fqn = "" + logger_mod_orig = logger_cls( + first_node.name, # ref_node_name + last_node.name, # prev_node_name + f"subgraph_{subgraph_idx}_{subgraph_candidate_idx}", # model_name + "model", # ref_name + get_target_type_str(last_node, model), # prev_node_target_type + get_target_type_str(first_node, model), # ref_node_target_type + NSSingleResultValuesType.NODE_OUTPUT.value, # results_type + 0, # index_within_arg + 0, # index_of_arg + fqn, # fqn + qconfig_str, + ) + # Usually we expect the user to add loggers, then calibrate, then convert, + # and then populate loggers. This is why the loggers start disabled. + # TODO(future PR): reconsider the design to make this more intuitive. + logger_mod_orig.enabled = False + return logger_mod_orig + + +def create_submodule_from_subgraph( + model: torch.nn.Module, + first_node: Node, + last_node: Node, +) -> GraphModule: + """ + Input: a model, and a linear subgraph within the model from first_node to + last_node. + + Output: a new submodule containing a copy of the subgraph, with the inputs + to the first node becoming the inputs to the submodule, and all other + nodes in the subgraph being copied. + + Example inputs: + + `model`: a module with graph + + x0 -> op1 -> x1 -> op2 -> x2 + | + arg1 + + `first_node`: op1 + `last_node`: op2 + + Example output: a new module with graph + + input1 -> op1_copy -> x1 -> op2_copy -> output1 + | + arg1 + """ + + # + # create a blank GraphModule with an empty graph + # + + class M(torch.nn.Module): + def forward(self, x): + pass + + m = M() + gm = torch.fx.symbolic_trace(m) + g = gm.graph + for node in reversed(gm.graph.nodes): + g.erase_node(node) + + # + # modify the graph to have a copy of our subgraph + # + + cur_node_orig = first_node + + cur_name_idx = 0 + + iteration_limit = 100 + cur_iteration = 0 + + while True: + if cur_node_orig is first_node: + # we are at the first node, we need to set up graph inputs + # TODO(future): some graphs could have placeholders which are unrelated + # to the first node, need to handle this + cur_args_copy = [] + cur_kwargs_copy = {} + seen_names: set[str] = set() + old_name_to_new_node: dict[str, Node] = {} + + def _add_placeholder( + g: Graph, node: Node, seen_names, old_name_to_new_node + ): + # note: for graphs starting with patterns such as `y = x + x`, we + # need to ensure we do not add multiple placeholders with the + # same name + counter = 0 + while node.name + "_" + str(counter) in seen_names: + counter += 1 + cur_name = node.name + "_" + str(counter) + seen_names.add(cur_name) + placeholder = g.placeholder(cur_name) + old_name_to_new_node[node.name] = placeholder + return placeholder + + for arg in cur_node_orig.args: + if isinstance(arg, Node): + p = _add_placeholder(g, arg, seen_names, old_name_to_new_node) + cur_args_copy.append(p) + elif isinstance(arg, (list, tuple)): + new_arg = [] + for inner_arg in arg: + if isinstance(inner_arg, Node): + new_arg.append( + _add_placeholder( + g, inner_arg, seen_names, old_name_to_new_node + ) + ) + else: + new_arg.append(inner_arg) + cur_args_copy.append(new_arg) + else: + cur_args_copy.append(arg) + + # TODO(future PR): handle non-normalized kwargs + for kwarg_name, kwarg in cur_node_orig.kwargs.items(): + if isinstance(kwarg, Node): + cur_kwargs_copy[kwarg_name] = _add_placeholder( + g, kwarg, seen_names, old_name_to_new_node + ) + elif isinstance(kwarg, (list, tuple)): + new_kwarg = [] + for inner_kwarg in kwarg: + p = _add_placeholder( + g, + inner_kwarg, # type: ignore[arg-type] + seen_names, + old_name_to_new_node, + ) + new_kwarg.append(p) + cur_kwargs_copy[kwarg_name] = new_kwarg + else: + cur_kwargs_copy[kwarg_name] = kwarg + + cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment] + else: + # we are not at first node, first arg is from the previous node, + # and all other args are copied + + # the current implementation is simplistic and cannot handle + # ops with two or more arguments which need to be passed from + # the previous op, so we assert them out + if cur_node_orig.target in BINARY_FUNCTIONS: + raise AssertionError( + f"Unexpected binary function target: {cur_node_orig.target}" + ) + + # at this point in the code, cur_node_copy is pointing to the copy + # of the previous node + # TODO(future PR): this is not handling complicated graphs correctly, need to + # look at actual relationships instead of assuming sequential graph + # TODO(future PR): this is ignoring kwargs, will need to support kwargs + # for any fusion pattern which has them for a node that is not the + # first node. + cur_args_copy = [cur_node_copy] # type: ignore[has-type, possibly-undefined] # noqa: F821 + + if len(cur_node_orig.args) > 1: + for arg in cur_node_orig.args[1:]: + if isinstance(arg, torch.nn.Parameter): + new_arg = arg.detach().clone() # type: ignore[assignment] + mod_name = f"mod_{cur_name_idx}" + cur_name_idx += 1 + setattr(gm, mod_name, new_arg) + new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator] + # pyrefly: ignore [missing-attribute] + cur_args_copy.append(new_arg_placeholder) + elif isinstance(arg, (float, int, torch.dtype)): + # pyrefly: ignore [missing-attribute] + cur_args_copy.append(arg) + else: + raise AssertionError(f"arg of type {type(arg)} not handled yet") + cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment] + + # copy the node + if cur_node_orig.op == "call_module": + orig_mod = getattr_from_fqn(model, cur_node_orig.target) # type: ignore[arg-type] + orig_mod_copy = copy.deepcopy(orig_mod) + mod_name = f"mod_{cur_name_idx}" + setattr(gm, mod_name, orig_mod_copy) + cur_name_idx += 1 + cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined,arg-type] + + elif cur_node_orig.op == "call_function": + cur_node_copy = g.call_function( + cur_node_orig.target, # type: ignore[arg-type] + cur_args_copy, # type: ignore[arg-type] + cur_kwargs_copy, # type: ignore[possibly-undefined] + ) + + elif cur_node_orig.op == "call_method": + cur_node_copy = g.call_method( + cur_node_orig.target, # type: ignore[arg-type] + cur_args_copy, # type: ignore[arg-type] + cur_kwargs_copy, # type: ignore[possibly-undefined] + ) + + else: + raise AssertionError(f"{cur_node_orig.op} not supported yet") + + if cur_node_orig is last_node: + break + + # go to next node + if len(cur_node_orig.users.keys()) != 1: + raise AssertionError( + f"{cur_node_orig} has more than 1 users, not supported yet" + ) + cur_node_orig = next(iter(cur_node_orig.users.keys())) + cur_iteration += 1 + if cur_iteration > iteration_limit: + raise AssertionError("iteration limit exceeded") + + # set up outputs + g.output(cur_node_copy) + + gm.recompile() + return gm + + +def create_one_transformed_and_logged_copy_of_subgraph( + mt: GraphModule, + subgraph_idx: int, + subgraph_candidate_idx: int, + first_node: Node, + last_node: Node, + fqn: str | None, + list_of_node_name_to_qconfig: list[dict[str, QConfigAny]], + example_inputs: Any, + last_added_shadow_node_list: list[Node | None], + custom_prepare_fn: Callable | None = None, + custom_prepare_kwargs: dict[str, Any] | None = None, +) -> None: + """ + Given a subgraph in `mt` and a subgraph candidate idx, inserts the + subgraph candidate copy and instruments it with loggers. + + If subgraph_candidate_idx is 0, this is the baseline fp32 subgraph and we just + add a logger to the end. + + If subgraph_candidate_idx is not 0, we create a copy of the subgraph and + prepare it with `prepare_fx`. + """ + + # TODO(future PR): move logger classes to utils to remove circular dependency + from torch.ao.ns._numeric_suite_fx import OutputComparisonLogger, OutputLogger + + if subgraph_candidate_idx == 0: + # idx = 0 is the floating point (original) version of the subgraph + # We keep the subgraph as is, and add a logger at the end + + qconfig_str = "" + logger_mod_orig = _get_logger_for_subgraph( + mt, + first_node, + last_node, + subgraph_idx, + subgraph_candidate_idx, + qconfig_str, + OutputLogger, + fqn, + ) + + attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx) + if hasattr(mt, attr_name): + raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}") + setattr(mt, attr_name, logger_mod_orig) + with mt.graph.inserting_after(last_node): + new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={}) + last_added_shadow_node_list[0] = new_node + + else: + # idx > 0 means we have a candidate qconfig to try, so we need + # to make a copy of the subgraph, feed it with the right inputs, + # and add a logger at the end + + # get the qconfig + # subtract one because the first candidate is the floating point + # version of the subgraph + node_name_to_qconfig = list_of_node_name_to_qconfig[subgraph_candidate_idx - 1] + qconfig = node_name_to_qconfig[first_node.name] + + # if no quantization is requested, skip + # TODO(future PR): deduplicate equivalent qconfigs that come from + # different qconfig mapping objects + if qconfig is None: + return + + qconfig_mapping = QConfigMapping().set_global(qconfig) + + # create a copy of the submodule, wrapped in a separate module + orig_mod_copy_wrapped = create_submodule_from_subgraph( + mt, first_node, last_node + ) + + # add a call to prepare_fx on the wrapper module + if custom_prepare_fn is None: + orig_mod_copy_wrapped = torch.ao.quantization.quantize_fx.prepare_fx( + orig_mod_copy_wrapped, qconfig_mapping, example_inputs=example_inputs + ) + else: + if custom_prepare_kwargs is None: + custom_prepare_kwargs = {} + for kwarg_name in [ + "example_inputs", + "prepare_custom_config", + "qconfig_mapping", + ]: + if kwarg_name in custom_prepare_kwargs: + raise AssertionError( + f"cannot specify {kwarg_name} in custom_prepare_kwargs" + ) + prepare_kwargs: dict[str, Any] = { + "example_inputs": example_inputs, + "qconfig_mapping": qconfig_mapping, + } + prepare_kwargs.update(custom_prepare_kwargs) + orig_mod_copy_wrapped = custom_prepare_fn( + orig_mod_copy_wrapped, **prepare_kwargs + ) + + # attach the wrapper to the model + attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx) + if hasattr(mt, attr_name): + raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}") + setattr(mt, attr_name, orig_mod_copy_wrapped) + + # add a call to the wrapper module from the parent graph + insert_after_node = last_added_shadow_node_list[0] + with mt.graph.inserting_after(insert_after_node): + # TODO(future PR): handle fusion patterns where non-first nodes + # need inputs + + # pass in all node args and kwargs + + new_args = [] + for arg in first_node.args: + if isinstance(arg, Node): + new_args.append(arg) + elif ( + isinstance(arg, (list, tuple)) + and len(arg) + and isinstance(arg[0], Node) + ): + new_args.extend( + inner_arg for inner_arg in arg if isinstance(inner_arg, Node) + ) + + new_kwargs = {} + for name, old_kwarg in first_node.kwargs.items(): + if isinstance(old_kwarg, Node): + new_kwargs[name] = old_kwarg + elif isinstance(old_kwarg, (list, tuple)) and len(old_kwarg): + # TODO(future PR): clarify why we are adding kwargs to args + new_args.extend(old_kwarg) # type: ignore[arg-type] + + new_args = tuple(new_args) # type: ignore[assignment] + + new_node = mt.graph.call_module(attr_name, args=new_args, kwargs=new_kwargs) # type: ignore[arg-type] + + # add a logger to parent graph to observe the shadow wrapper + logger_mod_orig = _get_logger_for_subgraph( + mt, + first_node, + last_node, + subgraph_idx, + subgraph_candidate_idx, + str(qconfig), + OutputComparisonLogger, + fqn, + ) + + attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx) + if hasattr(mt, attr_name): + raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}") + setattr(mt, attr_name, logger_mod_orig) + with mt.graph.inserting_after(new_node): + logger = mt.graph.call_module( + attr_name, args=(new_node, last_node), kwargs={} + ) + last_added_shadow_node_list[0] = logger + + mt.recompile() + + +def create_n_transformed_and_logged_copies_of_subgraph( + mt: GraphModule, + subgraph_idx: int, + match_name: str, + nodes_in_this_subgraph: list[Any], + qconfig_mappings: list[QConfigMapping], + list_of_node_name_to_qconfig: list[dict[str, QConfigAny]], + custom_prepare_fn: Callable | None = None, + custom_prepare_kwargs: dict[str, Any] | None = None, +) -> None: + """ + Given a model `mt` and a subgraph_idx, creates the needed copies + of the subgraph for all qconfigs, and instruments them with loggers. + """ + # for now, assume that + # 1. the first node has one input + # 2. the last node has one output + + # for now, ignore all subgraphs that contain non-nodes (tuples, etc) + # TODO(future PR): implement this + if any(not isinstance(node, Node) for node in nodes_in_this_subgraph): + return + + first_node = nodes_in_this_subgraph[0] + last_node = nodes_in_this_subgraph[-1] + # We used output propagation to populate example values on each + # node. Use the example values from the previous node as the input + # to the current node. + prev_node = get_normalized_nth_input(first_node, mt, 0) + if isinstance(prev_node, list): + example_inputs = [x.traced_result for x in prev_node] + elif isinstance(prev_node, tuple): + example_inputs = (x.traced_result for x in prev_node) # type: ignore[assignment] + else: + # currently some customer models do not have a traced_result in + # every node, so we have to guard for this case since we cannot + # quantize without an example input + # TODO(future PR): add a test case for this once we have an easy + # repro, see https://github.com/pytorch/pytorch/pull/80521/files#r975940489 + # for additional context + if hasattr(prev_node, "traced_result"): + example_inputs = (prev_node.traced_result,) # type: ignore[attr-defined, assignment] + else: + print( + "unable to get example input for node " + + f"{first_node.format_node()}, skipping" + ) + return + + # If there are no quantization configs for this subgraph, skip adding + # loggers. This reduces memory usage for models where not all layers are + # quantized. + # TODO(future): consider making this configurable + found_at_least_one_qconfig = False + for subgraph_candidate_idx in range(len(qconfig_mappings) + 1): + if subgraph_candidate_idx == 0: + # fp32 baseline does not need a qconfig + continue + + # a. we have N shadows, so len(qconfig_mappings) is N + # b. we will have the fp32 layer + N shadows, so overall number of + # (original_op) + (*shadows) will be N+1 + # c. since `subgraph_candidate_idx` represents (b), we need + # to subtract 1 to query from (a) + node_name_to_qconfig = list_of_node_name_to_qconfig[subgraph_candidate_idx - 1] + qconfig = node_name_to_qconfig[first_node.name] + if qconfig is not None: + found_at_least_one_qconfig = True + break + if not found_at_least_one_qconfig: + print( + "unable to find at least one qconfig for node " + + f"{first_node.format_node()}, skipping" + ) + return + + fqn = _maybe_get_fqn(first_node, mt) + + # We want the results to contain the subgraphs in natural order, + # and the graph to also contain shadow wrappers and shadow loggers + # in natural order. + # If we just iterate in reverse, the graph will be in natural + # order but the eventual results will be in reverse order. + # So, we keep track of the last shadow logger we added and + # always insert after it. + last_added_shadow_node_list: list[Node | None] = [None] + for subgraph_candidate_idx in range(len(qconfig_mappings) + 1): + create_one_transformed_and_logged_copy_of_subgraph( + mt, + subgraph_idx, + subgraph_candidate_idx, + first_node, + last_node, + fqn, + list_of_node_name_to_qconfig, + example_inputs, + last_added_shadow_node_list, + custom_prepare_fn, + custom_prepare_kwargs, + ) + + +def create_add_loggers_graph( + model: GraphModule, + subgraphs_dedup: dict[str, list[Node]], + qconfig_mapping: QConfigMapping, + node_name_to_qconfig: dict[str, QConfigAny], +) -> None: + r""" + Given a model, a model graph partition (currently a set of matched + subgraphs) and instructions how to transform each subgraph + (currently quantizing it according to qconfig_mapping), modifies + the model graph to create an alternate path through the original graph, + with each of the subgraphs quantized. This is useful to compare + propagation error of a transformation such as quantization. + + For example, given layer op0 and op1, there are four cases when handling op1: + 1. op0 and op1 quantized + 2. op0 and op1 unquantized + 3. op0 quantized, op1 unquantized + 4. op0 unquantized, op1 quantized + + Example input, case 1: + + .. code:: + + x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log + \ \ \ \ # noqa: W605 + ---> op0_1 -> x1_1 ----> clog op1_1 -> x2_1 ----> clog + + Example output, case 1: + + .. code:: + + x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log + \ \ \ # noqa: W605 + ---> op0_1 -> x1_1 ----> clog -> op1_1 -> x2_1 ----> clog + + """ + # TODO(future PR): move logger classes to utils to remove circular dependency + from torch.ao.ns._numeric_suite_fx import OutputComparisonLogger, OutputLogger + + def _get_subgraph_containing_node(node, subgraphs_dedup): + for subgraph in subgraphs_dedup.values(): + if node in subgraph: + return subgraph + return None + + # First, we need to create shadow branches, going from + # + # x0 -> op0 -> x1 -> ... + # + # + # to + # + # x0 -> op0_0 -> x1_0 -> log -> ... + # \ \ + # -> op0_1 -> x1_1 -> clog + # + # Later, the outputs of each shadow will be rerouted to calculate + # propagation error. + + # Note: we cannot iterate over matched subgraphs because some nodes + # may not be matched. So, we iterate over nodes in the graph, and + # associate them to matched subgraphs if possible. + + nodes_to_skip = set() + # for each subgraph, save a mapping from first node of subgraph + # to first and last node of the shadow of this subgraph + orig_first_node_to_shadow_in_node = {} + orig_first_node_to_shadow_out_node = {} + # need to record original list because we will mutate the graph as we go + orig_nodes = list(model.graph.nodes) # type: ignore[union-attr, arg-type] + cur_subgraph_idx = 0 + for n in orig_nodes: + if n.op in ("placeholder", "get_attr", "output") or n in nodes_to_skip: + continue + + maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup) + insert_submodule_copy = False + if maybe_subgraph is not None: + first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] + nodes_to_skip.update(maybe_subgraph) + qconfig = node_name_to_qconfig[first_node.name] + if qconfig is not None: + insert_submodule_copy = True + else: + first_node, last_node = n, n + + if insert_submodule_copy: + match_name = first_node.name + create_n_transformed_and_logged_copies_of_subgraph( + model, + cur_subgraph_idx, + match_name, + # pyrefly: ignore [bad-argument-type] + maybe_subgraph, + [qconfig_mapping], + [node_name_to_qconfig], + None, + None, # type: ignore[arg-type] + ) + # find the created shadow module and record it so we + # can find it easily in step 2 + expected_shadow_target = f"shadow_wrapper_{cur_subgraph_idx}_1" + new_shadow_mod = None + for maybe_shadow_mod in model.graph.nodes: + if ( + maybe_shadow_mod.op == "call_module" + and maybe_shadow_mod.target == expected_shadow_target + ): + new_shadow_mod = maybe_shadow_mod + break + if new_shadow_mod is None: + raise AssertionError("Expected new_shadow_mod to be non-None") + orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod + orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod + + else: + # create a copy of the subgraph by only copying FX nodes + # but not copying any parameters, to minimize memory usage + subgraph_to_use = ( + maybe_subgraph if maybe_subgraph is not None else [first_node] + ) + + # add a regular logger after last_node + qconfig_str = "" + subgraph_candidate_idx = 0 + fqn = _maybe_get_fqn(first_node, model) + logger_mod_orig = _get_logger_for_subgraph( + model, + first_node, + last_node, + cur_subgraph_idx, + subgraph_candidate_idx, + qconfig_str, + OutputLogger, + fqn, + ) + attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx) + if hasattr(model, attr_name): + raise AssertionError( + f"Unexpected attribute '{attr_name}' found in {model}" + ) + setattr(model, attr_name, logger_mod_orig) + insertion_point = last_node + with model.graph.inserting_after(insertion_point): + logger = model.graph.call_module( + attr_name, args=(last_node,), kwargs={} + ) + insertion_point = logger + + # create a copy of the subgraph + cur_node_orig = first_node + cur_node_copy = None + first_node_copy = None + # pyrefly: ignore [bad-assignment] + while cur_node_orig in subgraph_to_use: + # TODO(future PR): make this support all possible args/kwargs + if cur_node_orig is first_node: + new_args = cur_node_orig.args + new_kwargs = cur_node_orig.kwargs + else: + first_arg_for_copy: Node | None = cur_node_copy + new_args = (first_arg_for_copy, *cur_node_orig.args[1:]) + new_kwargs = cur_node_orig.kwargs + # make a copy of cur_node_orig + with model.graph.inserting_after(insertion_point): + cur_node_copy = model.graph.create_node( + cur_node_orig.op, + cur_node_orig.target, + new_args, + new_kwargs, + # cur_node_orig.name, # TODO(future PR): set name explicitly + ) + if first_node_copy is None: + first_node_copy = cur_node_copy + # since now only linear subgraphs are supported, all nodes + # except the last one must have only one user + if cur_node_orig != last_node: + if len(cur_node_orig.users.keys()) != 1: + raise AssertionError( + f"Expected exactly 1, but got {len(cur_node_orig.users)}" + ) + cur_node_orig = next(iter(cur_node_orig.users.keys())) + if cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX): + raise AssertionError( + "cur_node_orig should not start with SHADOW_NODE_NAME_PREFIX" + ) + insertion_point = cur_node_copy + + # add a comparison logger after last_node's copy + subgraph_candidate_idx = 1 + logger_mod_orig = _get_logger_for_subgraph( + model, + first_node, + last_node, + cur_subgraph_idx, + subgraph_candidate_idx, + qconfig_str, + OutputComparisonLogger, + fqn, + ) + attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx) + if hasattr(model, attr_name): + raise AssertionError( + f"Unexpected attribute '{attr_name}' found in {model}" + ) + setattr(model, attr_name, logger_mod_orig) + with model.graph.inserting_after(insertion_point): + logger = model.graph.call_module( + attr_name, args=(cur_node_copy, last_node), kwargs={} + ) + + # save the final node so we can use it in step 2 + orig_first_node_to_shadow_in_node[first_node] = first_node_copy + orig_first_node_to_shadow_out_node[first_node] = cur_node_copy + + cur_subgraph_idx += 1 + + model.recompile() + + # Now, we go from + # + # x0 -> op0_0 -> x1_0 -> log -> x1 -> op1_0 -> ... + # \ \ \ + # -> op0_1 -> x1_1 -> clog -> op1_1 -> ... + # + # to + # + # x0 -> op0_0 -> x1_0 -> log --> x1_0 -> op1_0 -> ... + # \ \ + # -> op0_1 -> x1_1 -> clog -> x1_1 -> op1_1 -> ... + # + # sample values of key internal variables for the example above: + # + # orig_first_node_to_shadow_in_node = {op0_0: op0_1, op1_0: op1_1} + # orig_first_node_to_shadow_out_node = {op0_0: op0_1, op1_0: op1_1} + # + # note: for subgraphs with more than one node, in_node will be different + # compared to out_node + + nodes_to_skip = set() + for n in orig_nodes: + if n.op in ("placeholder", "get_attr", "output") or n in nodes_to_skip: + continue + + maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup) + if maybe_subgraph is not None: + first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] + nodes_to_skip.update(maybe_subgraph) + else: + first_node, last_node = n, n + + def maybe_remap_node_to_shadow(node): + """ + If unshadowed `node` has a shadow version, return that. If not, + return `node`. + """ + if not isinstance(node, Node): + # handle scalars + return node + + if node.op in ("placeholder", "get_attr"): + return node + + # Find the shadowed version of this arg from the previous + # subgraph. For this, we need to: + # 1. navigate to the first node of the previous subgraph + # 2. get the output of the shadow wrapper which has (1) as an input + + # For now, assume the arg is in matched subgraphs. In the + # future we may have to handle the case where this is not true. + prev_subgraph = _get_subgraph_containing_node(node, subgraphs_dedup) + if prev_subgraph is None: + prev_subgraph = [node] + prev_first_node = prev_subgraph[0] + prev_shadow_output = orig_first_node_to_shadow_out_node[prev_first_node] + return prev_shadow_output + + cur_shadow_input = orig_first_node_to_shadow_in_node[first_node] + if cur_shadow_input is None: + raise AssertionError("Expected cur_shadow_input to be non-None") + cur_shadow_input.args = tree_map( + maybe_remap_node_to_shadow, cur_shadow_input.args + ) + cur_shadow_input.kwargs = tree_map( + maybe_remap_node_to_shadow, cur_shadow_input.kwargs + ) + + model.recompile() + + +def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module): + # input: shadow wrapper module + # output if shadow wrapper module has a weighted op: + # (quantize_fn, (quantize_fn_args)) + # output if shadow wrapper module doesn't have a weighted op: + # None + + # For now, assume that the weight is the second input + # to the shadow module. If that changes, we can fix it later. + placeholders_seen = 0 + for shadow_n in shadow_wrapper.graph.nodes: # type: ignore[union-attr] + if shadow_n.op != "placeholder": + continue + + placeholders_seen += 1 + if placeholders_seen != 2: + continue + + # the subgraph looks like + # + # _input_scale_1 = self._input_scale_1 + # _input_zero_point_1 = self._input_zero_point_1 + # quantize_per_channel = torch.quantize_per_channel( + # w2_0, _input_scale_1, _input_zero_point_1, + # 0, torch.qint8) + # + # we have `w2_0`, and are navigating this subgraph + # to get `_input_scale_1` and `_input_zero_point_1` + + if len(shadow_n.users) != 1: + raise AssertionError(f"Expected exactly 1, got {len(shadow_n.users)}") + quant_node = next(iter(shadow_n.users.keys())) + new_args: Any = None + if quant_node.target is torch.quantize_per_channel: + _weight, scale_node, zp_node, axis, dtype = quant_node.args + scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target) + zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target) + new_args = (scale_val, zp_val, axis, dtype) + else: + if quant_node.target != torch.quantize_per_tensor: + raise AssertionError( + f"Expected torch.quantize_per_tensor, but got {quant_node.target}" + ) + _weight, scale_node, zp_node, dtype = quant_node.args + scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target) + zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target) + new_args = (scale_val, zp_val, dtype) + return (quant_node.target, new_args) + + return None + + +def extract_weight_comparison(m: GraphModule) -> NSResultsType: + # example graph: + # + # w1 = self.w1 + # b1 = self.b1 + # linear = torch._C._nn.linear(x, w1, b1) + # shadow_0_0 = self.shadow_0_0(linear) + # shadow_wrapper_0_1 = self.shadow_wrapper_0_1(x, w1, b1) + # shadow_0_1 = self.shadow_0_1(shadow_wrapper_0_1, linear) + # + # algorithm: + # 1. for each call_function node matching our allowlist: + # 2. if corresponding shadow wrapper exists, extract the weight pair + # + # Note: this is not super robust, but that's ok because this is + # just for legacy customers who depend on the previous two-model version + # of this API. TBD if we need to make this robust. + # Note: modules are not supported, since existing customers only + # use functions. + + # TODO(future PR): move this to config + weighted_ops = { + torch.nn.functional.linear, + } + + results: NSResultsType = {"model": {NSSingleResultValuesType.WEIGHT.value: {}}} + + for n in m.graph.nodes: # type: ignore[union-attr] + if not (n.op == "call_function" and n.target in weighted_ops): + continue + + # Check if we have a corresponding shadow wrapper + # TODO(future PR, if needed): support kwargs + # TODO(future PR, if needed): support multiple shadow users + first_arg = n.args[0] + shadow_wrapper_node = None + for user in first_arg.users: + # TODO(before land): fix string match + if user.op == "call_module" and user.target.startswith("shadow_wrapper"): + shadow_wrapper_node = user + break + + if shadow_wrapper_node is None: + continue + + shadow_wrapper = getattr_from_fqn(m, shadow_wrapper_node.target) # type: ignore[arg-type] + weight_info = _get_weight_info_from_shadow_wrapper(shadow_wrapper) + if weight_info is None: + continue + + # get weight + w_node = n.args[1] + w_obj = getattr_from_fqn(m, w_node.target).detach() + + # get a quantized version of weight + quant_fn, quant_fn_args_except_first = weight_info + new_args = (w_obj, *quant_fn_args_except_first) + w_obj_q = quant_fn(*new_args) + + # add a comparison + ref_node_name = n.name + prev_node_name = n.name + ref_node_type = get_target_type_str(n, m) + prev_node_type = ref_node_type + fqn = None + if hasattr(m, "_node_name_to_scope"): + fqn = m._node_name_to_scope[n.name][0] # type: ignore[index] + comparison = torch.ao.ns.fx.utils.compute_sqnr(w_obj, w_obj_q) + result_fp32 = { + "res_type": NSSingleResultValuesType.WEIGHT.value, + "values": [w_obj], + "prev_node_name": prev_node_name, + "prev_node_target_type": prev_node_type, + "ref_node_name": ref_node_name, + "ref_node_target_type": ref_node_type, + "index_within_arg": 0, + "index_of_arg": 0, + "fqn": fqn, + "qconfig_str": "", + "comparisons": [comparison], + "comparison_fn_name": "sqnr", + } + result_q = { + "res_type": NSSingleResultValuesType.WEIGHT.value, + "values": [w_obj_q], + "prev_node_name": prev_node_name, + "prev_node_target_type": prev_node_type, + "ref_node_name": ref_node_name, + "ref_node_target_type": ref_node_type, + "index_within_arg": 0, + "index_of_arg": 0, + "fqn": fqn, + "qconfig_str": "", + "comparisons": [comparison], + "comparison_fn_name": "sqnr", + } + + # go from subgraph_n_1 to subgraph_n_0 + _1, _2, node_idx, _3 = shadow_wrapper_node.target.split("_") + name_fp32 = f"subgraph_{node_idx}_0" + name_q = f"subgraph_{node_idx}_1" + + results["model"][NSSingleResultValuesType.WEIGHT.value][name_fp32] = [ + result_fp32 + ] + results["model"][NSSingleResultValuesType.WEIGHT.value][name_q] = [result_q] + + return results + + +# TODO(future PR): redesign this to make it easier to consume outputs +def group_results_by_subgraph(results: NSResultsType) -> Any: + """ + Creates a comparison of results + + Input: + + { + 'model': { + 'node_output': { + 'subgraph_0_0': [ + 'values': [torch.tensor(...), ...], ... + 'ref_node_name': ..., + 'ref_node_target_type': ..., + 'qconfig_str': ..., + 'comparisons': [], ... + 'comparison_fn_name': '', + 'fqn': '...', + ], + 'subgraph_0_1': [ + 'values': [torch.tensor(...), ...], ... + 'ref_node_name': ..., + 'ref_node_target_type': ..., + 'qconfig_str': ..., + 'comparisons': [torch.tensor(...), ...], ... + 'comparison_fn_name': '...', + 'fqn': '...', + ], + ... + }, + }, + } + + Output: + { + 'subgraph_0': { + '0': { + 'ref_node_name': '...', + 'ref_node_target_type': ..., + 'values': [torch.tensor(...), ...], + 'qconfig_str': None, + 'comparisons': [torch.tensor(...), ...], ... + 'comparison_fn_name': '...', + 'fqn': '...', + }, + '1': { + 'ref_node_name': '...', + 'ref_node_target_type': ..., + 'values': [torch.tensor(...), ...], + 'qconfig_str': '...', + 'comparisons': [torch.tensor(...), ...], ... + 'comparison_fn_name': '...', + 'fqn': '...', + }, + }, + } + + """ + subgraph_name_to_subgraph_results: Any = collections.defaultdict(dict) + + # node_output or weight + key_to_use = next(iter(results["model"].keys())) + + for subgraph_name_with_idx, subgraph_candidate_results in results["model"][ + key_to_use + ].items(): + # convert from `subgraph_m_n` to `subgraph_m` and `n` + ( + subgraph_str, + subgraph_idx, + subgraph_candidate_idx, + ) = subgraph_name_with_idx.split("_") + subgraph_name = f"{subgraph_str}_{subgraph_idx}" + + subgraph_results = { + "ref_node_name": subgraph_candidate_results[0]["ref_node_name"], + "ref_node_target_type": subgraph_candidate_results[0][ + "ref_node_target_type" + ], + "fqn": subgraph_candidate_results[0]["fqn"], + "values": subgraph_candidate_results[0]["values"], + "qconfig_str": subgraph_candidate_results[0]["qconfig_str"], + "comparisons": subgraph_candidate_results[0]["comparisons"], + "comparison_fn_name": subgraph_candidate_results[0]["comparison_fn_name"], + } + + subgraph_name_to_subgraph_results[subgraph_name][subgraph_candidate_idx] = ( + subgraph_results + ) + + return dict(subgraph_name_to_subgraph_results) + + +# TODO(future PR): redesign this to make it easier to consume outputs +def create_results_comparison( + results_grouped, +) -> Any: + """ + Input: + + { + 'subgraph_0': { + '0': { + 'ref_node_name': '...', + 'ref_node_target_type': ..., + 'values': [torch.tensor(...), ...], + 'qconfig_str': '', + 'comparisons': [], + 'comparison_fn_name': '', + 'fqn': '...', + }, + '1': { + 'ref_node_name': '...', + 'ref_node_target_type': ..., + 'values': [torch.tensor(...), ...], + 'qconfig_str': '...', + 'comparisons': [torch.tensor(...), ...], + 'comparison_fn_name': 'sqnr', + 'fqn': '...', + }, + }, + } + + Output: + { + 'subgraph_0': { + 'ref_node_name': '...', + 'ref_node_target_type': '...', + 'fqn': '...', + 'candidates': { + '1': { + 'qconfig_str': ..., + 'comparison_fn_name': 'sqnr', + 'cmp_raw': [..., ...], + 'cmp_mean': ..., + }, + ..., + }, + }, + } + """ + + results_comparison = {} + + for subgraph_name, subgraph_results in results_grouped.items(): + candidates = {} + for subgraph_inner_name, subgraph_inner_result in subgraph_results.items(): + # skip comparing baseline to baseline + if subgraph_inner_name == "0": + continue + + # we expect the comparisons to be precalculated from + # calibration, so we just fetch them here + cmp_raw = subgraph_inner_result["comparisons"] + cmp_raw_tensor = torch.stack(cmp_raw) + + candidates[subgraph_inner_name] = { + "qconfig_str": subgraph_inner_result["qconfig_str"], + "comparison_fn_name": subgraph_inner_result["comparison_fn_name"], + "cmp_raw": cmp_raw_tensor, + "cmp_mean": torch.mean(cmp_raw_tensor), + } + + results_comparison[subgraph_name] = { + "ref_node_name": subgraph_results["0"]["ref_node_name"], + "ref_node_target_type": subgraph_results["0"]["ref_node_target_type"], + "fqn": subgraph_results["0"]["fqn"], + "candidates": candidates, + } + + return results_comparison + + +# TODO(future PR): redesign this to make it easier to consume outputs +def print_n_shadows_summary( + results_comparison, +) -> None: + """ + Input: + + { + 'subgraph_0': { + 'ref_node_name': 'linear1', + 'ref_node_target_type': '...', + 'fqn': '...', + 'candidates': { + '1': { + 'qconfig_str': ..., + 'comparison_fn_name': ..., + 'cmp_raw': [45.0, 55.0], + 'cmp_mean': 50.0, + }, + ..., + }, + }, + } + + Prints: + + node_name | node_type | fqn | 0 | 1 | ... + linear1 | ... | ... | 45.0 | 50.0 | ... + """ + + try: + from tabulate import tabulate + except ImportError: + print( + "`print_tabular` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) + return + + results = [] + for subgraph_data in results_comparison.values(): + mean_all_candidates = [ + candidate["cmp_mean"] + for candidate_name, candidate in subgraph_data["candidates"].items() + ] + + data_row = [ + subgraph_data["ref_node_name"], + subgraph_data["ref_node_target_type"], + subgraph_data["fqn"], + *mean_all_candidates, + ] + results.append(data_row) + + max_candidate_idx_len = -1 + for data_row in results: + max_candidate_idx_len = max(max_candidate_idx_len, len(data_row[1])) + candidate_idx_headers = [str(x) for x in range(max_candidate_idx_len)] + + headers = ["node_name", "node_type", "fqn", *candidate_idx_headers] + print(tabulate(results, headers=headers)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/ns_types.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/ns_types.py new file mode 100644 index 0000000000000000000000000000000000000000..134fd485130e0069ab992197ea6e176e1e1e216b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/ns_types.py @@ -0,0 +1,66 @@ +import enum +from collections.abc import Callable +from typing import Any, NamedTuple, Union + +from torch.fx.graph import Node + + +class NSSingleResultValuesType(str, enum.Enum): + WEIGHT = "weight" + NODE_OUTPUT = "node_output" + NODE_INPUT = "node_input" + + +class NSSubgraph(NamedTuple): + start_node: Node + end_node: Node + base_op_node: Node + + +# TODO(future PR): see if we can use typing_extensions's TypedDict instead +# to properly type the various keys +# { +# # one of NSSingleResultValuesType +# 'type': 'weight', +# # the values of type specified above +# 'values': [torch.tensor(...), ...], +# # name of the node directly before the logger +# 'prev_node_name': 'linear1', +# # type of the underlying function or module +# 'prev_node_target_type': torch.nn.functional.linear # or torch.nn.Linear, etc +# # name of the node responsible for adding this logger +# # Note: this may differ from prev_node_name if we are logging inputs +# 'ref_node_name': 'linear1', +# # index of this node within the arg of the input/output node +# # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1 +# 'index_within_arg': 0, +# # index of this node within the args of the input/output node +# # for example, in add(x1, x2), x2 would have index_of_arg == 1 +# 'index_of_arg': 0, +# # precomputed comparisons of logger values to reference values +# 'comparisons': [torch.tensor(...), ...] +# # name of function used for precomputed comparisons +# 'comparison_fn_name': 'sqnr', +# # string representation of qconfig responsible for creating this logger +# 'qconfig_str': 'QConfig(...)', +# } +NSSingleResultType = dict[str, Any] + +# { +# 'layer_name_1': { # subgraph name +# 'node_output': { # results type (node_output, node_input, weight) +# 'model_name_a': # model name +# [NSSingleResultType, ...], # results, ordered by index_within_arg +# 'model_name_b': +# [NSSingleResultType, ...], +# }, +# }, +# } +# +NSResultsType = dict[str, dict[str, dict[str, list[NSSingleResultType]]]] + +# Defines the underlying target type of a node, for example: +# `F.conv1d` for a `call_function` conv node +# `nn.Conv1d` for a `call_module` node calling the forward of a `nn.Conv1d` module +# `'sigmoid'` for a `call_method` node calling `x.sigmoid()` +NSNodeTargetType = Union[Callable, str] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/pattern_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/pattern_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d10fdd39da9080144d3f6ef577d3ca5aca313538 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/pattern_utils.py @@ -0,0 +1,214 @@ +from collections.abc import Callable +from typing import Any, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization import FakeQuantizeBase, ObserverBase +from torch.ao.quantization.backend_config import get_native_backend_config +from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers +from torch.ao.quantization.utils import getattr_from_fqn +from torch.fx import GraphModule +from torch.fx.graph import Node + +from .ns_types import NSNodeTargetType + + +toq = torch.ops.quantized + + +def get_type_a_related_to_b( + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]], +) -> set[tuple[NSNodeTargetType, NSNodeTargetType]]: + # TODO(future PR): allow customizations + # TODO(future PR): reuse existing quantization mappings + # TODO(future PR): add the rest of modules and ops here + type_a_related_to_b: set[tuple[NSNodeTargetType, NSNodeTargetType]] = set() + + for s in base_name_to_sets_of_related_ops.values(): + s_list = list(s) + # add every bidirectional pair + for idx_0 in range(len(s_list)): + for idx_1 in range(idx_0, len(s_list)): + type_a_related_to_b.add((s_list[idx_0], s_list[idx_1])) + type_a_related_to_b.add((s_list[idx_1], s_list[idx_0])) + + return type_a_related_to_b + + +NSFusionElType = Union[ + Callable, # call_function or call_module type, example: F.linear or nn.Conv2d + str, # call_method name, example: "dequantize" + tuple[ + str, Any + ], # call_method name and first argument, example: ("to", torch.float16) +] +NSFusionType = Union[ + tuple[NSFusionElType, NSFusionElType], + tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType], +] + + +def get_reversed_fusions() -> list[tuple[NSFusionType, int]]: + """ + Set of potential fusions, in reverse order. The order is reversed + to match how fusion patterns are defined in quantization code. + + Fusion format: + ((fusion_op_0, fusion_op_1), base_op_idx) + + Where base_op_idx is the idx of the op we should use to match other related + ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx + of 0 represents the first op in regular (non-reverse) order, 1 represents the + second op, etc. + """ + results: list[tuple[NSFusionType, int]] = [] + + # Possible syntaxes: + # * single op: torch.nn.Conv2d + # * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d) + # For fusions, we only care about patterns composed of multiple ops. + # TODO(future PR): allow customizations from default patterns. + all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config()) + + default_base_op_idx = 0 + for quant_pattern in all_quant_patterns: + # TODO: this is a temporary hack to flatten the patterns from quantization so + # that it works with the ns matcher function, maybe we should use `_is_match` + # in torch.ao.quantization.fx.match_utils to match the patterns + if ( + isinstance(quant_pattern, tuple) + and len(quant_pattern) == 2 + and isinstance(quant_pattern[1], tuple) + and len(quant_pattern[1]) == 2 + ): + # flatten the pattern with form (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)) + quant_pattern = (quant_pattern[0], quant_pattern[1][0], quant_pattern[1][1]) + + # Only patterns of multiple ops are fusions, ignore + # patterns which contain a single ops (they get matched + # without caring about fusions). + if isinstance(quant_pattern, tuple): + results.append((quant_pattern, default_base_op_idx)) # type: ignore[arg-type] + + # For each pattern, add additional patterns with observers and + # fake quants at the end. + # TODO(future PR): if needed, implement matching for a node + # having multiple output observers. + for cls in (ObserverBase, FakeQuantizeBase): + if isinstance(quant_pattern, tuple): + new_pattern = (cls, *quant_pattern) + else: + new_pattern = (cls, quant_pattern) + results.append((new_pattern, default_base_op_idx)) # type: ignore[arg-type] + + # After this point, results contains values such as + # [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...] + + # Patterns for matching fp16 emulation are not specified in the quantization + # fusion mappings. For now, define them here. + fp16_em_base_op_idx = 1 + patterns_to_add = [ + # linear-relu fp16 emulation: + # fp16_to_fp32 -> linear -> relu -> fp32_to_fp16 + ( + (("to", torch.float16), F.relu, F.linear, "dequantize"), + fp16_em_base_op_idx, + ), + # Conv-BN fusion (this happens outside of quantization patterns, + # which is why it is defined separately here). + ((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx), + ((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx), + ((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx), + ((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx), + ((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx), + ((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx), + ] + for p in patterns_to_add: + results.append(p) # type: ignore[arg-type] + results.append(((ObserverBase, *p[0]), p[1])) # type: ignore[arg-type] + results.append(((FakeQuantizeBase, *p[0]), p[1])) # type: ignore[arg-type] + + return results + + +def end_node_matches_reversed_fusion( + end_node: Node, + reversed_fusion: NSFusionType, + gm: GraphModule, + seen_nodes: set[Node], +) -> bool: + """ + Returns true if a pattern ending with `end_node` matches + the fusion pattern. + """ + cur_node = end_node + for fusion_idx in range(len(reversed_fusion)): + # each node can only belong to one matched pattern + if cur_node in seen_nodes: + return False + + cur_fusion_el = reversed_fusion[fusion_idx] + + if cur_node.op == "call_function": + fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and ( + not isinstance(cur_fusion_el, type) + ) + if fusion_el_is_fun: + if cur_node.target != cur_fusion_el: + return False + if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): + cur_node = cur_node.args[0] + else: + return False + else: + return False + + elif cur_node.op == "call_module": + fusion_el_is_mod = isinstance(cur_fusion_el, type) + if fusion_el_is_mod: + if not isinstance(cur_node.target, str): + raise AssertionError(f"Expected str, got {type(cur_node.target)}") + target_mod = getattr_from_fqn(gm, cur_node.target) + if not isinstance(cur_fusion_el, type): + return False + if not isinstance(target_mod, cur_fusion_el): + return False + if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): + cur_node = cur_node.args[0] + else: + return False + else: + return False + + elif cur_node.op == "call_method": + fusion_el_is_meth_with_second_arg = ( + isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2 + ) + fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str) + if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg: + if fusion_el_is_meth_without_args: + if cur_node.target != cur_fusion_el: + return False + else: + if not isinstance(cur_fusion_el, tuple): + raise AssertionError( + f"Expected tuple, got {type(cur_fusion_el)}" + ) + if cur_node.target != cur_fusion_el[0]: + return False + elif len(cur_node.args) < 2: + return False + elif cur_node.args[1] != cur_fusion_el[1]: + return False + + if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): + cur_node = cur_node.args[0] + else: + return False + else: + return False + else: + return False + + return True diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..d36914b46929d7eb8311097cd6b5d0fdc0c82f12 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py @@ -0,0 +1,251 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import copy +from typing import Any, TYPE_CHECKING + +import torch +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.ao.quantization.qconfig import QConfigAny + +__all__ = ["QConfigMultiMapping"] + +_QCONFIG_STYLE_TO_METHOD: dict[str, str] = { + "global_qconfig": "set_global", + "object_type_qconfigs": "set_object_type", + "module_name_regex_qconfigs": "set_module_name_regex", + "module_name_qconfigs": "set_module_name", + "module_name_object_type_order_qconfigs": "set_module_name_object_type_order", +} + + +def _remove_duplicates_and_none(qconfig_list: list[QConfigAny]) -> None: + to_remove = [] + for index, cur_qconfig in enumerate(qconfig_list): + if cur_qconfig is None: + to_remove.append(index) + break + for checked_qconfig in qconfig_list[:index]: + if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig): + to_remove.append(index) + break + for index in to_remove[::-1]: + qconfig_list.pop(index) + + +class QConfigMultiMapping: + """ + This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s + so that multiple QConfigs can be specified for each QConfig matching style. + + The user can specify QConfigs using the following methods (in increasing match priority): + + ``set_global`` : sets the global (default) QConfigs + + ``set_object_type`` : sets the QConfigs for a given module type, function, or method name + + ``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string + + ``set_module_name`` : sets the QConfigs for modules matching the given module name + + ``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination + of the given module name, object type, and the index at which the module appears + + Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a + single QConfig. + + Example usage:: + + qconfig_mapping = QConfigMultiMapping() + .set_global([qconfig1, qconfig2]) + .set_object_type(torch.nn.Linear, [qconfig2, qconfig3]) + .set_object_type(torch.nn.ReLU, [qconfig1]) + .set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2]) + .set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3]) + .set_module_name("module1", [None]) + .set_module_name("module2", [qconfig2]) + .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3]) + + """ + + def __init__(self) -> None: + # initialize this with 1 QConfigMapping to avoid corner cases + self.qconfig_mappings_list: list[QConfigMapping] = [QConfigMapping()] + + def _handle_list_size_mismatch( + self, qconfig_list: list[QConfigAny], style: str + ) -> None: + # this method handles cases where the size of qconfig_list does not match + # the size of qconfig_mappings_list. + # Issue: Consider a user inserting global_qconfig A and B first, then inserting + # qconfig C as an object_type_qconfig for conv ops. If we internally store + # 1 QConfigMapping with A and C and another with just B, then the + # second QConfigMapping will match B to conv ops (which is not wanted), since B is global. + + # we avoid this by maintaining the invariant that if any QConfigMapping + # has a qconfig style+key with a qconfig in it, all QConfigMappings must + # have either a qconfig or None for that same style+key. In the above + # example, a None qconfig would prevent the unwanted match in the + # second QConfigMapping + + if len(qconfig_list) > len(self.qconfig_mappings_list): + # Case: we have more qconfigs (in qconfig_list) than QConfigMappings + + # Add new QConfigMappings (initialized so we maintain the `invariant`) + + new_qconfig_mapping = QConfigMapping() + # searches other QConfigMappings for qconfig style+keys + # that need to be inserted as `None` into the new QConfigMapping + for qconfig_mapping in self.qconfig_mappings_list: + # global_qconfig has None by default + for check_style in _QCONFIG_STYLE_ORDER[1:]: + qconfigs_dict = getattr(qconfig_mapping, check_style) + target_qconfigs_dict = getattr(new_qconfig_mapping, check_style) + for key in qconfigs_dict: + target_qconfigs_dict[key] = None + break + + # insert copies of this new QConfigMapping until all entries + # in qconfig_list can fit among the QConfigMappings + while len(qconfig_list) > len(self.qconfig_mappings_list): + self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping)) + else: + # Case: we have fewer qconfigs in qconfig_list than QConfigMappings + + # pad qconfig_list with `None` until length is same + while len(qconfig_list) < len(self.qconfig_mappings_list): + qconfig_list.append(None) + + # this function applies the insertion method across each QConfigMapping + def _insert_qconfig_list( + self, + style: str, + args: list[str | int | Callable], + qconfig_list: list[QConfigAny], + ) -> None: + # we remove duplicates and None to make the ordering of qconfigs + # deterministic upon insertion. + _remove_duplicates_and_none(qconfig_list) + + self._handle_list_size_mismatch(qconfig_list, style) + method_name = _QCONFIG_STYLE_TO_METHOD[style] + for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list): + # uses QConfigMapping set method to insert qconfig + set_method = getattr(qconfig_mapping, method_name) + set_method(*args, qconfig) + + def set_global(self, global_qconfig_list: list[QConfigAny]) -> QConfigMultiMapping: + """ + Set global QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info + """ + self._insert_qconfig_list("global_qconfig", [], global_qconfig_list) + return self + + def set_object_type( + self, object_type: Callable | str, qconfig_list: list[QConfigAny] + ) -> QConfigMultiMapping: + """ + Set object type QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info + """ + self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list) + return self + + def set_module_name_regex( + self, module_name_regex: str, qconfig_list: list[QConfigAny] + ) -> QConfigMultiMapping: + """ + Set module_name_regex QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info + """ + self._insert_qconfig_list( + "module_name_regex_qconfigs", [module_name_regex], qconfig_list + ) + return self + + def set_module_name( + self, module_name: str, qconfig_list: list[QConfigAny] + ) -> QConfigMultiMapping: + """ + Set module_name QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info + """ + self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list) + return self + + def set_module_name_object_type_order( + self, + module_name: str, + object_type: Callable, + index: int, + qconfig_list: list[QConfigAny], + ) -> QConfigMultiMapping: + """ + Set module_name QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info + """ + self._insert_qconfig_list( + "module_name_object_type_order_qconfigs", + [module_name, object_type, index], + qconfig_list, + ) + return self + + def __repr__(self): + return ( + self.__class__.__name__ + + " [" + + "".join( + f"\n{qconfig_mapping.__repr__()}," + for qconfig_mapping in self.qconfig_mappings_list + ) + + "\n]" + ) + + @classmethod + def from_list_qconfig_mapping( + cls, qconfig_mapping_list: list[QConfigMapping] + ) -> QConfigMultiMapping: + """ + Creates a QConfigMultiMapping from a list of QConfigMappings + """ + new_qconfig_multi_mapping = cls() + + new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy( + qconfig_mapping_list + ) + + # we need to avoid the issue described in _handle_list_size_mismatch, + # so we reinsert all the qconfigs using the QConfigMultiMapping + # set methods + + # go through all qconfig styles + # note: global can be ignored since it is None by default + for style in _QCONFIG_STYLE_ORDER[1:]: + # gather all key+qconfigs for current style + # into qconfig_dict_list + qconfig_dict_list: dict[Any, list[QConfigAny]] = {} + for qconfig_mapping in qconfig_mapping_list: + qconfig_dict = getattr(qconfig_mapping, style) + for key, qconfig in qconfig_dict.items(): + if key not in qconfig_dict_list: + qconfig_dict_list[key] = [] + qconfig_dict_list[key].append(qconfig) + + # reinsert all gathered key+qconfigs + set_method_name = _QCONFIG_STYLE_TO_METHOD[style] + set_method = getattr(new_qconfig_multi_mapping, set_method_name) + for key, qconfig_list in qconfig_dict_list.items(): + if isinstance(key, tuple): + set_method(*key, qconfig_list) + else: + set_method(key, qconfig_list) + + return new_qconfig_multi_mapping diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..93e72ae2fd4b64ae1b529e06bb8af988a747f690 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/utils.py @@ -0,0 +1,579 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import enum +import operator +from collections.abc import Callable + +import torch +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.quantized as nnq +import torch.nn as nn +from torch.ao.quantization import FakeQuantizeBase, ObserverBase +from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.utils import getattr_from_fqn +from torch.fx import GraphModule +from torch.fx.graph import Node + +from .ns_types import NSNodeTargetType, NSResultsType + + +toq = torch.ops.quantized + + +# TODO(future PR): consider deleting this enum and using the torch types +# directly. This might be tricky because it is not a one to one mapping. +class NodeInputOrOutputType(enum.Enum): + FP32 = enum.auto() # torch.float + INT8 = enum.auto() # torch.qint8 or torch.quint8 + FP16 = enum.auto() # torch.float16 + UNKNOWN = enum.auto() # we cannot determine input/output dtype + # TODO(future PR): while these functions can support multiple dtypes, + # for the purposes of numerical debugging we want to get the actual + # dtype used in the model. We will likely need some kind of dtype + # propagation to estimate this. + FP32_OR_INT8 = enum.auto() # either torch.float or torch.quint8 or torch.qint8 + # TODO(future PRs): dynamic quant, fake quant, etc + + +def get_node_first_input_and_output_type( + node: Node, + gm: GraphModule, + logger_cls: Callable, + node_type_to_io_type_map: dict[str, set[NSNodeTargetType]], +) -> tuple[NodeInputOrOutputType, NodeInputOrOutputType]: + # TODO(future PR): clean this up + FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"] + FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"] + FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"] + FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"] + MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"] + MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"] + MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"] + METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"] + + if node.op == "call_function": + if node.target in FUNS_IO_TYPE_FP32: + return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) + if node.target in FUNS_IO_TYPE_FP16: + return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16) + elif node.target in FUNS_IO_TYPE_INT8: + return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8) + elif node.target in FUNS_IO_TYPE_FP32_OR_INT8: + first_arg = get_normalized_nth_input(node, gm, 0) + if not isinstance(first_arg, Node): + raise AssertionError(f"Expected Node, got {type(first_arg)}") + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + first_arg, gm, logger_cls, node_type_to_io_type_map + ) + return (prev_node_output_type, prev_node_output_type) + else: + return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) + + elif node.op == "call_module": + if node.op != "call_module": + raise AssertionError(f"Expected call_module, got '{node.op}'") + if not isinstance(node.target, str): + raise AssertionError(f"Expected str, but got {type(node.target)}") + + mod = getattr_from_fqn(gm, node.target) + is_known_fp32_or_int8_input_module = any( + isinstance(mod, target_type) # type: ignore[arg-type] + for target_type in MODS_IO_TYPE_FP32_OR_INT8 + ) + if ( + isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)) # type: ignore[arg-type] + or is_known_fp32_or_int8_input_module + ): + # A logger or observer's input and output type is the output + # type of the preceding node. + first_arg = get_normalized_nth_input(node, gm, 0) + if not isinstance(first_arg, Node): + raise AssertionError(f"Expected Node, got {type(first_arg)}") + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + first_arg, gm, logger_cls, node_type_to_io_type_map + ) + return (prev_node_output_type, prev_node_output_type) + is_known_fp32_input_module = any( + isinstance(mod, target_type) # type: ignore[arg-type] + for target_type in MODS_IO_TYPE_FP32 + ) + is_known_int8_input_module = any( + isinstance(mod, target_type) # type: ignore[arg-type] + for target_type in MODS_IO_TYPE_INT8 + ) + if is_known_fp32_input_module: + return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) + elif is_known_int8_input_module: + return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8) + else: + return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) + + elif node.op == "call_method": + if node.target == "dequantize": + # Dequantize is a special node because it allows multiple input types. + # So, we look up the output type of the previous node and return that + # as the input type of this node instance. + prev_node = get_normalized_nth_input(node, gm, 0) + if not isinstance(prev_node, Node): + raise AssertionError(f"Expected Node, got {type(prev_node)}") + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + prev_node, gm, logger_cls, node_type_to_io_type_map + ) + return (prev_node_output_type, NodeInputOrOutputType.FP32) + + elif node.target == "to": + # to is a special node because it allows multiple input types. + # So, we look up the output type of the previous node and return that + # as the input type of this node instance. We also look up the target + # of to and return the correct output type. + prev_node = get_normalized_nth_input(node, gm, 0) + if not isinstance(prev_node, Node): + raise AssertionError(f"Expected Node, got {type(prev_node)}") + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + prev_node, gm, logger_cls, node_type_to_io_type_map + ) + + cur_node_dtype_target = get_normalized_nth_input(node, gm, 1) + if cur_node_dtype_target is not torch.float16: + raise AssertionError( + f"{cur_node_dtype_target} handling needs to be added" + ) + + return (prev_node_output_type, NodeInputOrOutputType.FP16) + + elif node.target in METHS_IO_TYPE_FP32_OR_INT8: + first_arg = get_normalized_nth_input(node, gm, 0) + if not isinstance(first_arg, Node): + raise AssertionError(f"Expected Node, got {type(first_arg)}") + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + first_arg, gm, logger_cls, node_type_to_io_type_map + ) + return (prev_node_output_type, prev_node_output_type) + + return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) + else: + return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) + + +def get_node_input_qparams( + node: Node, + gm: GraphModule, + node_type_to_io_type_map: dict[str, set[NSNodeTargetType]], +) -> tuple[torch.Tensor | float, torch.Tensor | int] | None: + """ + Returns the qparams (scale, zero_point) of the first input to `node`, + if they can be inferred from the graph. + """ + prev_node = get_normalized_nth_input(node, gm, 0) + + if not isinstance(prev_node, Node): + return None + + MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"] + + def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx): + scale_node = get_normalized_nth_input(node, gm, scale_arg_idx) + zp_node = get_normalized_nth_input(node, gm, zp_arg_idx) + if not isinstance(scale_node, Node): + raise AssertionError(f"Expected Node, got {type(scale_node)}") + if not isinstance(scale_node.target, str): + raise AssertionError(f"Expected str, got {type(scale_node.target)}") + if not isinstance(zp_node, Node): + raise AssertionError(f"Expected Node, got {type(zp_node)}") + if not isinstance(zp_node.target, str): + raise AssertionError(f"Expected str, got {type(zp_node.target)}") + scale_obj = getattr_from_fqn(gm, scale_node.target) + zp_obj = getattr_from_fqn(gm, zp_node.target) + return (scale_obj, zp_obj) + + if prev_node.op == "call_function": + # quantize - read the args directly + if prev_node.target is torch.quantize_per_tensor: + return _get_scale_zp_from_function_args(prev_node, gm, 1, 2) + elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu): + return _get_scale_zp_from_function_args(prev_node, gm, 2, 3) + + return None + # TODO(future PR): handle more functionals + # TODO(future PR): handle functional ops which inherit qparams from input + + elif prev_node.op == "call_module": + # get type of the module + if not isinstance(prev_node.target, str): + raise AssertionError(f"Expected str, got {type(prev_node.target)}") + module_obj = getattr_from_fqn(gm, prev_node.target) + if isinstance( + module_obj, + ( + nnq.Linear, + nnq.Conv1d, + nnq.Conv2d, + nniq.ConvReLU2d, + nnq.Conv3d, + nnq.BatchNorm2d, + nnq.BatchNorm3d, + nnq.ConvTranspose1d, + nnq.ConvTranspose2d, + nnq.ELU, + nnq.GroupNorm, + nnq.InstanceNorm1d, + nnq.InstanceNorm2d, + nnq.InstanceNorm3d, + nnq.LayerNorm, + nnq.Hardswish, + nnq.LeakyReLU, + nnq.ReLU6, + nniq.BNReLU2d, + nniq.BNReLU3d, + nniq.ConvReLU1d, + nniq.ConvReLU2d, + nniq.ConvReLU3d, + nniq.LinearReLU, + ), + ): + return (module_obj.scale, module_obj.zero_point) # type: ignore[return-value] + + is_known_fp32_or_int8_input_module = any( + isinstance(module_obj, target_type) # type: ignore[arg-type] + for target_type in MODS_IO_TYPE_FP32_OR_INT8 + ) + if is_known_fp32_or_int8_input_module: + return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map) + + return None + + +def return_first_non_observer_node( + node: Node, + gm: GraphModule, +) -> Node: + """ + If node is not an observer, returns it. If node is an observer, + navigates up the graph and returns the first parent which is not an + observer. For example, + + graph: (node_non_obs), node = node_non_obs : returns node_non_obs + graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs + graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs + """ + if node.op == "call_module": + node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] + if _is_activation_post_process(node_obj): + if len(node.args) != 1: + raise AssertionError( + f"Expected node.args to have length 1, got {len(node.args)}" + ) + if not isinstance(node.args[0], Node): + raise AssertionError(f"Expected Node, got {type(node.args[0])}") + node = node.args[0] + # code duplication intended, not worth refactoring + if not isinstance(node.target, str): + raise AssertionError(f"Expected str, got {type(node.target)}") + node_obj = getattr_from_fqn(gm, node.target) + if _is_activation_post_process(node_obj): + if len(node.args) != 1: + raise AssertionError( + f"Expected node.args to have length 1, got {len(node.args)}" + ) + if not isinstance(node.args[0], Node): + raise AssertionError(f"Expected Node, got {type(node.args[0])}") + node = node.args[0] + return node + + +def get_number_of_non_param_args( + node: Node, + gm: GraphModule, +) -> int: + """ + Assumes that all non-param args occur first. Returns the number of + non-param args expected for a node. For example, for + + F.linear(x, weight, bias) + + Returns 1, because x is a non-param arg and weight and bias are params. + For + + lstm_mod(x, hid) + + Returns 2, because both x and hid are non-param args. + """ + if node.op == "call_module": + node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] + if isinstance(node_obj, nn.LSTM): + return 2 + + # default is 1 + return 1 + + +def get_arg_indices_of_inputs_to_log(node: Node) -> list[int]: + """ + Returns the indices of args of the node which we should attach + loggers to, if input logging is enabled. + + For example, + * for (x + y), returns [0, 1] + * for (1 + y), returns [1] + * for (x + 1), returns [0] + * for (linear(x, w, b)) returns [0] + * by default, returns [0] + """ + if len(node.args) == 0: + return [] + if node.op == "call_function" and ( + # TODO(future PR): use relationship map instead of hardcoding + node.target in (torch.add, torch.ops.quantized.add, operator.add) + or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul) + ): + result = [i for i in range(2) if type(node.args[i]) is Node] + return result + return [0] + + +def get_target_type_str(node: Node, gm: GraphModule) -> str: + """ + Returns a string representation of the type of the function or module + pointed to by this node, or '' for other node types. + """ + target_type = "" + if node.op in ("call_function", "call_method"): + target_type = torch.typename(node.target) + elif node.op == "call_module": + if not isinstance(node.target, str): + raise AssertionError(f"Expected str, got {type(node.target)}") + target_mod = getattr_from_fqn(gm, node.target) + target_type = torch.typename(target_mod) + return target_type + + +def rekey_logger_info_on_node_name_of_model( + results: NSResultsType, + model_name: str, +) -> NSResultsType: + """ + Rekeys the layer name of a results dictionary to use node names + from `model_name`. + + For example, transforms + + {'base_op_1_0': {'node_output': {'model_a': + [{'ref_node_name': 'linear1', ...}]}}} + + into + + {'linear1': {'node_output': {'model_a': + [{'ref_node_name': 'linear1', ...}]}}} + + Note: we cannot use these node names directly because they are not + guaranteed to be consistent across models. This is why we extract + the results first and rekey afterwards. + """ + new_results = {} + for old_layer_name, result_type_to_results in results.items(): + new_layer_name = None + for model_name_to_results in result_type_to_results.values(): + for cur_model_name, list_of_results in model_name_to_results.items(): + if cur_model_name == model_name: + if len(list_of_results) == 0: + raise AssertionError("Expected list_of_results to be not empty") + new_layer_name = list_of_results[0]["ref_node_name"] + else: + continue + if new_layer_name is not None: + new_results[new_layer_name] = result_type_to_results + else: + new_results[old_layer_name] = result_type_to_results + return new_results + + +def maybe_add_missing_fqns(results: NSResultsType) -> None: + """ + If `fqn` entries are filled in for one of the models in `results`, copies + them over to any models which do not have them filled out. + + A common use case benefitting from this is comparing a model prepared by + quantization to a quantized model. In this case, the model prepared by + quantization would have `fqn` entries, and the quantized model would not. + """ + + # Check in the first result to find any model with fqn entries defined. + model_name_with_fqns = None + for result_type_to_results in results.values(): + for model_name_to_results in result_type_to_results.values(): + for model_name, model_results in model_name_to_results.items(): + if len(model_results) > 0: + if model_results[0]["fqn"] is not None: + model_name_with_fqns = model_name + break + break + break + + if model_name_with_fqns: + for result_type_to_results in results.values(): + for model_name_to_results in result_type_to_results.values(): + ref_model_results = model_name_to_results[model_name_with_fqns] + for model_name, model_results in model_name_to_results.items(): + if model_name == model_name_with_fqns: + continue + + for i in range(len(model_results)): + fqn = ref_model_results[i]["fqn"] + model_results[i]["fqn"] = fqn + + +def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f): + def inner(*args, **kwargs): + a0, a1, *a_other = args + + if (isinstance(a0, tuple) and isinstance(a1, tuple)) or ( + isinstance(a0, list) and isinstance(a1, list) + ): + results = [] + for el0, el1 in zip(a0, a1): + new_args = (el0, el1, *a_other) + results.append(inner(*new_args, **kwargs)) + return results + + elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor): + if a0.is_quantized: + a0 = a0.dequantize() + if a1.is_quantized: + a1 = a1.dequantize() + + # for the purposes of this util, only handle floats + if a0.dtype != torch.float or a1.dtype != torch.float: + return None + + new_args = (a0, a1, *a_other) + return f(*new_args, **kwargs) + + return inner + + +@maybe_dequantize_first_two_tensor_args_and_handle_tuples +def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the SQNR between `x` and `y`. + + Args: + x: Tensor or tuple of tensors + y: Tensor or tuple of tensors + + Return: + float or tuple of floats + """ + Ps = torch.norm(x) + Pn = torch.norm(x - y) + return 20 * torch.log10(Ps / Pn) + + +@maybe_dequantize_first_two_tensor_args_and_handle_tuples +def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the normalized L2 error between `x` and `y`. + + Args: + x: Tensor or tuple of tensors + y: Tensor or tuple of tensors + + Return: + float or tuple of floats + """ + # pyrefly: ignore [unsupported-operation] + return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum()) + + +@maybe_dequantize_first_two_tensor_args_and_handle_tuples +def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the cosine similarity between `x` and `y`. + + Args: + x: Tensor or tuple of tensors + y: Tensor or tuple of tensors + + Return: + float or tuple of floats + """ + # For convolutions, the shape of the quantized weight has one additional + # dimension compared to the shape of the fp32 weight. Match the shapes + # to enable cosine similarity comparison. + x = x.reshape(1, -1) + y = y.reshape(1, -1) + return torch.nn.functional.cosine_similarity(x, y) + + +def op_type_supports_shadowing(node: Node) -> bool: + if node.op == "call_function": + if node.target in ( + torch.add, + torch.mul, + operator.add, + operator.mul, + torch.cat, + torch.stack, + ): + # shadowing for ops with multiple tensor inputs is not implemented yet + return False + return True + + +def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node: + """ + Given a node, gets the n'th input to that node, normalizing + args and kwargs to the best of its ability. + """ + try: + norm_args_and_kwargs = node.normalized_arguments( + gm, normalize_to_only_use_kwargs=True + ) + if norm_args_and_kwargs is not None: + norm_args, norm_kwargs = norm_args_and_kwargs + if len(norm_args) + len(norm_kwargs) <= idx: + raise AssertionError( + f"Index {idx} out of range: total = {len(norm_args) + len(norm_kwargs)}" + ) + if idx < len(norm_args): + return norm_args[idx] + else: + # note: in Python 3.7+ dicts are ordered + return list(norm_kwargs.values())[idx] + else: + if len(node.args) + len(node.kwargs) <= idx: + raise AssertionError( + f"Index {idx} out of range: total = {len(node.args) + len(node.kwargs)}" + ) + if idx < len(node.args): + return node.args[idx] # type: ignore[return-value] + else: + kwargs_idx = idx + len(node.args) + return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value] + except RuntimeError: + # this RuntimeError happens when node argument normalization + # requires typehints to proceed, such as for torch.add where + # either the first, second or both arguments could be tensors + if len(node.args) + len(node.kwargs) <= idx: + raise AssertionError( + f"Index {idx} out of range: total = {len(node.args) + len(node.kwargs)}" + ) from None + if idx < len(node.args): + return node.args[idx] # type: ignore[return-value] + else: + kwargs_idx = idx + len(node.args) + return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/weight_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/weight_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6bff44215e46174856918883f35aac92b4491c25 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/fx/weight_utils.py @@ -0,0 +1,302 @@ +from collections.abc import Callable + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.qat as nnqat +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.nn as nn +import torch.nn.functional as F +from torch.fx import GraphModule +from torch.fx.graph import Node + +from .ns_types import NSSingleResultType, NSSingleResultValuesType +from .utils import get_target_type_str, getattr_from_fqn, return_first_non_observer_node + + +toq = torch.ops.quantized + + +def mod_weight_detach(mod: nn.Module) -> torch.Tensor: + return mod.weight.detach() # type: ignore[operator] + + +def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor: + return mod[0].weight.detach() # type: ignore[index] + + +def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor: + return mod._weight_bias()[0] # type: ignore[operator] + + +def get_lstm_weight(mod: nn.Module) -> list[torch.Tensor]: + res = [] + for idx, param_name in enumerate(mod._flat_weights_names): # type: ignore[arg-type] + if "weight_ih_l" in param_name or "weight_hh_l" in param_name: + param_value = mod._flat_weights[idx].detach() # type: ignore[index,union-attr] + res.append(param_value) + return res + + +def get_qlstm_weight(mod: nn.Module) -> list[torch.Tensor]: + res = [] + for weight_value in mod._all_weight_values: # type: ignore[union-attr] + res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0]) + res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0]) + return res + + +def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor: + if isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + return mod.weight.detach() + elif isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d)): + return mod[0].weight.detach() # type: ignore[operator] + else: + return mod._weight_bias()[0] # type: ignore[operator] + + +def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor: + if isinstance(mod, nn.Linear): + return mod.weight.detach() + elif isinstance(mod, nni.LinearReLU): + return mod[0].weight.detach() # type: ignore[operator] + else: + return mod._weight_bias()[0] # type: ignore[operator] + + +def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]: + # TODO(future PR): make more generic, handle everything + if isinstance(mod, nn.LSTM): + res = [] + for idx, param_name in enumerate(mod._flat_weights_names): + if "weight_ih_l" in param_name or "weight_hh_l" in param_name: + param_value = mod._flat_weights[idx].detach() # type: ignore[index,union-attr] + res.append(param_value) + return res + else: + if not isinstance(mod, nnqd.LSTM): + raise AssertionError(f"type {type(mod)} not handled yet") + res = [] + for weight_value in mod._all_weight_values: + res.append( + weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0] # type: ignore[index] + ) + res.append( + weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0] # type: ignore[index] + ) + return res + + +def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: + # traverse backwards from the weight arg, accounting for any observers + weight_arg_node = node.args[1] + if not isinstance(weight_arg_node, Node): + raise AssertionError(f"Expected Node, got {type(weight_arg_node)}") + weight_node = return_first_non_observer_node(weight_arg_node, gm) + if not isinstance(weight_node, Node): + raise AssertionError(f"Expected Node, got {type(weight_node)}") + if weight_node.op != "get_attr": + raise AssertionError(f"Expected get_attr, got {weight_node.op}") + weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] + return weight.detach() + + +def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: + # qconv state is arg 1 + qconv_state_node = node.args[1] + if not isinstance(qconv_state_node, Node): + raise AssertionError(f"Expected Node, got {type(qconv_state_node)}") + if qconv_state_node.op != "get_attr": + raise AssertionError(f"Expected get_attr, got {qconv_state_node.op}") + qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type] + return qconv_state_obj.weight() + + +def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: + # traverse backwards from the weight arg, accounting for any observers + # supported patterns: + # weight -> obs -> linear + # weight -> to(torch.float16) -> dequantize -> linear + linear_second_arg = node.args[1] + if not isinstance(linear_second_arg, Node): + raise AssertionError(f"Expected Node, got {type(linear_second_arg)}") + + if linear_second_arg.op == "call_module": + # weight -> obs -> linear + weight_arg_node = node.args[1] + if not isinstance(weight_arg_node, Node): + raise AssertionError(f"Expected Node, got {type(weight_arg_node)}") + weight_node = weight_arg_node.args[0] + if not isinstance(weight_node, Node): + raise AssertionError(f"Expected Node, got {type(weight_node)}") + if weight_node.op != "get_attr": + raise AssertionError(f"Expected get_attr, got {weight_node.op}") + weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] + return weight.detach() + elif linear_second_arg.op == "call_method": + # weight -> to(torch.float16) -> dequantize -> linear + if linear_second_arg.op != "call_method": + raise AssertionError(f"Expected call_method, got {linear_second_arg.op}") + dequant_node = node.args[1] + if not isinstance(dequant_node, Node): + raise AssertionError(f"Expected Node, got {type(dequant_node)}") + to_fp16_node = dequant_node.args[0] + if not isinstance(to_fp16_node, Node): + raise AssertionError(f"Expected Node, got {type(to_fp16_node)}") + # extract the dtype, so we can cast to it before returning + target_dtype = to_fp16_node.args[1] + weight_node = to_fp16_node.args[0] + if not isinstance(weight_node, Node): + raise AssertionError(f"Expected Node, got {type(weight_node)}") + if weight_node.op != "get_attr": + raise AssertionError(f"Expected get_attr, got {weight_node.op}") + weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] + # return the weight with fp16 cast + return weight.detach().to(target_dtype) + else: + if linear_second_arg.op != "get_attr": + raise AssertionError(f"Expected get_attr, got {linear_second_arg.op}") + weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type] + return weight.detach() + + +def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: + # packed weight is arg 1 + packed_weight_node = node.args[1] + if not isinstance(packed_weight_node, Node): + raise AssertionError(f"Expected Node, got {type(packed_weight_node)}") + if packed_weight_node.op != "get_attr": + raise AssertionError(f"Expected get_attr, got {packed_weight_node.op}") + packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type] + # TODO(future PR): why does packed_weight.unpack() not work? + (weight, _bias), _name = packed_weight.__getstate__() + return weight + + +def get_op_to_type_to_weight_extraction_fn() -> dict[str, dict[Callable, Callable]]: + op_to_type_to_weight_extraction_fn: dict[str, dict[Callable, Callable]] = { + "call_module": { + # Conv1d + nn.Conv1d: mod_weight_detach, + nni.ConvReLU1d: mod_0_weight_detach, + nnq.Conv1d: mod_weight_bias_0, + nnqat.Conv1d: mod_weight_detach, + nniqat.ConvBn1d: mod_weight_detach, + nniqat.ConvBnReLU1d: mod_weight_detach, + nniqat.ConvReLU1d: mod_weight_detach, + nniq.ConvReLU1d: mod_weight_bias_0, + # Conv2d + nn.Conv2d: mod_weight_detach, + nni.ConvReLU2d: mod_0_weight_detach, + nnq.Conv2d: mod_weight_bias_0, + nnqat.Conv2d: mod_weight_detach, + nniqat.ConvBn2d: mod_weight_detach, + nniqat.ConvBnReLU2d: mod_weight_detach, + nniqat.ConvReLU2d: mod_weight_detach, + nniq.ConvReLU2d: mod_weight_bias_0, + # Conv3d + nn.Conv3d: mod_weight_detach, + nni.ConvReLU3d: mod_0_weight_detach, + nnq.Conv3d: mod_weight_bias_0, + nnqat.Conv3d: mod_weight_detach, + nniqat.ConvBn3d: mod_weight_detach, + nniqat.ConvBnReLU3d: mod_weight_detach, + nniqat.ConvReLU3d: mod_weight_detach, + nniq.ConvReLU3d: mod_weight_bias_0, + # Linear + nn.Linear: mod_weight_detach, + nnq.Linear: mod_weight_bias_0, + nni.LinearReLU: mod_0_weight_detach, + nniq.LinearReLU: mod_weight_bias_0, + nnqat.Linear: mod_weight_detach, + nnqd.Linear: mod_weight_bias_0, + nniqat.LinearReLU: mod_weight_detach, + nniqat.LinearBn1d: mod_weight_detach, + nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach, + # LSTM + nn.LSTM: get_lstm_weight, + nnqd.LSTM: get_qlstm_weight, + }, + "call_function": { + # Conv + F.conv1d: get_conv_fun_weight, + F.conv2d: get_conv_fun_weight, + F.conv3d: get_conv_fun_weight, + toq.conv1d: get_qconv_fun_weight, + toq.conv2d: get_qconv_fun_weight, + toq.conv3d: get_qconv_fun_weight, + toq.conv1d_relu: get_qconv_fun_weight, + toq.conv2d_relu: get_qconv_fun_weight, + toq.conv3d_relu: get_qconv_fun_weight, + # Linear + F.linear: get_linear_fun_weight, + toq.linear: get_qlinear_fun_weight, + toq.linear_relu: get_qlinear_fun_weight, + }, + } + + return op_to_type_to_weight_extraction_fn + + +def extract_weight_from_node( + node: Node, + gm: GraphModule, + op_to_type_to_weight_extraction_fn: dict[str, dict[Callable, Callable]] + | None = None, +) -> NSSingleResultType | None: + res_type = NSSingleResultValuesType.WEIGHT.value + + # Not all graphmodules have _node_name_to_scope, so only fill it + # out if it exists. + fqn = None + if hasattr(gm, "_node_name_to_scope"): + fqn = gm._node_name_to_scope[node.name][0] # type: ignore[index] + + if op_to_type_to_weight_extraction_fn is None: + op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn() + + ref_node_type = get_target_type_str(node, gm) + # for extracting weights, these are always the same + prev_node_type = ref_node_type + + if node.op == "call_function": + function_mapping = op_to_type_to_weight_extraction_fn["call_function"] + for target_fn_type, weight_extraction_fn in function_mapping.items(): + if node.target == target_fn_type: + weight = weight_extraction_fn(node, gm) + return { + "type": res_type, + "values": [weight], + "prev_node_name": node.name, + "prev_node_target_type": prev_node_type, + "ref_node_name": node.name, + "ref_node_target_type": ref_node_type, + "index_within_arg": 0, + "index_of_arg": 0, + "fqn": fqn, + } + + elif node.op == "call_module": + # for call_module, we need to look up the modules to do the type check + if not isinstance(node.target, str): + raise AssertionError(f"Expected str, got {type(node.target)}") + mod = getattr_from_fqn(gm, node.target) + module_mapping = op_to_type_to_weight_extraction_fn["call_module"] + for target_mod_type, weight_extraction_fn in module_mapping.items(): + if type(mod) is target_mod_type: + weight = weight_extraction_fn(mod) + return { + "type": res_type, + "values": [weight], + "prev_node_name": node.name, + "prev_node_target_type": prev_node_type, + "ref_node_name": node.name, + "ref_node_target_type": ref_node_type, + "index_within_arg": 0, + "index_of_arg": 0, + "fqn": fqn, + } + + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d77ea1712ff48129dd0d5d9bfe2aed532b70070d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6db8321abed775b7ca6a0672513be7602439fb56 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/activation_sparsifier.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/activation_sparsifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f64ae77bf9a72cb147e2406c965b6a377b118b5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/activation_sparsifier.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..241b4e70e8196e66b471e3855033e55e3e426249 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py @@ -0,0 +1,482 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from collections import defaultdict +from typing import Any + +import torch +from torch import nn +from torch.ao.pruning.sparsifier.utils import fqn_to_module, module_to_fqn + + +__all__ = ["ActivationSparsifier"] + + +class ActivationSparsifier: + r""" + The Activation sparsifier class aims to sparsify/prune activations in a neural + network. The idea is to attach the sparsifier to a layer (or layers) and it + zeroes out the activations based on the mask_fn (or sparsification function) + input by the user. + The mask_fn is applied once all the inputs are aggregated and reduced i.e. + mask = mask_fn(reduce_fn(aggregate_fn(activations))) + + Note:: + The sparsification mask is computed on the input **before it goes through the attached layer**. + + Args: + model (nn.Module): + The model whose layers will be sparsified. The layers that needs to be + sparsified should be added separately using the register_layer() function + aggregate_fn (Optional, Callable): + default aggregate_fn that is used if not specified while registering the layer. + specifies how inputs should be aggregated over time. + The aggregate_fn should usually take 2 torch tensors and return the aggregated tensor. + Example + def add_agg_fn(tensor1, tensor2): return tensor1 + tensor2 + reduce_fn (Optional, Callable): + default reduce_fn that is used if not specified while registering the layer. + reduce_fn will be called on the aggregated tensor i.e. the tensor obtained after + calling agg_fn() on all inputs. + Example + def mean_reduce_fn(agg_tensor): return agg_tensor.mean(dim=0) + mask_fn (Optional, Callable): + default mask_fn that is used to create the sparsification mask using the tensor obtained after + calling the reduce_fn(). This is used by default if a custom one is passed in the + register_layer(). + Note that the mask_fn() definition should contain the sparse arguments that is passed in sparse_config + arguments. + features (Optional, list): + default selected features to sparsify. + If this is non-empty, then the mask_fn will be applied for each feature of the input. + For example, + mask = [mask_fn(reduce_fn(aggregated_fn(input[feature])) for feature in features] + feature_dim (Optional, int): + default dimension of input features. Again, features along this dim will be chosen + for sparsification. + sparse_config (Dict): + Default configuration for the mask_fn. This config will be passed + with the mask_fn() + + Example: + >>> # xdoctest: +SKIP + >>> model = SomeModel() + >>> act_sparsifier = ActivationSparsifier(...) # init activation sparsifier + >>> # Initialize aggregate_fn + >>> def agg_fn(x, y): + >>> return x + y + >>> + >>> # Initialize reduce_fn + >>> def reduce_fn(x): + >>> return torch.mean(x, dim=0) + >>> + >>> # Initialize mask_fn + >>> def mask_fn(data): + >>> return torch.eye(data.shape).to(data.device) + >>> + >>> + >>> act_sparsifier.register_layer( + ... model.some_layer, + ... aggregate_fn=agg_fn, + ... reduce_fn=reduce_fn, + ... mask_fn=mask_fn, + ... ) + >>> + >>> # start training process + >>> for _ in [...]: + >>> # epoch starts + >>> # model.forward(), compute_loss() and model.backwards() + >>> # epoch ends + >>> act_sparsifier.step() + >>> # end training process + >>> sparsifier.squash_mask() + """ + + def __init__( + self, + model: nn.Module, + aggregate_fn=None, + reduce_fn=None, + mask_fn=None, + features=None, + feature_dim=None, + **sparse_config, + ): + self.model = model + self.defaults: dict[str, Any] = defaultdict() + self.defaults["sparse_config"] = sparse_config + + # functions + self.defaults["aggregate_fn"] = aggregate_fn + self.defaults["reduce_fn"] = reduce_fn + self.defaults["mask_fn"] = mask_fn + + # default feature and feature_dim + self.defaults["features"] = features + self.defaults["feature_dim"] = feature_dim + + self.data_groups: dict[str, dict] = defaultdict( + dict + ) # contains all relevant info w.r.t each registered layer + + self.state: dict[str, Any] = defaultdict(dict) # layer name -> mask + + @staticmethod + def _safe_rail_checks(args): + """Makes sure that some of the functions and attributes are not passed incorrectly""" + + # if features are not None, then feature_dim must not be None + features, feature_dim = args["features"], args["feature_dim"] + if features is not None: + if feature_dim is None: + raise AssertionError("need feature dim to select features") + + # all the *_fns should be callable + fn_keys = ["aggregate_fn", "reduce_fn", "mask_fn"] + for key in fn_keys: + fn = args[key] + if not callable(fn): + raise AssertionError(f"{fn} must be callable") + + def _aggregate_hook(self, name): + """Returns hook that computes aggregate of activations passing through.""" + + # gather some data + feature_dim = self.data_groups[name]["feature_dim"] + features = self.data_groups[name]["features"] + agg_fn = self.data_groups[name]["aggregate_fn"] + + def hook(module, input) -> None: + input_data = input[0] + + data = self.data_groups[name].get("data") # aggregated data + if features is None: + # no features associated, data should not be a list + if data is None: + data = torch.zeros_like(input_data) + self.state[name]["mask"] = torch.ones_like(input_data) + out_data = agg_fn(data, input_data) + else: + # data should be a list [aggregated over each feature only] + if data is None: + out_data = [ + 0 for _ in range(len(features)) + ] # create one in case of 1st forward + self.state[name]["mask"] = [0 for _ in range(len(features))] + else: + out_data = data # a list + + # compute aggregate over each feature + for feature_idx in range(len(features)): + # each feature is either a list or scalar, convert it to torch tensor + feature_tensor = ( + torch.Tensor([features[feature_idx]]) + .long() + .to(input_data.device) + ) + data_feature = torch.index_select( + input_data, feature_dim, feature_tensor + ) + if data is None: + curr_data = torch.zeros_like(data_feature) + self.state[name]["mask"][feature_idx] = torch.ones_like( + data_feature + ) + else: + curr_data = data[feature_idx] + out_data[feature_idx] = agg_fn(curr_data, data_feature) + self.data_groups[name]["data"] = out_data + + return hook + + def register_layer( + self, + layer: nn.Module, + aggregate_fn=None, + reduce_fn=None, + mask_fn=None, + features=None, + feature_dim=None, + **sparse_config, + ): + r""" + Registers a layer for sparsification. The layer should be part of self.model. + Specifically, registers a pre-forward hook to the layer. The hook will apply the aggregate_fn + and store the aggregated activations that is input over each step. + + Note:: + - There is no need to pass in the name of the layer as it is automatically computed as per + the fqn convention. + + - All the functions (fn) passed as argument will be called at a dim, feature level. + """ + name = module_to_fqn(self.model, layer) + if name is None: + raise AssertionError("layer not found in the model") + + if name in self.data_groups: # unregister layer if already present + warnings.warn( + "layer already attached to the sparsifier, deregistering the layer and registering with new config", + stacklevel=2, + ) + self.unregister_layer(name=name) + + local_args = copy.deepcopy(self.defaults) + update_dict = { + "aggregate_fn": aggregate_fn, + "reduce_fn": reduce_fn, + "mask_fn": mask_fn, + "features": features, + "feature_dim": feature_dim, + "layer": layer, + } + local_args.update( + (arg, val) for arg, val in update_dict.items() if val is not None + ) + local_args["sparse_config"].update(sparse_config) + + self._safe_rail_checks(local_args) + + self.data_groups[name] = local_args + agg_hook = layer.register_forward_pre_hook(self._aggregate_hook(name=name)) + + self.state[name]["mask"] = ( + None # mask will be created when model forward is called. + ) + + # attach agg hook + self.data_groups[name]["hook"] = agg_hook + + # for serialization purposes, we know whether aggregate_hook is attached + # or sparsify_hook() + self.data_groups[name]["hook_state"] = "aggregate" # aggregate hook is attached + + def get_mask(self, name: str | None = None, layer: nn.Module | None = None): + """ + Returns mask associated to the layer. + + The mask is + - a torch tensor is features for that layer is None. + - a list of torch tensors for each feature, otherwise + + Note:: + The shape of the mask is unknown until model.forward() is applied. + Hence, if get_mask() is called before model.forward(), an + error will be raised. + """ + if name is None and layer is None: + raise AssertionError("Need at least name or layer obj to retrieve mask") + + if name is None: + if layer is None: + raise AssertionError("layer must be provided when name is None") + name = module_to_fqn(self.model, layer) + if name is None: + raise AssertionError("layer not found in the specified model") + + if name not in self.state: + raise ValueError("Error: layer with the given name not found") + + mask = self.state[name].get("mask", None) + + if mask is None: + raise ValueError( + "Error: shape unknown, call layer() routine at least once to infer mask" + ) + return mask + + def unregister_layer(self, name): + """Detaches the sparsifier from the layer""" + + # detach any hooks attached + self.data_groups[name]["hook"].remove() + + # pop from the state dict + self.state.pop(name) + + # pop from the data groups + self.data_groups.pop(name) + + def step(self): + """Internally calls the update_mask() function for each layer""" + with torch.no_grad(): + for name, configs in self.data_groups.items(): + data = configs["data"] + self.update_mask(name, data, configs) + + self.data_groups[name].pop("data") # reset the accumulated data + + def update_mask(self, name, data, configs): + """ + Called for each registered layer and does the following- + 1. apply reduce_fn on the aggregated activations + 2. use mask_fn to compute the sparsification mask + + Note: + the reduce_fn and mask_fn is called for each feature, dim over the data + """ + mask = self.get_mask(name) + sparse_config = configs["sparse_config"] + features = configs["features"] + reduce_fn = configs["reduce_fn"] + mask_fn = configs["mask_fn"] + if features is None: + data = reduce_fn(data) + mask.data = mask_fn(data, **sparse_config) + else: + for feature_idx in range(len(features)): + data_feature = reduce_fn(data[feature_idx]) + mask[feature_idx].data = mask_fn(data_feature, **sparse_config) + + def _sparsify_hook(self, name): + """Returns hook that applies sparsification mask to input entering the attached layer""" + mask = self.get_mask(name) + features = self.data_groups[name]["features"] + feature_dim = self.data_groups[name]["feature_dim"] + + def hook(module, input): + input_data = input[0] + if features is None: + # apply to all the features + return input_data * mask + else: + # apply per feature, feature_dim + for feature_idx in range(len(features)): + feature = ( + torch.Tensor([features[feature_idx]]) + .long() + .to(input_data.device) + ) + sparsified = ( + torch.index_select(input_data, feature_dim, feature) + * mask[feature_idx] + ) + input_data.index_copy_(feature_dim, feature, sparsified) + return input_data + + return hook + + def squash_mask(self, attach_sparsify_hook=True, **kwargs): + """ + Unregisters aggregate hook that was applied earlier and registers sparsification hooks if + attach_sparsify_hook = True. + """ + for name, configs in self.data_groups.items(): + # unhook agg hook + configs["hook"].remove() + configs.pop("hook") + self.data_groups[name]["hook_state"] = "None" + if attach_sparsify_hook: + configs["hook"] = configs["layer"].register_forward_pre_hook( + self._sparsify_hook(name) + ) + configs["hook_state"] = ( + "sparsify" # signals that sparsify hook is now attached + ) + + def _get_serializable_data_groups(self): + """Exclude hook and layer from the config keys before serializing + + TODO: Might have to treat functions (reduce_fn, mask_fn etc) in a different manner while serializing. + For time-being, functions are treated the same way as other attributes + """ + data_groups: dict[str, Any] = defaultdict() + for name, config in self.data_groups.items(): + new_config = { + key: value + for key, value in config.items() + if key not in ["hook", "layer"] + } + data_groups[name] = new_config + return data_groups + + def _convert_mask(self, states_dict, sparse_coo=True): + r"""Converts the mask to sparse coo or dense depending on the `sparse_coo` argument. + If `sparse_coo=True`, then the mask is stored as sparse coo else dense tensor + """ + states = copy.deepcopy(states_dict) + for state in states.values(): + if state["mask"] is not None: + if isinstance(state["mask"], list): + for idx in range(len(state["mask"])): + if sparse_coo: + state["mask"][idx] = state["mask"][idx].to_sparse_coo() + else: + state["mask"][idx] = state["mask"][idx].to_dense() + else: + if sparse_coo: + state["mask"] = state["mask"].to_sparse_coo() + else: + state["mask"] = state["mask"].to_dense() + return states + + def state_dict(self) -> dict[str, Any]: + r"""Returns the state of the sparsifier as a :class:`dict`. + + It contains: + * state - contains name -> mask mapping. + * data_groups - a dictionary containing all config information for each + layer + * defaults - the default config while creating the constructor + """ + data_groups = self._get_serializable_data_groups() + state = self._convert_mask(self.state) + return {"state": state, "data_groups": data_groups, "defaults": self.defaults} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + r"""The load_state_dict() restores the state of the sparsifier based on the state_dict + + Args: + * state_dict - the dictionary that to which the current sparsifier needs to be restored to + """ + state = state_dict["state"] + data_groups, defaults = state_dict["data_groups"], state_dict["defaults"] + + self.__set_state__( + {"state": state, "data_groups": data_groups, "defaults": defaults} + ) + + def __get_state__(self) -> dict[str, Any]: + data_groups = self._get_serializable_data_groups() + state = self._convert_mask(self.state) + return { + "defaults": self.defaults, + "state": state, + "data_groups": data_groups, + } + + def __set_state__(self, state: dict[str, Any]) -> None: + state["state"] = self._convert_mask( + state["state"], sparse_coo=False + ) # convert mask to dense tensor + self.__dict__.update(state) + + # need to attach layer and hook info into the data_groups + for name, config in self.data_groups.items(): + # fetch layer + layer = fqn_to_module(self.model, name) + if layer is None: + raise AssertionError(f"layer {name} not found in the model") + + # if agg_mode is True, then layer in aggregate mode + if "hook_state" in config and config["hook_state"] == "aggregate": + hook = layer.register_forward_pre_hook(self._aggregate_hook(name)) + + elif "hook_state" in config and config["hook_state"] == "sparsify": + hook = layer.register_forward_pre_hook(self._sparsify_hook(name)) + + config["layer"] = layer + config["hook"] = hook # type: ignore[possibly-undefined] + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + for name, config in self.data_groups.items(): + format_string += "\n" + format_string += "\tData Group\n" + format_string += f"\t name: {name}\n" + for key in sorted(config.keys()): + if key in ["data", "hook", "reduce_fn", "mask_fn", "aggregate_fn"]: + continue + format_string += f"\t {key}: {config[key]}\n" + format_string += ")" + return format_string diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7564fe408b36e5fb62eb4cb2272ef432095981 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py @@ -0,0 +1,6 @@ +from .base_data_scheduler import BaseDataScheduler + + +__all__ = [ + "BaseDataScheduler", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8c4cef62169d94cd38c2aa96ddb42e932c59367 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/base_data_scheduler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/base_data_scheduler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b430a83e9e913ec67f39eb368827d48c5f5d05f7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/base_data_scheduler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..c2f48abfc9deec2393816ffd227cc414f9f14a29 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py @@ -0,0 +1,199 @@ +# mypy: allow-untyped-defs +import abc +import warnings +import weakref +from functools import wraps + +from torch.ao.pruning._experimental.data_sparsifier import BaseDataSparsifier + + +__all__ = ["BaseDataScheduler"] + + +class BaseDataScheduler: + r""" + The BaseDataScheduler is the abstract scheduler class specifically for the + BaseDataSparsifier class. This class controls a specific hyperparameter of + the sparsifier class and varies it across the training process (or across time). + + Args: + data_sparsifier (instance of BaseDataSparsifier) + Implemented class data sparsifier class wherein the update_mask is implemented + schedule_param (str) + A specific hyperparameter of the passed sparsifier that needs to be scheduled/varied + last_epoch (int, default=-1) + This is specifically is passed when training needs to be resumed from a particular + point. + verbose (bool, default=False) + Verbosity of the BaseDataScheduler + + The *get_hyperparam()* function needs to be implemented by the user. + """ + + def __init__( + self, data_sparsifier, schedule_param: str, last_epoch=-1, verbose=False + ): + # Attach sparsifier + if not isinstance(data_sparsifier, BaseDataSparsifier): + raise TypeError( + f"{type(data_sparsifier).__name__} is not an instance of torch.ao.pruning.BaseDataSparsifier" + ) + self.data_sparsifier = data_sparsifier + self.schedule_param = schedule_param + + # Initialize epoch and base hyper-params + self.base_param = { + name: config.get(schedule_param, None) + for name, config in self.data_sparsifier.data_groups.items() + } + + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `scheduler.step()` is called after + # `sparsifier.step()` + def with_counter(method): + if getattr(method, "_with_counter", False): + # `sparsifier.step()` has already been replaced, return. + return method + + # Keep a weak reference to the sparsifier instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 # type: ignore[union-attr] + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True # type: ignore[attr-defined] + return wrapper + + self.data_sparsifier.step = with_counter(self.data_sparsifier.step) # type: ignore[assignment] + self.data_sparsifier._step_count = 0 # type: ignore[attr-defined] + self._step_count: int = 0 + self.verbose = verbose + + # Housekeeping + self._get_sp_called_within_step: bool = False # sp -> schedule parameter + self.step() + + @abc.abstractmethod + def get_schedule_param(self): + r""" + Abstract method that needs to be implemented by the child class. + The expected return type should is a dictionary of name to schedule_param value + The returned values will be updated in sparsifier when the scheduler step() function + is called. + + Example: + >>> def get_schedule_param(self): + ... new_param = {} + ... for name in self.sparsifier.data_groups.keys(): + ... new_param[name] = ( + ... self.sparsifier.data_groups[name][self.schedule_param] * 0.5 + ... ) + ... return new_param + + When the step() function is called, the value in self.sparsifier.data_groups[name][self.schedule_param] + would be halved + """ + raise NotImplementedError + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + format_string += "\n" + format_string += f"Data Sparsifier {self.data_sparsifier}\n" + format_string += f" {self.schedule_param}: {self.base_param}\n" + format_string += ")" + return format_string + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the sparsifier. + + Note: + The scheduler class does not track the state of the data_sparsifier. + Make sure to store the state of the sparsifier before storing the + state of the scheduler + """ + return { + key: value + for key, value in self.__dict__.items() + if key != "data_sparsifier" + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Note: + Remember to restore the state of the data_sparsifier before the scheduler. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_param(self): + return self._last_param + + def step(self): + # Raise warning if trying to call scheduler step before the sparsifier. + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.data_sparsifier.step, "_with_counter"): + warnings.warn( + "Seems like `data_sparsifier.step()` has been overridden after sparsity scheduler " + "initialization. Please, make sure to call `data_sparsifier.step()` before " + "`scheduler.step()`.", + UserWarning, + stacklevel=2, + ) + + # Just check if there were two first scheduler.step() calls before sparsifier.step() + elif self.data_sparsifier._step_count < 1: # type: ignore[attr-defined] + warnings.warn( + "Detected call of `scheduler.step()` before `data_sparsifier.step()`. " + "You have to make sure you run the data_sparsifier.step() BEFORE any " + "calls to the scheduler.step().", + UserWarning, + stacklevel=2, + ) + self._step_count += 1 + + class _enable_get_sp_call: + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_sp_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_sp_called_within_step = False + + with _enable_get_sp_call(self): + self.last_epoch += 1 + updated_scheduler_params = self.get_schedule_param() + + for name, param in updated_scheduler_params.items(): + self.data_sparsifier.data_groups[name][self.schedule_param] = param + if self.verbose: + print(f"Adjusting {self.schedule_param} for group {name} to {param}") + + self._last_param = { + name: config.get(self.schedule_param, None) + for name, config in self.data_sparsifier.data_groups.items() + } + self.data_sparsifier.enable_mask_update = True diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1b5b9b96ec96fffdb0b66e21686a927a0c41b4a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py @@ -0,0 +1,8 @@ +from .base_data_sparsifier import BaseDataSparsifier +from .data_norm_sparsifier import DataNormSparsifier + + +__all__ = [ + "BaseDataSparsifier", + "DataNormSparsifier", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6872164d752c84283ea05d76842698645105d4af Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/base_data_sparsifier.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/base_data_sparsifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..846fe4cdcd4255d34c9f3674b01e9a5407e513dc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/base_data_sparsifier.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/data_norm_sparsifier.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/data_norm_sparsifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9df9efd4d760826701f94b12c52a889b8edf7df6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/data_norm_sparsifier.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6df56b8d4f709d6c1fd7e48d860283c835fa8126 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..e76b5ccd7b5c571f636cce6a2f8beb907f50004e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py @@ -0,0 +1,334 @@ +# mypy: allow-untyped-defs +import abc +import copy +import sys +import warnings +from collections import defaultdict +from typing import Any + +import torch +from torch import nn +from torch.ao.pruning.sparsifier import base_sparsifier, utils +from torch.nn.utils import parametrize + + +if not sys.warnoptions: + # to suppress repeated warnings when being used in a training loop. + warnings.simplefilter("once") + +__all__ = ["BaseDataSparsifier"] + +EMBEDDING_TYPES = { + nn.Embedding, + nn.EmbeddingBag, +} + +SUPPORTED_TYPES = { + torch.Tensor, + nn.Parameter, + *EMBEDDING_TYPES, +} + + +class _Container(nn.Module): + pass + + +class BaseDataSparsifier(base_sparsifier.BaseSparsifier): + r""" + Base Data Sparsifier class for all Data sparsifiers. + The abstract class accepts raw torch tensors / embedding / embedding bags (refer to SUPPORTED_TYPES above) + to prepare for sparsification. + In this case, mask (and parametrizations) is owned by the class and not by the user. + Specifically, the container object inside the class maintains the mask and parametrizations of the input data + + Args: + data_list (list of tuples) + list of (name, data) tuples to sparsify. Lookup SUPPORTED_TYPES + for type of data. Internally, a container module handles the data sparsification. + + defaults (dict) + default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + Example:: + >>> # xdoctest: +SKIP + >>> data_list = [('tensor_1', torch.randn(3,3)), ('tensor_2', torch.randn(4,4))] + >>> defaults = {'sparsity_level': 0.7} + >>> sparsifier = DerivedDataSparsifier(data_list = data_list, **defaults) # Some sparsifier that inherits BaseDataSparsifier + >>> new_tensor_to_add = {'name': 'tensor_3', 'data': torch.randn(5,5), 'sparsity_level': 0.3} + >>> sparsifier.add_data(**new_tensor_to_add) + >>> # tensor_1 and tensor_2 will have sparsity_level of 0.7 but tensor_3 will have sparsity_level=0.3 + """ + + def __init__(self, data_list: list[tuple[str, Any]] | None = None, **defaults): + super().__init__(defaults=defaults) + + self._container = _Container() + + self.data_groups: dict[str, dict] = defaultdict(dict) # name -> {**config} + if data_list is not None: + # add data with default config here + [self.add_data(name, data, **self.defaults) for name, data in data_list] + + def prepare(self, model, config): + raise NotImplementedError("this function is undefined for this class") + + def _extract_weight(self, data): + # extract the weight parameter instead of underlying data + if type(data) in [torch.Tensor, nn.Parameter]: + return data + elif type(data) in EMBEDDING_TYPES: + return data.weight + + def add_data(self, name: str, data, reuse_mask=True, **config): + r"""Configures and parametrizes the internal container model with name and data. + + **Note**: + 1. If the data with name already exists, it replaces the data. + 2. While replacing, the old mask is reused when `reuse_mask=True` + 3. If `reuse_mask=True`, then the replacing data needs to have the same shape as that of old data. + 4. By default, the config of the replaced data is used as config for the replacing data, unless something + is specified in the config dictionary. + """ + if type(data) not in SUPPORTED_TYPES: + raise AssertionError( + f"specified data type:{type(data)} not supported at the moment" + ) + local_args = copy.deepcopy(self.defaults) + local_args.update(config) + weight = self._extract_weight(data) + + # Bookkeeping in the container class + mask = local_args.get("mask", torch.ones_like(weight)) + param_class = local_args.get("parametrization", utils.FakeSparsity) + + if name in self.state: + # If the named data already exists - replace + warnings.warn( + "Replacing existing data of the same name. - Did you mean a different name?", + stacklevel=2, + ) + + # reuse old config + old_args = self.data_groups[name] + local_args = copy.deepcopy(old_args) + local_args.update(config) + + if reuse_mask: + current_data = self.get_data(name=name) + if weight.shape != current_data.shape: + raise AssertionError( + "to retain the old mask, the shape of the new data must be the same as the previous one" + ) + mask = self.get_mask( + name=name + ) # reuse mask instead of creating a new one + + self._delete_data(name=name) + + # parameter creates a deepcopy of the weight inside, so create a buffer + self._container.register_buffer(name=name, tensor=weight) + parametrize.register_parametrization(self._container, name, param_class(mask)) + self.state[name]["mask"] = mask + self.data_groups[name] = local_args + return getattr(self._container, name) + + def get_data(self, name: str, return_original: bool = True): + r"""Returns weight tensor (or data) + Args: + - name: name of the data to be returned + - return_original returns weight tensor without applying parametrization if True + else - returns the sparsified version (parametrized) + """ + if name not in self.data_groups: + raise ValueError("data with specified name does not exist") + + if return_original: + if not parametrize.is_parametrized(self._container, name): + raise ValueError("mask squashed - original mask value does not exist") + data = getattr(self._container.parametrizations, name).original + return data + else: + return getattr(self._container, name) + + def _convert_mask(self, states, sparse_coo=True): + r"""Converts the mask to sparse coo or dense tensors depending on the `sparse_coo` argument.""" + states = copy.deepcopy(states) + for state in states.values(): + if sparse_coo: + state["mask"] = state["mask"].to_sparse_coo() + else: + state["mask"] = state["mask"].to_dense() + + return states + + def state_dict(self): + r"""Returns the state of the optimizer as a :class:`dict`. + + It contains: + * state - contains name -> mask mapping. + * data_groups - a list containing all sparsity configuration groups + with the key name specifying the name of the data + * container_state_dict - the state dictionary of the internal + container model used for sparsification + """ + state = self._convert_mask(self.state) + return { + "state": state, + "data_groups": self.data_groups, + "_container": self._container.state_dict(), + } + + def _load_container_from_state(self, states, data_groups, container_state_dict): + r"""This restores the state of the container specifically based on the data present in state and data_groups + If the data was parametrized, then the data would be added to the container and then parametrized, + else it would just add the attribute the container. + """ + for name, state in states.items(): + config_name = data_groups.get(name, None) + if config_name is None: + raise RuntimeError(f"Error loading {name}") + + # check if the data with such a name was parametrized, if so parametrize + # otherwise just set the attribute and continue + parametrized_name = f"parametrizations.{name}.original" + parametrized = False + data = container_state_dict.get(name, None) + if name in container_state_dict: + # the parametrization was probably removed for this + data = container_state_dict.get(name) + + elif parametrized_name in container_state_dict: + # so the weight was parametrized + data = container_state_dict.get(parametrized_name) + parametrized = True + + else: + raise RuntimeError(f"Error loading {name}") + + self._container.register_buffer(name=name, tensor=data) + + if parametrized: + # register parameter if parametrized + mask = state.get("mask", torch.ones_like(data)) + param_class = data_groups.get( + "parametrization", utils.FakeSparsity + ) # change once public_api for utils is fixed! + parametrize.register_parametrization( + self._container, name, param_class(mask) + ) + + def load_state_dict(self, state_dict, strict=True): + r"""The load_state_dict() restores the state of the sparsifier based on the state_dict + + Args: + * state_dict - the dictionary that to which the current sparsifier needs to be restored to + * strict - If True - the sparsifier is reset and is restored exactly to the state in state_dict. + If False - the current sparsifier is not reset before loading the state_dict i.e. data added + before loading the state_dict is not erased. + """ + states = copy.deepcopy(state_dict["state"]) + data_groups = copy.deepcopy(state_dict["data_groups"]) + container_state_dict = copy.deepcopy(state_dict["_container"]) + + states = self._convert_mask( + states, sparse_coo=False + ) # convert sparse coo mask to dense + if strict: + # if strict load -> then reset container + self._container = _Container() + + self._load_container_from_state(states, data_groups, container_state_dict) + + if not strict: + states.update(self.state) + data_groups.update(self.data_groups) + + self.__setstate__({"state": states, "data_groups": data_groups}) + + def __setstate__(self, state): + if "_container" in state: # If container object is in state then load model + container_dict = state.pop("_container") + self._container = _Container() + state["state"] = self._convert_mask( + state["state"], sparse_coo=False + ) # convert sparse coo mask to dense + self._load_container_from_state( + state["state"], state["data_groups"], container_dict + ) + + self.__dict__.update(state) + + def __getstate__(self): + state = self._convert_mask(self.state) + return { + "defaults": self.defaults, + "state": state, + "data_groups": self.data_groups, + "_container": self._container.state_dict(), + } + + def __repr__(self): # type:ignore[override] + format_string = self.__class__.__name__ + " (" + for name, sparse_args in self.data_groups.items(): + format_string += "\n" + format_string += "\tData Group\n" + format_string += f"\t name: {name}\n" + for key in sorted(sparse_args.keys()): + if key == "data": + continue + format_string += f"\t {key}: {sparse_args[key]}\n" + format_string += ")" + return format_string + + def get_mask(self, name: str): + if name not in self.state: + raise ValueError("data with specified name does not exist") + return self.state[name]["mask"] + + def squash_mask(self, *args, leave_parametrized=True, names=None, **kwargs): + r"""Squashes the sparse masks into the appropriate tensors. Also, accepts list of strings + to squash mask for. If none, squashes mask for all the keys + kwargs: + * names: list of strings to squash mask for + * sparsified: if true - applies the mask before squashing + if false - does not apply the mask before squashing + """ + if names is None: + names = list(self.data_groups.keys()) + for name in names: + parametrize.remove_parametrizations( + self._container, name, leave_parametrized=leave_parametrized + ) + + def step(self): # type:ignore[override] + if not self.enable_mask_update: + return + with torch.no_grad(): + for name, config in self.data_groups.items(): + # get non-sparsified data + data = self.get_data(name) + # need name for the mask otherwise can directly pass mask? + self.update_mask(name, data, **config) + + @abc.abstractmethod + def update_mask(self, name, data, **kwargs): # type: ignore[override] + pass + + def _delete_data(self, name): + """Detaches some data from the sparsifier. + + Args: + name (str) + Name of the data to be removed from the sparsifier + + Note: + Currently private. Kind of used as a helper function when replacing data of the same name + """ + self.squash_mask( + names=[name], leave_parametrized=False + ) # do not apply the mask while deleting + delattr(self._container, name) + self.state.pop(name) + self.data_groups.pop(name) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2971cd0b3d0cae7afb8763bee319ac819ad2f8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -0,0 +1,204 @@ +# mypy: allow-untyped-defs +import operator +from functools import reduce +from typing import Any + +import torch +from torch.nn import functional as F + +from .base_data_sparsifier import BaseDataSparsifier + + +__all__ = ["DataNormSparsifier"] + + +class DataNormSparsifier(BaseDataSparsifier): + r"""L1-Norm Sparsifier + This sparsifier computes the *L1-norm* of every sparse block and "zeroes-out" the + ones with the lowest norm. The level of sparsity defines how many of the + blocks is removed. + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out + 2. `sparse_block_shape` defines the shape of the sparse blocks. Note that + the sparse blocks originate at the zero-index of the tensor. + 3. `zeros_per_block` is the number of zeros that we are expecting in each + sparse block. By default we assume that all elements within a block are + zeroed-out. However, setting this variable sets the target number of + zeros per block. The zeros within each block are chosen as the *smallest + absolute values*. + Args: + sparsity_level: The target level of sparsity + sparse_block_shape: The shape of a sparse block + zeros_per_block: Number of zeros in a sparse block + Note:: + All arguments to the DataNormSparsifier constructor are "default" + arguments and could be overridden by the configuration provided in the + `add_data` step. + """ + + def __init__( + self, + data_list: list[tuple[str, Any]] | None = None, + sparsity_level: float = 0.5, + sparse_block_shape: tuple[int, int] = (1, 4), + zeros_per_block: int | None = None, + norm: str = "L1", + ): + if zeros_per_block is None: + zeros_per_block = reduce(operator.mul, sparse_block_shape) + + if norm not in ["L1", "L2"]: + raise AssertionError("only L1 and L2 norm supported at the moment") + + defaults = { + "sparsity_level": sparsity_level, + "sparse_block_shape": sparse_block_shape, + "zeros_per_block": zeros_per_block, + } + self.norm = norm + super().__init__(data_list=data_list, **defaults) + + def __get_scatter_folded_mask( + self, data, dim, indices, output_size, sparse_block_shape + ): + mask = torch.ones_like(data) + mask.scatter_(dim=dim, index=indices, value=0) # zeroing out + mask = F.fold( + mask, + output_size=output_size, + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ) + mask = mask.to(torch.int8) + return mask + + def __get_block_level_mask(self, data, sparse_block_shape, zeros_per_block): + # Assume data is a squeezed tensor + height, width = data.shape[-2], data.shape[-1] + block_height, block_width = sparse_block_shape + values_per_block = block_height * block_width + + # just return zeros if zeroing all elements in block + if values_per_block == zeros_per_block: + return torch.zeros_like(data, dtype=torch.int8) + + # creating additional height and width to support padding + dh = (block_height - height % block_height) % block_height + dw = (block_width - width % block_width) % block_width + + # create a new padded tensor like data (to match the block_shape) + padded_data = torch.ones( + height + dh, width + dw, dtype=data.dtype, device=data.device + ) + padded_data = ( + padded_data * torch.nan + ) # can also be replaced with 0 to stop the removal of edge data + padded_data[0:height, 0:width] = data + unfolded_data = F.unfold( + padded_data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ) + + _, sorted_idx = torch.sort(unfolded_data, dim=1) + sorted_idx = sorted_idx[ + :, :zeros_per_block, : + ] # zero out zeros_per_block number of elements + + mask = self.__get_scatter_folded_mask( + data=unfolded_data, + dim=1, + indices=sorted_idx, + output_size=padded_data.shape, + sparse_block_shape=sparse_block_shape, + ) + + mask = ( + mask.squeeze(0).squeeze(0)[:height, :width].contiguous() + ) # remove padding and make contiguous + return mask + + def __get_data_level_mask(self, data, sparsity_level, sparse_block_shape): + height, width = data.shape[-2], data.shape[-1] + block_height, block_width = sparse_block_shape + dh = (block_height - height % block_height) % block_height + dw = (block_width - width % block_width) % block_width + + data_norm = F.avg_pool2d( + data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ceil_mode=True, + ) + + values_per_block = reduce(operator.mul, sparse_block_shape) + + data_norm = data_norm.flatten() + num_blocks = len(data_norm) + + data_norm = data_norm.repeat( + 1, values_per_block, 1 + ) # get similar shape after unfold + _, sorted_idx = torch.sort(data_norm, dim=2) + + threshold_idx = round(sparsity_level * num_blocks) # number of blocks to remove + sorted_idx = sorted_idx[:, :, :threshold_idx] + + mask = self.__get_scatter_folded_mask( + data=data_norm, + dim=2, + indices=sorted_idx, + output_size=(height + dh, width + dw), + sparse_block_shape=sparse_block_shape, + ) + + mask = mask.squeeze(0).squeeze(0)[ + :height, :width + ] # squeeze only the first 2 dimension + return mask + + def update_mask( # type: ignore[override] + self, name, data, sparsity_level, sparse_block_shape, zeros_per_block, **kwargs + ): + values_per_block = reduce(operator.mul, sparse_block_shape) + if zeros_per_block > values_per_block: + raise ValueError( + "Number of zeros per block cannot be more than " + "the total number of elements in that block." + ) + if zeros_per_block < 0: + raise ValueError("Number of zeros per block should be positive.") + + if self.norm == "L1": + data_norm = torch.abs(data).squeeze() # absolute value based (L1) + else: + data_norm = (data * data).squeeze() # square every element for L2 + + if len(data_norm.shape) > 2: # only supports 2 dimensional data at the moment + raise ValueError("only supports 2-D at the moment") + + elif len(data_norm.shape) == 1: # in case the data is bias (or 1D) + data_norm = data_norm[None, :] + + mask = self.get_mask(name) + if sparsity_level <= 0 or zeros_per_block == 0: + mask.data = torch.ones_like(mask) + elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block): + mask.data = torch.zeros_like(mask) + + # Fetch the high level mask that zeros out entire blocks + data_lvl_mask = self.__get_data_level_mask( + data=data_norm, + sparsity_level=sparsity_level, + sparse_block_shape=sparse_block_shape, + ) + + # Fetch block level mask that zeros out 'zeros_per_block' number of elements in every block + block_lvl_mask = self.__get_block_level_mask( + data=data_norm, + sparse_block_shape=sparse_block_shape, + zeros_per_block=zeros_per_block, + ) + + # zero out the entries inside those blocks whose block is sparsified + mask.data = torch.where(data_lvl_mask == 1, data_lvl_mask, block_lvl_mask) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51e5ee8e9ff90e7a9f69fde690a3d45b9d514a57 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b30e05be1b7e2c5043e7e977b47dc85e5d8f772a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/_data_sparstity_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/_data_sparstity_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0eecd8ea4cdef23209fa2e620b15c57b46cea6cd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/_data_sparstity_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/data_sparsity.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/data_sparsity.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c929217af0af43823fd48f9055cde0fd73b84ed7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/data_sparsity.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..50d5684961bc807d5ae1b02615ade168416c9b3d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py @@ -0,0 +1,44 @@ +# mypy: allow-untyped-defs +import logging + +from torch.ao.pruning._experimental.data_sparsifier.base_data_sparsifier import ( + SUPPORTED_TYPES, +) + + +logger: logging.Logger = logging.getLogger(__name__) + + +def _attach_model_to_data_sparsifier(module, data_sparsifier, config=None): + """Attaches a data sparsifier to all the layers of the module. + Essentially, loop over all the weight parameters in the module and + attach it to the data sparsifier. + Note:: + The '.' in the layer names are replaced with '_' (refer to _get_valid_name() below) + before attaching to the sparsifier. This is because, the data + sparsifier uses a dummy model inside to store the weight parameters. + """ + if config is None: + config = {} + for name, parameter in module.named_parameters(): + if type(parameter) in SUPPORTED_TYPES: + valid_name = _get_valid_name(name) + # will be defaulted to default configs + data_sparsifier.add_data( + name=valid_name, data=parameter, **config.get(valid_name, {}) + ) + + +def _get_valid_name(name): + return name.replace(".", "_") # . is not allowed as a name + + +def _log_sparsified_level(model, data_sparsifier) -> None: + # Show the level of sparsity AFTER step: + for name, parameter in model.named_parameters(): + if type(parameter) not in SUPPORTED_TYPES: + continue + valid_name = _get_valid_name(name) + mask = data_sparsifier.get_mask(name=valid_name) + sparsity_level = 1.0 - mask.float().mean() + logger.info("Sparsity in layer %s = % .2%", name, sparsity_level) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c8a91c5c9dcea9ad5cceaa9ecc80e3f32bd8a7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +from collections import defaultdict +from copy import deepcopy +from typing import Any, TYPE_CHECKING + +import pytorch_lightning as pl # type: ignore[import] + +from ._data_sparstity_utils import ( + _attach_model_to_data_sparsifier, + _get_valid_name, + _log_sparsified_level, +) + + +if TYPE_CHECKING: + import torch + + +class PostTrainingDataSparsity(pl.callbacks.Callback): + """Lightning callback that enables post-training sparsity. + + This callback aims to sparsify the model inside lightning module after training. + **Note that the model is copied and then sparsified, so the existing model is not modified** + + The sparsified model can be used for comparison and can be accessed using + .sparsified + + Args: + data_sparsifier_class (some implemented class of BaseDataSparsifier) + The data sparsifier object of this class is created when the + training starts. + Note: Objects should not be passed in here as they are created + once the training completes. + + data_sparsifier_args (Dict) + Dictionary of args to be passed to the data sparsifier. + Note: data_list arg should be ignored + + Hooks implemented: + on_fit_end() + 1. copies the model and attaches it to the sparsifier + 2. sparsier step() is called + 3. squashes the mask() + """ + + def __init__(self, data_sparsifier_class, data_sparsifier_args): + super().__init__() + self.data_sparsifier_class = data_sparsifier_class + self.data_sparsifier_args = data_sparsifier_args + self.data_sparsifier: Any = None + self.sparsified: torch.nn.Module | None = None + + def on_fit_end(self, trainer, pl_module) -> None: + self.sparsified = deepcopy(pl_module.model).eval() + self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args) + + _attach_model_to_data_sparsifier(self.sparsified, self.data_sparsifier) + + self.data_sparsifier.step() + + self.data_sparsifier.squash_mask() # currently squashes params for all mask + + _log_sparsified_level(self.sparsified, self.data_sparsifier) + + +class TrainingAwareDataSparsity(pl.callbacks.Callback): + """Lightning callback that enables in-training sparsity. + + This callback aims to sparsify the model inside lightning module during training. + **Note that the model is copied and then sparsified, so the existing model is not modified** + + The sparsified model can be used for comparison and can be accessed using + .sparsified + + Args: + data_sparsifier_class (some implemented class of BaseDataSparsifier) + The data sparsifier object of this class is created when the + training starts. + Note: Objects should not be passed in here as they are created + when the training starts. + + data_sparsifier_args (Dict) + Dictionary of args to be passed to the data sparsifier. + Note: data_list arg should be ignored + + data_scheduler_class (some implemented class of BaseDataScheduler) + The data scheduler of this class is created when the training starts + Note: Objects should not be passed in here as they are created + when the training starts. + + data_scheduler_args(Dict) + Dictionary of args to be passed to the data scheduler. + **Note: data_sparsifier arg should be ignored as the recipe + creates and pass sparsifier object into the class** + + Hooks implemented: + on_train_start() + Data sparsifier and scheduler objects are created. + Pytorch model attached to the sparsifier + + on_train_epoch_start() + Loads the state_dict of the data sparsifier + + on_train_epoch_end() + 1. Copies the model and attaches it to the sparsifier + 2. sparsifier step() and scheduler step() + 3. Dump state_dict of the current sparsifier + + on_train_end() + squash mask + """ + + def __init__( + self, + data_sparsifier_class, + data_sparsifier_args, + data_scheduler_class, + data_scheduler_args, + ): + super().__init__() + # data sparsifier objects + self.data_sparsifier_class = data_sparsifier_class + self.data_sparsifier_args = data_sparsifier_args + + # scheduler objects + self.data_scheduler_class = data_scheduler_class + self.data_scheduler_args = data_scheduler_args + + # fields + self.data_sparsifier: Any = None + self.data_scheduler: Any = None + self.sparsified: torch.nn.Module | None = None + + self.data_sparsifier_state_dict: Any = None + + def on_train_start(self, trainer, pl_module) -> None: + # create sparsifier + self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args) + self.sparsified = deepcopy(pl_module.model) + + _attach_model_to_data_sparsifier( + self.sparsified, self.data_sparsifier + ) # just to populate the base_sl in the scheduler + + # create scheduler + args = deepcopy(self.data_scheduler_args) + args["data_sparsifier"] = self.data_sparsifier + self.data_scheduler = self.data_scheduler_class(**args) + + def on_train_epoch_start(self, trainer, pl_module): + if self.data_sparsifier_state_dict is None: + return # probably first epoch + + # load the existing config for each data + self.data_sparsifier.load_state_dict(self.data_sparsifier_state_dict) + + def __create_config_based_on_state(self, pl_module): + config: dict = defaultdict() + if self.data_sparsifier_state_dict is None: + return config + for name, _ in pl_module.model.named_parameters(): + valid_name = _get_valid_name(name) + config[valid_name] = self.data_sparsifier.data_groups[valid_name] + + return config + + def on_train_epoch_end(self, trainer, pl_module): + self.sparsified = deepcopy(pl_module.model) + config = self.__create_config_based_on_state(pl_module) + + # attach model to the data sparsifier + _attach_model_to_data_sparsifier( + self.sparsified, self.data_sparsifier, config=config + ) + self.data_sparsifier.step() + self.data_scheduler.step() + + self.data_sparsifier_state_dict = self.data_sparsifier.state_dict() + + def on_train_end(self, trainer, pl_module): + self.data_sparsifier.squash_mask() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b727635d08151abd39c94ee40b0417afea97a05b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py @@ -0,0 +1,154 @@ +# mypy: allow-untyped-defs + +import torch +import torch.nn as nn +from torch.ao.pruning.sparsifier.utils import fqn_to_module, module_to_fqn + + +SUPPORTED_MODULES = {nn.Embedding, nn.EmbeddingBag} + + +def _fetch_all_embeddings(model): + """Fetches Embedding and EmbeddingBag modules from the model""" + embedding_modules = [] + stack = [model] + while stack: + module = stack.pop() + for _, child in module.named_children(): + fqn_name = module_to_fqn(model, child) + if type(child) in SUPPORTED_MODULES: + embedding_modules.append((fqn_name, child)) + else: + stack.append(child) + return embedding_modules + + +def post_training_sparse_quantize( + model, + data_sparsifier_class, + sparsify_first=True, + select_embeddings: list[nn.Module] | None = None, + **sparse_config, +): + """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags. + The quantization step can happen before or after sparsification depending on the `sparsify_first` argument. + + Args: + - model (nn.Module) + model whose embeddings needs to be sparsified + - data_sparsifier_class (type of data sparsifier) + Type of sparsification that needs to be applied to model + - sparsify_first (bool) + if true, sparsifies first and then quantizes + otherwise, quantizes first and then sparsifies. + - select_embeddings (List of Embedding modules) + List of embedding modules to in the model to be sparsified & quantized. + If None, all embedding modules with be sparsified + - sparse_config (Dict) + config that will be passed to the constructor of data sparsifier object. + + Note: + 1. When `sparsify_first=False`, quantization occurs first followed by sparsification. + - before sparsifying, the embedding layers are dequantized. + - scales and zero-points are saved + - embedding layers are sparsified and `squash_mask` is applied + - embedding weights are requantized using the saved scales and zero-points + 2. When `sparsify_first=True`, sparsification occurs first followed by quantization. + - embeddings are sparsified first + - quantization is applied on the sparsified embeddings + """ + data_sparsifier = data_sparsifier_class(**sparse_config) + + # if select_embeddings is None, perform it on all embeddings + if select_embeddings is None: + embedding_modules = _fetch_all_embeddings(model) + + else: + embedding_modules = [] + if not isinstance(select_embeddings, list): + raise AssertionError( + "the embedding_modules must be a list of embedding modules" + ) + for emb in select_embeddings: + if type(emb) not in SUPPORTED_MODULES: + raise AssertionError( + "the embedding_modules list must be an embedding or embedding bags" + ) + fqn_name = module_to_fqn(model, emb) + if fqn_name is None: + raise AssertionError( + "the embedding modules must be part of input model" + ) + embedding_modules.append((fqn_name, emb)) + + if sparsify_first: + # sparsify + for name, emb_module in embedding_modules: + valid_name = name.replace(".", "_") + data_sparsifier.add_data(name=valid_name, data=emb_module) + + data_sparsifier.step() + data_sparsifier.squash_mask() + + # quantize + for _, emb_module in embedding_modules: + emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig + + torch.ao.quantization.prepare(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) + + else: + # quantize + for _, emb_module in embedding_modules: + emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig + + torch.ao.quantization.prepare(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) + + # retrieve scale & zero_points + quantize_params: dict[str, dict] = { + "scales": {}, + "zero_points": {}, + "dequant_weights": {}, + "axis": {}, + "dtype": {}, + } + + for name, _ in embedding_modules: + quantized_emb = fqn_to_module(model, name) + if quantized_emb is None: + raise AssertionError(f"quantized embedding {name} not found in model") + + quantized_weight = quantized_emb.weight() # type: ignore[operator] + quantize_params["scales"][name] = quantized_weight.q_per_channel_scales() + quantize_params["zero_points"][name] = ( + quantized_weight.q_per_channel_zero_points() + ) + quantize_params["dequant_weights"][name] = torch.dequantize( + quantized_weight + ) + quantize_params["axis"][name] = quantized_weight.q_per_channel_axis() + quantize_params["dtype"][name] = quantized_weight.dtype + + # attach data to sparsifier + data_sparsifier.add_data( + name=name.replace(".", "_"), + data=quantize_params["dequant_weights"][name], + ) + + data_sparsifier.step() + data_sparsifier.squash_mask() + + for name, _ in embedding_modules: + quantized_emb = fqn_to_module(model, name) + if quantized_emb is None: + raise AssertionError(f"quantized embedding {name} not found in model") + requantized_vector = torch.quantize_per_channel( + quantize_params["dequant_weights"][name], + scales=quantize_params["scales"][name], + zero_points=quantize_params["zero_points"][name], + dtype=quantize_params["dtype"][name], + axis=quantize_params["axis"][name], + ) + + quantized_emb.set_weight(requantized_vector) # type: ignore[operator] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a57db6a8d8cde9a89c7cbda4dff6f6075559b59b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py @@ -0,0 +1,5 @@ +from .base_structured_sparsifier import BaseStructuredSparsifier +from .FPGM_pruner import FPGMPruner +from .lstm_saliency_pruner import LSTMSaliencyPruner +from .parametrization import BiasHook, FakeStructuredSparsity +from .saliency_pruner import SaliencyPruner diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29bcf36fe8880f6c316d55895f8f8efad6e3a104 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..330226e07da5563afd6a7850701fd5e0b1109ba3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/base_structured_sparsifier.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/base_structured_sparsifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..597fd639c47ca98d5cf987cbad266eebf45cf794 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/base_structured_sparsifier.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/lstm_saliency_pruner.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/lstm_saliency_pruner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6558bf79356f67967bca766151ca33e6e14e2f52 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/lstm_saliency_pruner.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/match_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/match_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1f626f59db18260bf4b33b097dea373be80575a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/match_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/parametrization.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/parametrization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..075b61e2d6524fb949a1db6fe73257a5a2e56779 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/parametrization.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/prune_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/prune_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ef5b9850929dc43212bd9ef71951cd33d286077 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/prune_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/saliency_pruner.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/saliency_pruner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9d4f3bfd9a2b58fd67b3a86fba196db3be3ebe8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/saliency_pruner.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..d1676292f7d74c4a620de0a53334d6dcd33aa764 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py @@ -0,0 +1,313 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable +from itertools import chain +from operator import getitem + +import torch +import torch.nn.functional as F +from torch import nn +from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier +from torch.fx import symbolic_trace +from torch.nn.utils import parametrize + +from .match_utils import apply_match, MatchAllNode +from .parametrization import BiasHook, FakeStructuredSparsity, module_contains_param +from .prune_functions import ( + prune_conv2d, + prune_conv2d_activation_conv2d, + prune_conv2d_activation_pool_conv2d, + prune_conv2d_conv2d, + prune_conv2d_pool_activation_conv2d, + prune_conv2d_pool_flatten_linear, + prune_linear, + prune_linear_activation_linear, + prune_linear_linear, + prune_lstm_output_layernorm_linear, + prune_lstm_output_linear, +) + + +def _get_supported_structured_pruning_modules(): + SUPPORTED_STRUCTURED_PRUNING_MODULES = { # added to config if None given + nn.Linear, + nn.Conv2d, + nn.LSTM, + } + return SUPPORTED_STRUCTURED_PRUNING_MODULES + + +def _get_supported_activation_functions(): + SUPPORTED_ACTIVATION_FUNCTIONS = { + F.relu, + F.rrelu, + F.hardtanh, + F.relu6, + F.sigmoid, + F.hardsigmoid, + F.tanh, + F.silu, + F.mish, + F.hardswish, + F.elu, + F.celu, + F.selu, + F.hardshrink, + F.leaky_relu, + F.logsigmoid, + F.softplus, + F.prelu, + F.softsign, + F.tanhshrink, + F.gelu, + } + return SUPPORTED_ACTIVATION_FUNCTIONS + + +def _get_supported_activation_modules(): + SUPPORTED_ACTIVATION_MODULES = { + nn.ReLU, + nn.RReLU, + nn.Hardtanh, + nn.ReLU6, + nn.Sigmoid, + nn.Hardsigmoid, + nn.Tanh, + nn.SiLU, + nn.Mish, + nn.Hardswish, + nn.ELU, + nn.CELU, + nn.SELU, + nn.Hardshrink, + nn.LeakyReLU, + nn.LogSigmoid, + nn.Softplus, + nn.PReLU, + nn.Softsign, + nn.Tanhshrink, + nn.GELU, + } + return SUPPORTED_ACTIVATION_MODULES + + +def _get_default_structured_pruning_patterns() -> dict[ + tuple[type[nn.Module] | Callable | MatchAllNode | str, ...], + Callable[..., None], +]: + """ + Returns the patterns for conv2d / linear conversion for each element in the activation functions/modules defined above. + """ + patterns: dict[ + tuple[type[nn.Module] | Callable | MatchAllNode | str, ...], + Callable[..., None], + ] = { + # linear -> linear + (nn.Linear, "output"): prune_linear, + (nn.Linear, nn.Linear): prune_linear_linear, + # conv2d -> conv2d + (nn.Conv2d, "output"): prune_conv2d, + (nn.Conv2d, nn.Conv2d): prune_conv2d_conv2d, + # TODO LSTM Structured pruning does not support returned state currently. + # Should find a way to explicitly match getitem(0) instead of getitem. + # This will also require changing the pruning function. + # lstm -> getitem(0) -> linear + (nn.LSTM, getitem, nn.Linear): prune_lstm_output_linear, + # lstm -> getitem(0) -> layernorm -> linear + (nn.LSTM, getitem, nn.LayerNorm, nn.Linear): prune_lstm_output_layernorm_linear, + } + + for activation in chain( + _get_supported_activation_functions(), _get_supported_activation_modules() + ): + patterns.update( + { + # linear -> activation -> linear + (nn.Linear, activation, nn.Linear): prune_linear_activation_linear, + # conv2d -> activation -> conv2d + (nn.Conv2d, activation, nn.Conv2d): prune_conv2d_activation_conv2d, + # conv2d -> activation -> pool -> conv2d + ( + nn.Conv2d, + activation, + nn.AvgPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.avg_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + nn.MaxPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.max_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + # conv2d -> pool -> activation -> conv2d + ( + nn.Conv2d, + nn.AvgPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.avg_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + nn.MaxPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.max_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + # conv2d -> adaptive pool -> flatten -> linear + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + } + ) + return patterns + + +class BaseStructuredSparsifier(BaseSparsifier): + r"""Base class for structured pruning. + + Abstract methods that need to be implemented: + - update_mask: Function to compute a new mask for all keys in the + `groups` attribute. + + Args: + - defaults [dict]: default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + """ + + def __init__(self, defaults, patterns=None): + super().__init__(defaults) + if patterns is None: + patterns = _get_default_structured_pruning_patterns() + self.patterns = patterns + + def make_config_from_model( + self, + model: nn.Module, + SUPPORTED_MODULES: set[type] | None = None, + ) -> None: + if SUPPORTED_MODULES is None: + SUPPORTED_MODULES = _get_supported_structured_pruning_modules() + super().make_config_from_model(model, SUPPORTED_MODULES=SUPPORTED_MODULES) + + def _prepare(self, *args, **kwargs) -> None: + r"""This function will attach the FakeStructuredSparsity parameterizations + and BiasHooks at the appropriate points in the model. + """ + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrization = config.get("parametrization", FakeStructuredSparsity) + tensor = getattr(module, tensor_name) + + mask = config.get( + "mask", + torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device), + ) + self.state[config["tensor_fqn"]]["mask"] = mask + parametrize.register_parametrization( + module, tensor_name, parametrization(mask) + ) + + # if linear / conv, we add in bias hooks + if isinstance(module, (nn.Linear, nn.Conv2d)): + prune_bias = config.get("prune_bias", True) + if module.bias is not None: + module.register_parameter( + "_bias", nn.Parameter(module.bias.detach()) + ) + # pyrefly: ignore [bad-assignment] + module.bias = None + module.prune_bias = prune_bias + + module.register_forward_hook( + BiasHook(module.parametrizations.weight[0], prune_bias) # type: ignore[union-attr, index] + ) + + def prune(self) -> None: + r""" + This function will FX symbolically trace the model and then find instances of the patterns + defined in self.patterns (by default SUPPORTED_STRUCTURED_PRUNING_PATTERNS ). + + For each pattern, it will apply to corresponding conversion function, which will modify the output + and input size expected by the modules within the pattern + """ + + self.traced = symbolic_trace(self.model) + modules = dict(self.traced.named_modules()) + + # Right now we check for matches simply by iterating across all the patterns + # if this is slow we can store patterns in a trie-structure and modify this code for faster lookup + for node in self.traced.graph.nodes: + for pattern, convert_fn in self.patterns.items(): + matched = apply_match(modules, pattern, node, []) + if matched is None: + continue + + first_module = modules.get(node.target) + # check if first module exists and has appropriate parameterization, otherwise skip + if ( + first_module is not None + and parametrize.is_parametrized(first_module) + and module_contains_param(first_module, FakeStructuredSparsity) + ): + convert_block = [] + for node in matched: + if node.op == "call_module": + convert_block.append(modules.get(node.target)) + elif node.op == "call_function": + convert_block.append(node.target) + convert_fn(*convert_block) + + for module in self.traced.modules(): + if module_contains_param(module, FakeStructuredSparsity): + raise Exception( # noqa: TRY002 + f"Error: {module} still contains FakeStructuredSparsity parametrizations!" + ) + + self.traced.graph.lint() + self.traced.recompile() + return self.traced # type: ignore[return-value] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..f904cc3ab8c4c34a193dd30926fff164010287a8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py @@ -0,0 +1,54 @@ +from typing import Any, cast + +import torch +from torch import nn + +from .base_structured_sparsifier import BaseStructuredSparsifier +from .parametrization import FakeStructuredSparsity + + +class LSTMSaliencyPruner(BaseStructuredSparsifier): + """ + Prune packed LSTM weights based on saliency. + For each layer {k} inside a LSTM, we have two packed weight matrices + - weight_ih_l{k} + - weight_hh_l{k} + + These tensors pack the weights for the 4 linear layers together for efficiency. + + [W_ii | W_if | W_ig | W_io] + + Pruning this tensor directly will lead to weights being misassigned when unpacked. + To ensure that each packed linear layer is pruned the same amount: + 1. We split the packed weight into the 4 constituent linear parts + 2. Update the mask for each individual piece using saliency individually + + This applies to both weight_ih_l{k} and weight_hh_l{k}. + """ + + def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: Any) -> None: + weights = getattr(module, tensor_name) + + for p in getattr(module.parametrizations, tensor_name): + if isinstance(p, FakeStructuredSparsity): + mask = cast(torch.Tensor, p.mask) + + # select weights based on magnitude + if weights.dim() <= 1: + raise Exception( # noqa: TRY002 + "Structured pruning can only be applied to a 2+dim weight tensor!" + ) + # take norm over all but first dim + dims = tuple(range(1, weights.dim())) + saliency = weights.norm(dim=dims, p=1) + + # handle weights in 4 groups + split_size = len(mask) // 4 + masks = torch.split(mask, split_size) + saliencies = torch.split(saliency, split_size) + + for keep_mask, sal in zip(masks, saliencies): + # mask smallest k values to be removed + k = int(len(keep_mask) * kwargs["sparsity_level"]) + prune = sal.topk(k, largest=False, sorted=False).indices + keep_mask.data[prune] = False # modifies underlying p.mask directly diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e22b979ab900c63a9a975b7a07c9b2a64ed8c0b5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py @@ -0,0 +1,65 @@ +""" +Contains utility functions to check if a pattern is in the graph and return the matching nodes +""" + +from typing import Any + +import torch +from torch import nn +from torch.ao.quantization.utils import MatchAllNode +from torch.fx import Node +from torch.nn.utils import parametrize + + +def _match( + modules: dict[str, nn.ModuleDict], + node: Node, + current: nn.Module | Any, +) -> bool: + r""" + checks to see if a single node of a pattern matches + """ + if isinstance(current, type) and issubclass(current, MatchAllNode): + return True + if not isinstance(node, Node): + return False + if isinstance(current, type) and issubclass(current, torch.nn.Module): + return ( + node.op == "call_module" + and parametrize.type_before_parametrizations(modules[node.target]) # type: ignore[index] + == current + ) + elif callable(current): + return node.op == "call_function" and node.target is current + elif isinstance(current, str): + return node.target == current + return False + + +def apply_match( + modules: dict[str, nn.ModuleDict], + pattern: tuple[Any] | Any, + node: Node, + matched_node_pattern: list[Node], +) -> list[Node] | None: + r""" + This function will return the matched nodes if the pattern matches the node given + If there is no match, it will return None + """ + if isinstance(pattern, tuple): + if len(pattern) == 1: + if _match(modules, node, pattern[0]): + return matched_node_pattern + [node] + + first, *rest = pattern + if _match(modules, node, first): + if rest is None: + return matched_node_pattern + [node] + + for user in node.users: + return apply_match( + modules, tuple(rest), user, matched_node_pattern + [node] + ) + elif _match(modules, node, pattern): + return [node] + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py new file mode 100644 index 0000000000000000000000000000000000000000..4256d6fd01750d4408b92342bfb8d12239bf129a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py @@ -0,0 +1,63 @@ +# mypy: allow-untyped-defs +import torch +from torch import nn +from torch.nn.utils.parametrize import is_parametrized + + +def module_contains_param(module, parametrization): + if is_parametrized(module): + # see if any of the module tensors have a parametriztion attached that matches the one passed in + return any( + any(isinstance(param, parametrization) for param in param_list) + for key, param_list in module.parametrizations.items() + ) + return False + + +# Structured Pruning Parameterizations +class FakeStructuredSparsity(nn.Module): + r""" + Parametrization for Structured Pruning. Like FakeSparsity, this should be attached to + the 'weight' or any other parameter that requires a mask. + + Instead of an element-wise bool mask, this parameterization uses a row-wise bool mask. + """ + + def __init__(self, mask): + super().__init__() + self.register_buffer("mask", mask) + + def forward(self, x): + if not isinstance(self.mask, torch.Tensor): + raise AssertionError("mask must be a torch.Tensor") + if self.mask.shape[0] != x.shape[0]: + raise AssertionError( + f"mask shape[0] ({self.mask.shape[0]}) must match x shape[0] ({x.shape[0]})" + ) + shape = [1] * len(x.shape) + shape[0] = -1 + return self.mask.reshape(shape) * x + + def state_dict(self, *args, **kwargs): + # avoid double saving masks + return {} + + +class BiasHook: + def __init__(self, parametrization, prune_bias): + self.param = parametrization + self.prune_bias = prune_bias + + def __call__(self, module, input, output): + if getattr(module, "_bias", None) is not None: + bias = module._bias.data + if self.prune_bias: + bias[~self.param.mask] = 0 + + # reshape bias to broadcast over output dimensions + idx = [1] * len(output.shape) + idx[1] = -1 + bias = bias.reshape(idx) + + output += bias + return output diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..14a1c9a97b07ccb87a5ffab2923a105b7abbd6d4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py @@ -0,0 +1,485 @@ +# mypy: allow-untyped-defs +""" +Collection of conversion functions for linear / conv2d structured pruning +Also contains utilities for bias propagation +""" + +from collections.abc import Callable +from typing import cast + +import torch +from torch import nn, Tensor +from torch.nn.utils import parametrize +from torch.nn.utils.parametrize import ParametrizationList + +from .parametrization import BiasHook, FakeStructuredSparsity + + +# BIAS PROPAGATION +def _remove_bias_handles(module: nn.Module) -> None: + if hasattr(module, "_forward_hooks"): + bias_hooks: list[int] = [] + for key, hook in module._forward_hooks.items(): + if isinstance(hook, BiasHook): + bias_hooks.append(key) + + for key in bias_hooks: + del module._forward_hooks[key] + + +def _get_adjusted_next_layer_bias( + next_layer: nn.Module, pruned_biases: Tensor, mask: Tensor +) -> nn.Parameter: + r"""Returns new adjusted bias for the second supported module""" + if parametrize.is_parametrized(next_layer): + # need to access original weight + parametrization_dict = cast(nn.ModuleDict, next_layer.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + next_weight = weight_parameterizations.original + else: + next_weight = cast(Tensor, next_layer.weight) + + scaling_weight = next_weight[:, ~mask] + if isinstance(next_layer, nn.Conv2d): # checking for Conv2d + # Propagating first layer pruned biases and calculating the new second layer bias + # involves more steps since the Conv2d scaling weight has extra dimensions, + # so adding bias involves broadcasting, logically: + # for each channel k in range(oC): + # scaled_biases = sum(first_bias[pruned_idx] @ next_weight[k, pruned_idx, :, :].T) + # new_next_bias[k] = old_next_bias[k] + scaled_biases + scaling_product = torch.matmul( + pruned_biases.reshape(1, -1), torch.transpose(scaling_weight, 1, 2) + ) + sum_range = list(range(len(scaling_product.shape)))[ + 1: + ] # all but the first dimension + scaled_biases = torch.sum(scaling_product, sum_range) + elif isinstance(next_layer, nn.Linear): # Linear + scaled_biases = torch.matmul( + pruned_biases, torch.transpose(scaling_weight, 0, 1) + ) # recall b2_new = b1 @ w2.T + b2 + else: + raise NotImplementedError(f"Type {type(next_layer)} not supported yet.") + + if ( + parametrize.is_parametrized(next_layer) + and getattr(next_layer, "_bias", None) is not None + ): # next_layer is parametrized & has original bias ._bias + adjusted_bias = nn.Parameter(scaled_biases + next_layer._bias) # type: ignore[operator] + elif ( + not parametrize.is_parametrized(next_layer) and next_layer.bias is not None + ): # next_layer not parametrized & has .bias + adjusted_bias = nn.Parameter(scaled_biases + next_layer.bias) # type: ignore[operator] + else: # next_layer has no bias + adjusted_bias = nn.Parameter(scaled_biases) + return adjusted_bias + + +def _prune_module_bias(module: nn.Module, mask: Tensor) -> None: + r"""Applies mask to given modules bias""" + # prune bias along with weights, discard pruned indices of bias + original_bias = cast(Tensor, getattr(module, "_bias", module.bias)) + if original_bias is not None: + module.bias = nn.Parameter(original_bias[mask]) + + # remove _bias parameter + if hasattr(module, "_bias"): + delattr(module, "_bias") + + +def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Tensor | None: + r""" + In the case that we need to propagate biases, this function will return the biases we need + """ + # set current module bias + if module.bias is not None: + module.bias = nn.Parameter(cast(Tensor, module.bias)[mask]) + elif getattr(module, "_bias", None) is not None: + # pyrefly: ignore [bad-assignment] + module.bias = nn.Parameter(cast(Tensor, module._bias)[mask]) + + # get pruned biases to propagate to subsequent layer + if getattr(module, "_bias", None) is not None: + pruned_biases = cast(Tensor, module._bias)[~mask] + else: + pruned_biases = None + + if hasattr(module, "_bias"): + delattr(module, "_bias") + + return pruned_biases + + +# LINEAR +def _prune_linear_helper(linear: nn.Linear) -> Tensor: + # expects linear to be a parameterized linear module + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True) + linear.weight = nn.Parameter(linear.weight[mask]) # type: ignore[possibly-undefined] + linear.out_features = linear.weight.shape[0] + _remove_bias_handles(linear) + + # pyrefly: ignore [unbound-name] + return mask + + +def prune_linear(linear: nn.Linear) -> None: + mask = _prune_linear_helper(linear) + if getattr(linear, "prune_bias", False): + _prune_module_bias(linear, mask) + + +def prune_linear_linear(linear1: nn.Linear, linear2: nn.Linear) -> None: + prune_linear_activation_linear(linear1, None, linear2) + + +def prune_linear_activation_linear( + linear1: nn.Linear, + activation: Callable[[Tensor], Tensor] | None, + linear2: nn.Linear, +): + mask = _prune_linear_helper(linear1) + if getattr(linear1, "prune_bias", False): + _prune_module_bias(linear1, mask) + else: + pruned_biases = _propagate_module_bias(linear1, mask) + if pruned_biases is not None: + if activation: + pruned_biases = activation(pruned_biases) + linear2.bias = _get_adjusted_next_layer_bias(linear2, pruned_biases, mask) + + with torch.no_grad(): + if parametrize.is_parametrized(linear2): + parametrization_dict = cast(nn.ModuleDict, linear2.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, mask] + ) + linear2.in_features = weight_parameterizations.original.shape[1] + else: + linear2.weight = nn.Parameter(linear2.weight[:, mask]) + linear2.in_features = linear2.weight.shape[1] + + +# CONV2D +def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor: + parametrization_dict = cast(nn.ModuleDict, conv2d.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True) + conv2d.weight = nn.Parameter(conv2d.weight[mask]) # type: ignore[possibly-undefined] + conv2d.out_channels = conv2d.weight.shape[0] + + _remove_bias_handles(conv2d) + # pyrefly: ignore [unbound-name] + return mask + + +def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None: + parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(conv2d_1, "weight", leave_parametrized=True) + + if getattr(conv2d_1, "_bias", None) is not None: + if ( + conv2d_1.bias is not None + ): # conv2d_1 has original bias and bias propagated from previous layer + new_bias = torch.zeros(conv2d_1.bias.shape) + new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined] + # adjusted bias that to keep in conv2d_1 + # pyrefly: ignore [unbound-name] + new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask] + # pruned biases that are kept instead of propagated + conv2d_1.bias = nn.Parameter(new_bias) + else: # conv2d_1 has only original bias + conv2d_1.bias = nn.Parameter(cast(Tensor, conv2d_1._bias)) + else: + # no original bias, only propagated bias + if ( + conv2d_1.bias is not None + ): # conv2d_1 has bias propagated from previous layer + conv2d_1.bias.data[~mask] = 0 # type: ignore[possibly-undefined] + + if hasattr(conv2d_1, "_bias"): + delattr(conv2d_1, "_bias") + + +def prune_conv2d(conv2d: nn.Conv2d) -> None: + mask = _prune_conv2d_helper(conv2d) + if getattr(conv2d, "prune_bias", False): + _prune_module_bias(conv2d, mask) + + +def prune_conv2d_conv2d(conv2d_1: nn.Conv2d, conv2d_2: nn.Conv2d) -> None: + prune_conv2d_activation_conv2d(conv2d_1, None, conv2d_2) + + +def prune_conv2d_activation_conv2d( + conv2d_1: nn.Conv2d, + activation: Callable[[Tensor], Tensor] | None, + conv2d_2: nn.Conv2d, +): + r""" + Fusion Pattern for conv2d -> some activation module / function -> conv2d layers + """ + parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + prune_bias = getattr(conv2d_1, "prune_bias", False) + if ( + hasattr(conv2d_2, "padding") + and cast(tuple[int], conv2d_2.padding) > (0, 0) + and (conv2d_1.bias is not None or getattr(conv2d_1, "_bias", None) is not None) + ): + prune_conv2d_padded(conv2d_1) + else: + mask = _prune_conv2d_helper(conv2d_1) + if prune_bias: + _prune_module_bias(conv2d_1, mask) + else: + pruned_biases = _propagate_module_bias(conv2d_1, mask) + if pruned_biases is not None: + if activation: + pruned_biases = activation(pruned_biases) + conv2d_2.bias = _get_adjusted_next_layer_bias( + conv2d_2, pruned_biases, mask + ) + + if ( + not ( + hasattr(conv2d_2, "padding") + and cast(tuple[int], conv2d_2.padding) > (0, 0) + ) + or conv2d_1.bias is None + ): + with torch.no_grad(): + if parametrize.is_parametrized(conv2d_2): + parametrization_dict = cast( + nn.ModuleDict, conv2d_2.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, mask] + ) + conv2d_2.in_channels = weight_parameterizations.original.shape[1] + else: + conv2d_2.weight = nn.Parameter(conv2d_2.weight[:, mask]) + conv2d_2.in_channels = conv2d_2.weight.shape[1] + + +def prune_conv2d_pool_activation_conv2d( + c1: nn.Conv2d, + pool: nn.Module, + activation: Callable[[Tensor], Tensor] | None, + c2: nn.Conv2d, +) -> None: + prune_conv2d_activation_conv2d(c1, activation, c2) + + +def prune_conv2d_activation_pool_conv2d( + c1: nn.Conv2d, + activation: Callable[[Tensor], Tensor] | None, + pool: nn.Module, + c2: nn.Conv2d, +) -> None: + prune_conv2d_activation_conv2d(c1, activation, c2) + + +def prune_conv2d_pool_flatten_linear( + conv2d: nn.Conv2d, + pool: nn.Module, + flatten: Callable[[Tensor], Tensor] | None, + linear: nn.Linear, +) -> None: + mask = _prune_conv2d_helper(conv2d) + + # We map the pruned indices of the Conv2d output to the flattened indices of the Linear following the Flatten layer. + # we determine the flattening scale (h * w), and readjust `first_pruned_indices` + # (each idx maps to range idx * h * w to (idx+1) * h * w), `first_valid_indices`, + # and `pruned_biases` (repeat each bias by h * w). + if parametrize.is_parametrized(linear): + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + linear_ic = weight_parameterizations.original.shape[1] + else: + linear_ic = linear.weight.shape[1] + + conv2d_oc = len(mask) + if linear_ic % conv2d_oc != 0: + raise AssertionError( + f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported" + ) + + flatten_scale = linear_ic // conv2d_oc + flattened_mask = torch.tensor( + [[val] * flatten_scale for val in mask], dtype=torch.bool, device=mask.device + ).flatten() + + if getattr(conv2d, "prune_bias", False): + _prune_module_bias(conv2d, mask) + else: + pruned_biases = cast(Tensor, _propagate_module_bias(conv2d, mask)) + flattened_pruned_biases = torch.tensor( + [[bias] * flatten_scale for bias in pruned_biases], device=mask.device + ).flatten() + linear.bias = _get_adjusted_next_layer_bias( + linear, flattened_pruned_biases, flattened_mask + ) + + with torch.no_grad(): + if parametrize.is_parametrized(linear): + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, flattened_mask] + ) + linear.in_features = weight_parameterizations.original.shape[1] + else: + linear.weight = nn.Parameter(linear.weight[:, flattened_mask]) + linear.in_features = linear.weight.shape[1] + + +def prune_lstm_output_linear( + lstm: nn.LSTM, getitem: Callable, linear: nn.Linear +) -> None: + prune_lstm_output_layernorm_linear(lstm, getitem, None, linear) + + +def prune_lstm_output_layernorm_linear( + lstm: nn.LSTM, + getitem: Callable, + layernorm: nn.LayerNorm | None, + linear: nn.Linear, +) -> None: + for i in range(lstm.num_layers): + if parametrize.is_parametrized(lstm, f"weight_ih_l{i}"): + parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict[f"weight_ih_l{i}"] + ) + mask = weight_parameterizations[0].mask + + with torch.no_grad(): + parametrize.remove_parametrizations( + lstm, f"weight_ih_l{i}", leave_parametrized=True + ) + setattr( + lstm, + f"weight_ih_l{i}", + nn.Parameter(getattr(lstm, f"weight_ih_l{i}")[mask]), + ) + setattr( + lstm, + f"bias_ih_l{i}", + nn.Parameter(getattr(lstm, f"bias_ih_l{i}")[mask]), + ) + + if parametrize.is_parametrized(lstm, f"weight_hh_l{i}"): + parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict[f"weight_hh_l{i}"] + ) + mask = weight_parameterizations[0].mask + + with torch.no_grad(): + parametrize.remove_parametrizations( + lstm, f"weight_hh_l{i}", leave_parametrized=True + ) + # splitting out hidden-hidden masks + W_hi, W_hf, W_hg, W_ho = torch.split( + getattr(lstm, f"weight_hh_l{i}"), lstm.hidden_size + ) + M_hi, M_hf, M_hg, M_ho = torch.split(mask, lstm.hidden_size) # type: ignore[arg-type] + + # resize each individual weight separately + W_hi = W_hi[M_hi][:, M_hi] + W_hf = W_hf[M_hf][:, M_hf] + W_hg = W_hg[M_hg][:, M_hg] + W_ho = W_ho[M_ho][:, M_ho] + + # concat, use this as new weight + new_weight = torch.cat((W_hi, W_hf, W_hg, W_ho)) + setattr(lstm, f"weight_hh_l{i}", nn.Parameter(new_weight)) + setattr( + lstm, + f"bias_hh_l{i}", + nn.Parameter(getattr(lstm, f"bias_hh_l{i}")[mask]), + ) + + # If this is the final layer, then we need to prune linear layer columns + if i + 1 == lstm.num_layers: + lstm.hidden_size = int(M_hi.sum()) + with torch.no_grad(): + if parametrize.is_parametrized(linear): + parametrization_dict = cast( + nn.ModuleDict, linear.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, M_ho] + ) + linear.in_features = weight_parameterizations.original.shape[1] + else: + linear.weight = nn.Parameter(linear.weight[:, M_ho]) + linear.in_features = linear.weight.shape[1] + + # if layernorm module, prune weight and bias + if layernorm is not None: + layernorm.normalized_shape = (linear.in_features,) + layernorm.weight = nn.Parameter(layernorm.weight[M_ho]) + layernorm.bias = nn.Parameter(layernorm.bias[M_ho]) + + # otherwise need to prune the columns of the input of the next LSTM layer + else: + with torch.no_grad(): + if parametrize.is_parametrized(lstm, f"weight_ih_l{i + 1}"): + parametrization_dict = cast( + nn.ModuleDict, lstm.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, + getattr(parametrization_dict, f"weight_ih_l{i + 1}"), + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, M_ho] + ) + else: + next_layer_weight = getattr(lstm, f"weight_ih_l{i + 1}") + setattr( + lstm, + f"weight_ih_l{i + 1}", + nn.Parameter(next_layer_weight[:, M_ho]), + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..11c4652a7f0dafe2d3dd94f85c68fece035fd827 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py @@ -0,0 +1,35 @@ +# mypy: allow-untyped-defs +from .base_structured_sparsifier import BaseStructuredSparsifier + + +class SaliencyPruner(BaseStructuredSparsifier): + """ + Prune rows based on the saliency (L1 norm) of each row. + + This pruner works on N-Dimensional weight tensors. + For each row, we will calculate the saliency, which is the sum the L1 norm of all weights in that row. + We expect that the resulting saliency vector has the same shape as our mask. + We then pick elements to remove until we reach the target sparsity_level. + """ + + def update_mask(self, module, tensor_name, **kwargs): + # tensor_name will give you the FQN, all other entries in sparse config is present in kwargs + weights = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + # use negative weights so we can use topk (we prune out the smallest) + if weights.dim() <= 1: + raise Exception( # noqa: TRY002 + "Structured pruning can only be applied to a 2+dim weight tensor!" + ) + saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1) + if saliency.shape != mask.shape: + raise AssertionError( + f"saliency shape ({saliency.shape}) must match mask shape ({mask.shape})" + ) + + num_to_pick = int(len(mask) * kwargs["sparsity_level"]) + prune = saliency.topk(num_to_pick).indices + + # Set the mask to be false for the rows we want to prune + mask.data[prune] = False diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4f1100bd9cb7e2f70ade8f9757447bfa0eac988 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8068aa4aec5a94aa82e8d2a0b9baaa5bd88028a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__pycache__/cubic_scheduler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__pycache__/cubic_scheduler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f7bb43a8317a6351dd7102e34cf75b5fc24c462 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__pycache__/cubic_scheduler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__pycache__/lambda_scheduler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__pycache__/lambda_scheduler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5166088816111028f62e455d810d0bf82422e976 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__pycache__/lambda_scheduler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4174725b556c620c9e186b9c1cbd6fbe16d3b216 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..797776874cafaaf704ec546ff87cf6cd61fe570b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/nearly_diagonal_sparsifier.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/nearly_diagonal_sparsifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e96c8737b19df548a6d33225b99ca4f633364e7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/nearly_diagonal_sparsifier.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c651988dc756061f4df4f073376e4f7c8b625f2a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a31af8f1bed61cc4eedd42af24e5237058edb147 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cae552a9d68fe738d2966e817c9edbcd3783628 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f47cb8fb39f2f37994e9ac0b4308ebe6a68977c7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adcdbcc0ade9aec0d41e5eea7bab48afed62bdc4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3a90f61785a0c7938188301ab25492e6cf8809f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c881367a645009e9937b22eb83e9a5409babc637 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e9e92d79c0e584277ece75510056eeabe191bea Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a072f91d6822ccb3dd7e141141dead3c2dd73da9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/observer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/observer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f37a4964e99bb06f49e5224e75e5ca66d1b60ffb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/observer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b79991a344672f92c8c3728d479959350995d79 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b5309938a72bdc9cc5629df91a4548f2a64397d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9dd74c724227971d0441d6e175d04c90db6e307 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4949bc0efb78489bbfdf68131834955aa72eee6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73001ae4bc28894363ca1d91ece94d7385e037a2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e26b4f8ca083249e9390a2df8e6e9c0075049b0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74fcae76ba3ad4353dd9093a4366a7d4f65bd7c2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a603f5fcb414ece002cb9a1c5d57ecd51ee3dce3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..583391b91aae9da252dc427e7df64d8fa33763a9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..427d17634bb8c37018c1bd296ce69b2581f37adf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d452359d41c36dc719e6df932b6fb018ca6a36b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__init__.py @@ -0,0 +1,30 @@ +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + DTypeWithConstraints, + ObservationType, +) +from .executorch import get_executorch_backend_config +from .fbgemm import get_fbgemm_backend_config +from .native import get_native_backend_config, get_native_backend_config_dict +from .onednn import get_onednn_backend_config +from .qnnpack import get_qnnpack_backend_config +from .tensorrt import get_tensorrt_backend_config, get_tensorrt_backend_config_dict + + +__all__ = [ + "get_fbgemm_backend_config", + "get_native_backend_config", + "get_native_backend_config_dict", + "get_qnnpack_backend_config", + "get_tensorrt_backend_config", + "get_tensorrt_backend_config_dict", + "get_executorch_backend_config", + "BackendConfig", + "BackendPatternConfig", + "DTypeConfig", + "DTypeWithConstraints", + "ObservationType", + "get_onednn_backend_config", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb322fd85d2c2b07a68f3b436e4f0536ac87e28 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -0,0 +1,782 @@ +# mypy: allow-untyped-defs +import copy +import operator +from collections import namedtuple +from collections.abc import Callable + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.qat as nnqat +import torch.ao.nn.quantized.reference as nnqr +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fuser_method_mappings import ( + _sequential_wrapper2, + fuse_conv_bn, + fuse_conv_bn_relu, + fuse_convtranspose_bn, + fuse_linear_bn, +) + +from .backend_config import ( + BackendPatternConfig, + DTypeConfig, + DTypeWithConstraints, + ObservationType, +) + + +__all__: list[str] = [] + +# TODO: rename to be more explicit, e.g. qat_conv_relu +_ConvMetadata = namedtuple( + "_ConvMetadata", + [ + "root", + "transpose", + "bn", + "reference", + "transpose_reference", + "fused_conv_relu", + "fused_conv_bn", + "fused_conv_bn_relu", + "qat", + "relu_qat", + "bn_qat", + "bn_relu_qat", + "func", + "func_transpose", + ], +) +_Conv1dMetadata = _ConvMetadata( + nn.Conv1d, + nn.ConvTranspose1d, + nn.BatchNorm1d, + nnqr.Conv1d, + nnqr.ConvTranspose1d, + nni.ConvReLU1d, + nni.ConvBn1d, + nni.ConvBnReLU1d, + nnqat.Conv1d, + nniqat.ConvReLU1d, + nniqat.ConvBn1d, + nniqat.ConvBnReLU1d, + F.conv1d, + F.conv_transpose1d, +) +_Conv2dMetadata = _ConvMetadata( + nn.Conv2d, + nn.ConvTranspose2d, + nn.BatchNorm2d, + nnqr.Conv2d, + nnqr.ConvTranspose2d, + nni.ConvReLU2d, + nni.ConvBn2d, + nni.ConvBnReLU2d, + nnqat.Conv2d, + nniqat.ConvReLU2d, + nniqat.ConvBn2d, + nniqat.ConvBnReLU2d, + F.conv2d, + F.conv_transpose2d, +) +_Conv3dMetadata = _ConvMetadata( + nn.Conv3d, + nn.ConvTranspose3d, + nn.BatchNorm3d, + nnqr.Conv3d, + nnqr.ConvTranspose3d, + nni.ConvReLU3d, + nni.ConvBn3d, + nni.ConvBnReLU3d, + nnqat.Conv3d, + nniqat.ConvReLU3d, + nniqat.ConvBn3d, + nniqat.ConvBnReLU3d, + F.conv3d, + F.conv_transpose3d, +) + +# Add constraints for fixed qparams ops like sigmoid and tanh to ensure values +# fall within the proper ranges, e.g. [0, 1] for sigmoid, [-1, 1] for tanh +_FIXED_QPARAM_OP_0TO1_CONSTRAINTS = DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_exact_match=1.0 / 256.0, + zero_point_exact_match=0, +) +_FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS = DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_exact_match=2.0 / 256.0, + zero_point_exact_match=128, +) +_FIXED_QPARAMS_OP_TO_CONSTRAINTS: dict[Callable | str, DTypeWithConstraints] = { + torch.nn.Hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.functional.hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "hardsigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "hardsigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "sigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "sigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Softmax: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + torch.tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + "tanh": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + "tanh_": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, +} + + +def _get_binary_op_configs( + dtype_configs: list[DTypeConfig], +) -> list[BackendPatternConfig]: + binary_op_configs: list[BackendPatternConfig] = [] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + for op_with_quantized_bop_scalar_variant in [ + operator.add, + torch.add, + operator.mul, + torch.mul, + ]: + bop_patterns = [ + (op_with_quantized_bop_scalar_variant, nn.ReLU), + (op_with_quantized_bop_scalar_variant, F.relu), + (op_with_quantized_bop_scalar_variant, torch.relu), + op_with_quantized_bop_scalar_variant, + ] + binary_op_configs.extend( + BackendPatternConfig(bop_pattern) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type( + num_tensor_args_to_observation_type_mapping + ) + for bop_pattern in bop_patterns + ) + # matmul + binary_op_configs.append( + BackendPatternConfig(torch.matmul).set_dtype_configs(dtype_configs) # noqa: E131 + ) + return binary_op_configs + + +def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]: + """ + Return all configs related to linear modules and ops. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + linear_configs: list[BackendPatternConfig] = [] + + # (1) Single linear modules/functions + # ------------------------------------- + # linear module + linear_configs.append( + BackendPatternConfig(torch.nn.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nnqat.Linear) + ) + # linear qat module + linear_configs.append( + BackendPatternConfig(nnqat.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + # functional linear + linear_configs.append( + BackendPatternConfig(torch.nn.functional.linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + + # (2) Linear + relu + # ------------------- + # 2.1 linear module + relu fusion config + # linear relu, linear module + relu module + linear_configs.append( + BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(nni.LinearReLU)) + .set_fused_module(nni.LinearReLU) + ) + # linear relu, linear module + functional relu + linear_configs.append( + BackendPatternConfig((torch.nn.Linear, torch.nn.functional.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(nni.LinearReLU)) + .set_fused_module(nni.LinearReLU) + ) + + # 2.2 linear module + relu, fused module configs + # linear relu, fused module + linear_configs.append( + BackendPatternConfig(nni.LinearReLU) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nniqat.LinearReLU) + ) + # linear relu, qat fused module + linear_configs.append( + BackendPatternConfig(nniqat.LinearReLU) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + # 2.3 functional linear + relu configs + # linear relu, functional linear + relu module + linear_configs.append( + BackendPatternConfig((F.linear, torch.nn.ReLU)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # linear relu, functional linear + functional relu + linear_configs.append( + BackendPatternConfig((F.linear, F.relu)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + + # (3) Linear + batchnorm + # ------------------------ + # 3.1 linear bn fusion + linear_configs.append( + BackendPatternConfig((nn.Linear, nn.BatchNorm1d)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_linear_bn) + .set_fused_module(nni.LinearBn1d) + ) + + # 3.2 linear bn fused + # linear bn, fused module + linear_configs.append( + BackendPatternConfig(nni.LinearBn1d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nniqat.LinearBn1d) + ) + # linear bn, qat fused module + linear_configs.append( + BackendPatternConfig(nniqat.LinearBn1d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + return linear_configs + + +def _get_conv_configs(dtype_configs): + """ + Return all configs related to conv modules and ops. + """ + conv_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + for convs in [_Conv1dMetadata, _Conv2dMetadata, _Conv3dMetadata]: + # (1) Single conv modules/functions + # ----------------------------------- + # conv module + conv_configs.append( + BackendPatternConfig(convs.root) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.qat) + ) + # conv qat module + conv_configs.append( + BackendPatternConfig(convs.qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # functional conv + conv_configs.append( + BackendPatternConfig(convs.func) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + + # (2) Conv + relu + # ----------------- + # 2.1 conv module + relu fusion configs + # conv relu fusion, conv module + relu module + conv_configs.append( + BackendPatternConfig((convs.root, torch.nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # conv relu fusion, conv module + functional relu + conv_configs.append( + BackendPatternConfig((convs.root, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # 2.2 conv module + relu fused module configs + # conv relu, fused module + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.relu_qat) + ) + # conv relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # 2.3 functional conv + relu configs + # conv relu, functional conv + relu module + conv_configs.append( + BackendPatternConfig((convs.func, torch.nn.ReLU)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # conv relu, functional conv + functional relu + conv_configs.append( + BackendPatternConfig((convs.func, F.relu)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + + # fused conv relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.relu_qat) + ) + + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + + # (3) Conv + batchnorm (+ relu) + # ------------------------------- + # 3.1 conv bn fusion configs + # conv + bn fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn) + .set_fused_module(convs.fused_conv_bn) + ) + # conv + bn + relu module fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # conv + bn + relu functional fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # TODO: we can add fusion for torch.relu as well + + # 3.2 conv + bn (+ relu) fused module configs + # fused conv bn + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_qat) + ) + + # fused conv bn relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_relu_qat) + ) + + # conv bn, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # conv bn relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + + # (4) conv transpose and its fusion + # 4.1 conv transpose config + conv_configs.append( + BackendPatternConfig(convs.transpose) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.transpose) + .set_reference_quantized_module(convs.transpose_reference) + ) + + # 4.2 conv transpose + bn fusion + conv_configs.append( + BackendPatternConfig((convs.transpose, convs.bn)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_convtranspose_bn) + .set_root_module(convs.transpose) + .set_reference_quantized_module(convs.transpose_reference) + ) + + # 4.3 functional conv transpose + conv_configs.append( + BackendPatternConfig(convs.func_transpose) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + + return conv_configs + + +def _get_cat_config(dtype_configs: list[DTypeConfig]) -> BackendPatternConfig: + return ( + BackendPatternConfig(torch.cat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + + +def _get_ln_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]: + ln_configs = [] + ln_configs.append( + BackendPatternConfig(torch.nn.LayerNorm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + ln_configs.append( + BackendPatternConfig(torch.nn.functional.layer_norm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 2, "bias": 3}) + ) + return ln_configs + + +def _get_default_op_configs( + dtype_configs: list[DTypeConfig], +) -> list[BackendPatternConfig]: + default_ops = [ + torch.nn.ELU, + torch.nn.LeakyReLU, + torch.nn.Hardswish, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.Dropout, + torch.nn.PReLU, + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.leaky_relu, + torch.nn.functional.dropout, + ] + configs = [ + BackendPatternConfig(op) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + for op in default_ops + ] + + configs.append( + BackendPatternConfig(torch.nn.functional.group_norm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 2, "bias": 3}) + ) + + configs.append( + BackendPatternConfig(torch.nn.functional.instance_norm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 3, "bias": 4}) + ) + return configs + + +def _add_fixed_qparams_to_dtype_configs( + dtype_configs: list[DTypeConfig], + constraints: DTypeWithConstraints, +) -> list[DTypeConfig]: + """ + Return a copy of the list of DTypeConfigs where activations are subject to the specified + constraints required for fixed qparams ops. + + If the data type doesn't match the one in the constraints, simply leave the corresponding + DTypeConfig unchanged. + + If `scale_min_lower_bound` or `scale_max_upper_bound` is specified in the activations, + throw an exception since these settings are incompatible with fixed qparams ops. + """ + new_dtype_configs = [] + for dtype_config in dtype_configs: + dc = copy.deepcopy(dtype_config) + for orig_constraints in [ + dc.input_dtype_with_constraints, + dc.output_dtype_with_constraints, + ]: + if orig_constraints.dtype != constraints.dtype: + continue + if orig_constraints.scale_min_lower_bound is not None: + raise ValueError( + f"scale_min_lower_bound is invalid for fixed qparams ops: {dtype_config}" + ) + if orig_constraints.scale_max_upper_bound is not None: + raise ValueError( + f"scale_max_upper_bound is invalid for fixed qparams ops: {dtype_config}" + ) + orig_constraints.quant_min_lower_bound = constraints.quant_min_lower_bound + orig_constraints.quant_max_upper_bound = constraints.quant_max_upper_bound + orig_constraints.scale_exact_match = constraints.scale_exact_match + orig_constraints.zero_point_exact_match = constraints.zero_point_exact_match + new_dtype_configs.append(dc) + return new_dtype_configs + + +def _get_fixed_qparams_op_configs( + dtype_configs: list[DTypeConfig], +) -> list[BackendPatternConfig]: + fixed_qparams_op_configs = [] + for fixed_qparam_op, constraints in _FIXED_QPARAMS_OP_TO_CONSTRAINTS.items(): + new_dtype_configs = _add_fixed_qparams_to_dtype_configs( + dtype_configs, constraints + ) + fixed_qparams_op_configs.append( + BackendPatternConfig(fixed_qparam_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(new_dtype_configs) + ) + return fixed_qparams_op_configs + + +def _get_share_qparams_op_configs(dtype_configs): + """Get the operator config for the operators that works for both float and quantized input + if input is quantized, the output Tensor shares the same quantization parameter + with input. + Example operator: avgpool2d, reshape, transpose, maxpool2d + Example observed operator: + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + + def _get_share_qprams_op_backend_config(op): + return ( + BackendPatternConfig(op) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + + share_qparams_ops = [ + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.Hardtanh, + torch.nn.Identity, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.PixelShuffle, + torch.nn.PixelUnshuffle, + torch.nn.ReLU, + torch.nn.ReLU6, + torch.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.interpolate, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.pixel_shuffle, + torch.nn.functional.pixel_unshuffle, + torch.nn.functional.relu, + torch.nn.functional.relu6, + torch.avg_pool1d, + torch._C._nn.avg_pool2d, + torch._C._nn.avg_pool3d, + torch.clamp, + torch.flatten, + torch.mean, + torch.narrow, + torch.repeat_interleave, + torch.transpose, + torch.squeeze, + torch.stack, + torch.unsqueeze, + operator.floordiv, + "contiguous", + "clamp", + "detach", + "detach_", + "mean", + "permute", + "repeat", + "repeat_interleave", + "reshape", + "resize_", + "relu", + "relu_", + "squeeze", + "squeeze_", + "transpose", + "unsqueeze", + "unsqueeze_", + "view", + ] + return [_get_share_qprams_op_backend_config(op) for op in share_qparams_ops] + + +def _get_bn_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]: + """Get configs related to batchnorm.""" + bn_configs = [] + bn_to_fused_bn = { + torch.nn.BatchNorm2d: nni.BNReLU2d, + torch.nn.BatchNorm3d: nni.BNReLU3d, + } + for bn in bn_to_fused_bn: + fused_bn = bn_to_fused_bn[bn] + # bn module + relu module fusion config + bn_configs.append( + BackendPatternConfig((bn, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(fused_bn)) + .set_fused_module(fused_bn) + ) + # bn module + F.relu fusion config + bn_configs.append( + BackendPatternConfig((bn, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(fused_bn)) + .set_fused_module(fused_bn) + ) + bn_configs.append( + BackendPatternConfig(bn) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + + # fused bn configs + for fused_bn in bn_to_fused_bn.values(): + bn_configs.append( + BackendPatternConfig(fused_bn) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + return bn_configs + + +def _get_rnn_op_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]: + rnn_op_configs = [] + for rnn_op, ref_rnn_op in [ + (nn.GRUCell, nnqr.GRUCell), + (nn.LSTMCell, nnqr.LSTMCell), + (nn.RNNCell, nnqr.RNNCell), + (nn.LSTM, nnqr.LSTM), + (nn.GRU, nnqr.GRU), + ]: + rnn_op_configs.append( + BackendPatternConfig(rnn_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(rnn_op) + .set_reference_quantized_module(ref_rnn_op) + ) + return rnn_op_configs + + +def _get_embedding_op_configs( + dtype_configs: list[DTypeConfig], +) -> list[BackendPatternConfig]: + embedding_op_configs = [] + for embedding_op, qat_embedding_op, ref_embedding_op in [ + (nn.Embedding, nnqat.Embedding, nnqr.Embedding), + (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag), + ]: + embedding_op_configs.append( + BackendPatternConfig(embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_qat_module(qat_embedding_op) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + + # config for qat op + embedding_op_configs.append( + BackendPatternConfig(qat_embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + return embedding_op_configs + + +def _get_tensor_info_op_configs(dtype_configs): + """ + These ops work on tensors of different dtypes but return non-tensors + containing information about the input tensor. + """ + + def _get_config(op): + return ( + BackendPatternConfig(op) + .set_observation_type(ObservationType.INPUT_OUTPUT_NOT_OBSERVED) + .set_dtype_configs(dtype_configs) + ) + + return [_get_config(op) for op in ("shape", "size")] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py new file mode 100644 index 0000000000000000000000000000000000000000..d4e67b79c370207c228fd66d33fadad03a58ed2a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +import operator + +import torch +from torch.ao.quantization.backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, +) + + +weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + + +def get_linear_configs(): + linear_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + + # TODO: need to fix the way we insert observers for this pattern + # should be solved in the new fusion API + # reason that this doesn't work: the pattern is a bit complicated and we don't + # have a way to specify which input of the pattern we would like to observe + # pattern: + # bias input weight + # \ | / + # \ | t + # \ | / + # addmm + # we want to observe "weight" as weight, but there is not way to convey this + # information with current pattern language + # + # right now: + # original: + # weight - t \ + # input - addmm + # observed (no hack): + # weight - t - observer \ + # input - observer - addmm + # target: + # weight - observer - t \ + # input - observer - addmm + + # def root_node_getter(node_pattern): + # addmm, bias, act, weight = node_pattern + # return addmm + + # linear_configs.append( + # BackendPatternConfig((torch.ops.aten.addmm.default, MatchAllNode, MatchAllNode, torch.ops.aten.t.default)) + # .set_observation_type(observation_type) # noqa: E131 + # .set_dtype_configs(dtype_configs) + # ._set_root_node_getter(root_node_getter)) + + linear_configs.append( + BackendPatternConfig(torch.ops.aten.addmm.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 2, "bias": 0}) + ) + # linear is decomposed to `t - mm` if bias is not present + linear_configs.append( + BackendPatternConfig(torch.ops.aten.mm.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1}) + ) + return linear_configs + + +def get_conv_configs(): + conv_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + conv_configs.append( + BackendPatternConfig(torch.ops.aten.convolution.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + conv_configs.append( + BackendPatternConfig( + (torch.ops.aten.convolution.default, torch.ops.aten.relu.default) + ) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + # TODO: remove when functionalization is supported in PT2 mode + conv_configs.append( + BackendPatternConfig( + (torch.ops.aten.convolution.default, torch.ops.aten.relu_.default) + ) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + return conv_configs + + +def get_pooling_configs(): + backend_pattern_configs = [] + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + + def root_node_getter(node_pattern): + _getitem, maxpool, _index = node_pattern + return maxpool + + backend_pattern_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0) + ) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_root_node_getter(root_node_getter) + ) + + return backend_pattern_configs + + +def get_relu_configs(): + backend_pattern_configs = [] + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + backend_pattern_configs.append( + BackendPatternConfig(torch.ops.aten.relu.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + return backend_pattern_configs + + +def get_binary_op_configs(): + binary_op_configs: list[BackendPatternConfig] = [] + dtype_configs = [weighted_op_quint8_dtype_config] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + for op_with_quantized_bop_scalar_variant in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ]: + bop_patterns = [ + (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu.default), + op_with_quantized_bop_scalar_variant, + # TODO: remove when functionalization is supported in pt2_mode + (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default), + ] + binary_op_configs.extend( + BackendPatternConfig(bop_pattern) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type( + num_tensor_args_to_observation_type_mapping + ) + for bop_pattern in bop_patterns + ) + + return binary_op_configs + + +def get_qnnpack_pt2e_backend_config(): + return ( + BackendConfig("qnnpack_pytorch_2.0_export") + .set_backend_pattern_configs(get_linear_configs()) + .set_backend_pattern_configs(get_binary_op_configs()) + .set_backend_pattern_configs(get_conv_configs()) + .set_backend_pattern_configs(get_pooling_configs()) + .set_backend_pattern_configs(get_relu_configs()) + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/backend_config.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/backend_config.py new file mode 100644 index 0000000000000000000000000000000000000000..96a0b44a3afdf5a663319e67e1b422e17136d4d5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/backend_config.py @@ -0,0 +1,751 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any, TYPE_CHECKING + +import torch + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.ao.quantization.utils import Pattern + + +__all__ = [ + "BackendConfig", + "BackendPatternConfig", + "DTypeConfig", + "DTypeWithConstraints", + "ObservationType", +] + + +# DTypeConfig dict keys +INPUT_DTYPE_DICT_KEY = "input_dtype" +OUTPUT_DTYPE_DICT_KEY = "output_dtype" +WEIGHT_DTYPE_DICT_KEY = "weight_dtype" +BIAS_DTYPE_DICT_KEY = "bias_dtype" +IS_DYNAMIC_DICT_KEY = "is_dynamic" + +# BackendConfig dict keys +NAME_DICT_KEY = "name" +CONFIGS_DICT_KEY = "configs" + +# BackendPatternConfig dict keys +PATTERN_DICT_KEY = "pattern" +PATTERN_COMPLEX_FORMAT_DICT_KEY = "pattern_complex_format" +OBSERVATION_TYPE_DICT_KEY = "observation_type" +DTYPE_CONFIGS_DICT_KEY = "dtype_configs" +ROOT_MODULE_DICT_KEY = "root_module" +QAT_MODULE_DICT_KEY = "qat_module" +REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root" +FUSED_MODULE_DICT_KEY = "fused_module" +FUSER_METHOD_DICT_KEY = "fuser_method" +ROOT_NODE_GETTER_DICT_KEY = "root_node_getter" +EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter" +NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type" +INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index" + + +# TODO: maybe rename this to something that's not related to observer +# e.g. QParamsType +class ObservationType(Enum): + """An enum that represents different ways of how an operator/operator pattern + should be observed + """ + + OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0 + """this means input and output are observed with different observers, based + on qconfig.activation + example: conv, linear, softmax + """ + + OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1 + """this means the output will use the same observer instance as input, based + on qconfig.activation + example: torch.cat, maxpool + """ + + INPUT_OUTPUT_NOT_OBSERVED = 2 + """this means the input and output are never observed + example: x.shape, x.size + """ + + +@dataclass +class DTypeWithConstraints: + """ + Config for specifying additional constraints for a given dtype, such as quantization + value ranges, scale value ranges, and fixed quantization params, to be used in + :class:`~torch.ao.quantization.backend_config.DTypeConfig`. + + The constraints currently supported are: + + * `quant_min_lower_bound` and `quant_max_upper_bound`: Lower and upper + bounds for the minimum and maximum quantized values respectively. If + the QConfig's `quant_min` and `quant_max` fall outside this range, + then the QConfig will be ignored. + + * `scale_min_lower_bound` and `scale_max_upper_bound`: Lower and upper + bounds for the minimum and maximum scale values respectively. If the + QConfig's minimum scale value (currently exposed as `eps`) falls below + the lower bound, then the QConfig will be ignored. Note that the upper + bound is currently not enforced. + + * `scale_exact_match` and `zero_point_exact_match`: Exact match requirements + for scale and zero point, to be used for operators with fixed quantization + parameters such as sigmoid and tanh. If the observer specified in the QConfig + is neither `FixedQParamsObserver` nor `FixedQParamsFakeQuantize`, or if + the quantization parameters don't match, then the QConfig will be ignored. + """ + + dtype: torch.dtype | None = None + quant_min_lower_bound: int | float | None = None + quant_max_upper_bound: int | float | None = None + scale_min_lower_bound: int | float | None = None + scale_max_upper_bound: int | float | None = None + scale_exact_match: float | None = None + zero_point_exact_match: int | None = None + + +@dataclass +class DTypeConfig: + """ + Config object that specifies the supported data types passed as arguments to + quantize ops in the reference model spec, for input and output activations, + weights, and biases. + + For example, consider the following reference model: + + quant1 - [dequant1 - fp32_linear - quant2] - dequant2 + + The pattern in the square brackets refers to the reference pattern of + statically quantized linear. Setting the input dtype as `torch.quint8` + in the DTypeConfig means we pass in `torch.quint8` as the dtype argument + to the first quantize op (quant1). Similarly, setting the output dtype as + `torch.quint8` means we pass in `torch.quint8` as the dtype argument to + the second quantize op (quant2). + + Note that the dtype here does not refer to the interface dtypes of the + op. For example, the "input dtype" here is not the dtype of the input + tensor passed to the quantized linear op. Though it can still be the + same as the interface dtype, this is not always the case, e.g. the + interface dtype is fp32 in dynamic quantization but the "input dtype" + specified in the DTypeConfig would still be quint8. The semantics of + dtypes here are the same as the semantics of the dtypes specified in + the observers. + + These dtypes are matched against the ones specified in the user's + QConfig. If there is a match, and the QConfig satisfies the constraints + specified in the DTypeConfig (if any), then we will quantize the given + pattern using this DTypeConfig. Otherwise, the QConfig is ignored and + the pattern will not be quantized. + + Example usage:: + + >>> # xdoctest: +SKIP(failing) + >>> dtype_config1 = DTypeConfig( + ... input_dtype=torch.quint8, + ... output_dtype=torch.quint8, + ... weight_dtype=torch.qint8, + ... bias_dtype=torch.float) + + >>> dtype_config2 = DTypeConfig( + ... input_dtype=DTypeWithConstraints( + ... dtype=torch.quint8, + ... quant_min_lower_bound=0, + ... quant_max_upper_bound=255, + ... ), + ... output_dtype=DTypeWithConstraints( + ... dtype=torch.quint8, + ... quant_min_lower_bound=0, + ... quant_max_upper_bound=255, + ... ), + ... weight_dtype=DTypeWithConstraints( + ... dtype=torch.qint8, + ... quant_min_lower_bound=-128, + ... quant_max_upper_bound=127, + ... ), + ... bias_dtype=torch.float) + + >>> dtype_config1.input_dtype + torch.quint8 + + >>> dtype_config2.input_dtype + torch.quint8 + + >>> dtype_config2.input_dtype_with_constraints + DTypeWithConstraints(dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=255, \ +scale_min_lower_bound=None, scale_max_upper_bound=None) + """ + + input_dtype_with_constraints: DTypeWithConstraints + output_dtype_with_constraints: DTypeWithConstraints + weight_dtype_with_constraints: DTypeWithConstraints + bias_dtype: torch.dtype | None + is_dynamic: bool | None + + def __init__( + self, + input_dtype: torch.dtype | DTypeWithConstraints | None = None, + output_dtype: torch.dtype | DTypeWithConstraints | None = None, + weight_dtype: torch.dtype | DTypeWithConstraints | None = None, + bias_dtype: torch.dtype | None = None, + is_dynamic: bool | None = None, + ): + if isinstance(input_dtype, DTypeWithConstraints): + self.input_dtype_with_constraints = input_dtype + else: + self.input_dtype_with_constraints = DTypeWithConstraints(dtype=input_dtype) + + if isinstance(output_dtype, DTypeWithConstraints): + self.output_dtype_with_constraints = output_dtype + else: + self.output_dtype_with_constraints = DTypeWithConstraints( + dtype=output_dtype + ) + + if isinstance(weight_dtype, DTypeWithConstraints): + self.weight_dtype_with_constraints = weight_dtype + else: + self.weight_dtype_with_constraints = DTypeWithConstraints( + dtype=weight_dtype + ) + + self.bias_dtype = bias_dtype + self.is_dynamic = is_dynamic + + @property + def input_dtype(self) -> torch.dtype | None: + return self.input_dtype_with_constraints.dtype + + @property + def output_dtype(self) -> torch.dtype | None: + return self.output_dtype_with_constraints.dtype + + @property + def weight_dtype(self) -> torch.dtype | None: + return self.weight_dtype_with_constraints.dtype + + @classmethod + def from_dict(cls, dtype_config_dict: dict[str, Any]) -> DTypeConfig: + """ + Create a ``DTypeConfig`` from a dictionary with the following items (all optional): + "input_dtype": torch.dtype or ``DTypeWithConstraints`` + "output_dtype": torch.dtype or ``DTypeWithConstraints`` + "weight_dtype": torch.dtype or ``DTypeWithConstraints`` + "bias_type": torch.dtype + "is_dynamic": bool + """ + input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY) + if input_dtype is not None and not isinstance( + input_dtype, (torch.dtype, DTypeWithConstraints) + ): + raise ValueError( + "Expected input_dtype to be a torch.dtype or DTypeWithConstraints" + ) + output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY) + if output_dtype is not None and not isinstance( + output_dtype, (torch.dtype, DTypeWithConstraints) + ): + raise ValueError( + "Expected output_dtype to be a torch.dtype or DTypeWithConstraints" + ) + weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY) + if weight_dtype is not None and not isinstance( + weight_dtype, (torch.dtype, DTypeWithConstraints) + ): + raise ValueError( + "Expected weight_dtype to be a torch.dtype or DTypeWithConstraints" + ) + bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY) + is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY) + return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic) + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``DTypeConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`. + """ + dtype_config_dict: dict[str, Any] = {} + if self.input_dtype is not None: + dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype_with_constraints + if self.output_dtype is not None: + dtype_config_dict[OUTPUT_DTYPE_DICT_KEY] = ( + self.output_dtype_with_constraints + ) + if self.weight_dtype is not None: + dtype_config_dict[WEIGHT_DTYPE_DICT_KEY] = ( + self.weight_dtype_with_constraints + ) + if self.bias_dtype is not None: + dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype + if self.is_dynamic is not None: + dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic + return dtype_config_dict + + +class BackendConfig: + # TODO: refer to NativeBackendConfig once that is implemented + """Config that defines the set of patterns that can be quantized on a given backend, and how reference + quantized models can be produced from these patterns. + + A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph + of the above. Each pattern supported on the target backend can be individually configured through + :class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of: + + (1) The supported input/output activation, weight, and bias data types + + (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and + + (3) (Optionally) Fusion, QAT, and reference module mappings. + + The format of the patterns is described in: + https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md + + Example usage:: + + import torch + from torch.ao.quantization.backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, + ) + + weighted_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float) + + def fuse_conv2d_relu(is_qat, conv, relu): + return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu) + + # For quantizing Linear + linear_config = BackendPatternConfig(torch.nn.Linear) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_root_module(torch.nn.Linear) \ + .set_qat_module(torch.ao.nn.qat.Linear) \ + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear) + + # For fusing Conv2d + ReLU into ConvReLU2d + conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \ + .set_fuser_method(fuse_conv2d_relu) + + # For quantizing ConvReLU2d + fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_root_module(torch.nn.Conv2d) \ + .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \ + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d) + + backend_config = BackendConfig("my_backend") \ + .set_backend_pattern_config(linear_config) \ + .set_backend_pattern_config(conv_relu_config) \ + .set_backend_pattern_config(fused_conv_relu_config) + + """ + + def __init__(self, name: str = ""): + self.name = name + # Store all BackendPatternConfigs in a map to handle duplicates + # Note: the key in this map uses the complex reversed tuple format. + # This is intended only for internal use; users who wish to access + # the original patterns should go through `self.configs` instead. + self._pattern_complex_format_to_config: dict[Pattern, BackendPatternConfig] = {} + + def __repr__(self): + return f"BackendConfig({self.__dict__})" + + def set_name(self, name: str) -> BackendConfig: + """ + Set the name of the target backend. + """ + self.name = name + return self + + def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig: + """ + Set the config for an pattern that can be run on the target backend. + This overrides any existing config for the given pattern. + """ + # Avoid circular dependencies + pattern_complex_format = torch.ao.quantization.backend_config.utils._get_pattern_in_reversed_nested_tuple_format( + config + ) # type: ignore[attr-defined] + self._pattern_complex_format_to_config[pattern_complex_format] = config + return self + + def set_backend_pattern_configs( + self, configs: list[BackendPatternConfig] + ) -> BackendConfig: + """ + Set the configs for patterns that can be run on the target backend. + This overrides any existing config for a given pattern if it was previously registered already. + """ + for conf in configs: + self.set_backend_pattern_config(conf) + return self + + @property + def configs(self) -> list[BackendPatternConfig]: + """ + Return a copy of the list of configs set in this `BackendConfig`. + """ + return list(self._pattern_complex_format_to_config.values()) + + @classmethod + def from_dict(cls, backend_config_dict: dict[str, Any]) -> BackendConfig: + """ + Create a ``BackendConfig`` from a dictionary with the following items: + + "name": the name of the target backend + + "configs": a list of dictionaries that each represents a `BackendPatternConfig` + + """ + conf = cls(backend_config_dict.get(NAME_DICT_KEY, "")) + for d in backend_config_dict.get(CONFIGS_DICT_KEY, []): + if isinstance(d, BackendPatternConfig): + conf.set_backend_pattern_config(d) + elif isinstance(d, dict): + conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d)) + else: + raise ValueError( + f"Expected backend_config_dict['{CONFIGS_DICT_KEY}'] to be a dictionary" + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``BackendConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`. + """ + return { + NAME_DICT_KEY: self.name, + CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs], + } + + +class BackendPatternConfig: + """ + Config object that specifies quantization behavior for a given operator pattern. + For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`. + """ + + def __init__(self, pattern: Pattern | None = None): + self.pattern: Pattern | None = pattern + self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + self.dtype_configs: list[DTypeConfig] = [] + self.root_module: type[torch.nn.Module] | None = None + self.qat_module: type[torch.nn.Module] | None = None + self.reference_quantized_module: type[torch.nn.Module] | None = None + self.fused_module: type[torch.nn.Module] | None = None + self.fuser_method: Callable | None = None + + # Temporary/internal configs + self._root_node_getter: Callable | None = None + self._extra_inputs_getter: Callable | None = None + self._num_tensor_args_to_observation_type: dict[int, ObservationType] = {} + self._input_type_to_index: dict[str, int] = {} + self._pattern_complex_format: Pattern | None = None + + def __repr__(self): + dict_nonempty = { + k: v + for k, v in self.__dict__.items() + if ( + (not isinstance(v, (list, dict)) and v is not None) + or (isinstance(v, (list, dict)) and len(v) > 0) + ) + } + return f"BackendPatternConfig({dict_nonempty})" + + def set_pattern(self, pattern: Pattern) -> BackendPatternConfig: + """ + Set the pattern to configure. + + The pattern can be a float module, functional operator, pytorch operator, or a tuple + combination of the above. Tuple patterns are treated as sequential patterns, and + currently only tuples of 2 or 3 elements are supported. + """ + if self._pattern_complex_format is not None: + raise ValueError( + "Only one of 'pattern' or 'pattern_complex_format' can be set" + ) + self.pattern = pattern + return self + + def set_observation_type( + self, observation_type: ObservationType + ) -> BackendPatternConfig: + """ + Set how observers should be inserted in the graph for this pattern. + + Observation type here refers to how observers (or quant-dequant ops) will be placed + in the graph. This is used to produce the desired reference patterns understood by + the backend. Weighted ops such as linear and conv require different observers + (or quantization parameters passed to quantize ops in the reference model) for the + input and the output. + + There are two observation types: + + `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance + will be different from the input. This is the most common observation type. + + `OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the + same as the input. This is useful for operators like `cat`. + + Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs + with observers (and fake quantizes) attached instead of observers themselves. + """ + self.observation_type = observation_type + return self + + def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig: + """ + Add a set of supported data types passed as arguments to quantize ops in the + reference model spec. + """ + self.dtype_configs.append(dtype_config) + return self + + def set_dtype_configs( + self, dtype_configs: list[DTypeConfig] + ) -> BackendPatternConfig: + """ + Set the supported data types passed as arguments to quantize ops in the + reference model spec, overriding all previously registered data types. + """ + self.dtype_configs = dtype_configs + return self + + def set_root_module( + self, root_module: type[torch.nn.Module] + ) -> BackendPatternConfig: + """ + Set the module that represents the root for this pattern. + + When we construct the reference quantized model during the convert phase, + the root modules (e.g. torch.nn.Linear for torch.ao.nn.intrinsic.LinearReLU) + will be swapped to the corresponding reference quantized modules (e.g. + torch.ao.nn.reference.quantized.Linear). This allows custom backends to + specify custom reference quantized module implementations to match the + numerics of their lowered operators. Since this is a one-to-one mapping, + both the root module and the reference quantized module must be specified + in the same BackendPatternConfig in order for the conversion to take place. + """ + self.root_module = root_module + return self + + def set_qat_module(self, qat_module: type[torch.nn.Module]) -> BackendPatternConfig: + """ + Set the module that represents the QAT implementation for this pattern. + """ + self.qat_module = qat_module + return self + + def set_reference_quantized_module( + self, reference_quantized_module: type[torch.nn.Module] + ) -> BackendPatternConfig: + """ + Set the module that represents the reference quantized implementation for + this pattern's root module. + + For more detail, see :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.set_root_module`. + """ + self.reference_quantized_module = reference_quantized_module + return self + + def set_fused_module( + self, fused_module: type[torch.nn.Module] + ) -> BackendPatternConfig: + """ + Set the module that represents the fused implementation for this pattern. + """ + self.fused_module = fused_module + return self + + def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig: + """ + Set the function that specifies how to fuse this BackendPatternConfig's pattern. + + The first argument of this function should be `is_qat`, and the rest of the arguments + should be the items in the tuple pattern. The return value of this function should be + the resulting fused module. + + For example, the fuser method for the pattern `(torch.nn.Linear, torch.nn.ReLU)` can be: + + def fuse_linear_relu(is_qat, linear, relu): + return torch.ao.nn.intrinsic.LinearReLU(linear, relu) + + For a more complicated example, see https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6. + """ + self.fuser_method = fuser_method + return self + + def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig: + self._root_node_getter = root_node_getter + return self + + def _set_extra_inputs_getter( + self, extra_inputs_getter: Callable + ) -> BackendPatternConfig: + self._extra_inputs_getter = extra_inputs_getter + return self + + def _set_num_tensor_args_to_observation_type( + self, num_tensor_args_to_observation_type: dict[int, ObservationType] + ) -> BackendPatternConfig: + self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type + return self + + def _set_input_type_to_index( + self, input_type_to_index: dict[str, int] + ) -> BackendPatternConfig: + self._input_type_to_index = input_type_to_index + return self + + def _set_pattern_complex_format(self, pattern: Pattern) -> BackendPatternConfig: + """ + Set the pattern to configure, using the reversed nested tuple format. + + See the BackendConfig README for more detail: + https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md#advanced-pattern-specification + """ + if self.pattern is not None: + raise ValueError( + "Only one of 'pattern' or 'pattern_complex_format' can be set" + ) + self._pattern_complex_format = pattern + return self + + @classmethod + def from_dict( + cls, backend_pattern_config_dict: dict[str, Any] + ) -> BackendPatternConfig: + """ + Create a ``BackendPatternConfig`` from a dictionary with the following items: + + "pattern": the pattern being configured + "observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how + observers should be inserted for this pattern + "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s + "root_module": a :class:`torch.nn.Module` that represents the root for this pattern + "qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern + "reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized + implementation for this pattern's root module. + "fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern + "fuser_method": a function that specifies how to fuse the pattern for this pattern + "pattern_complex_format": the pattern specified in the reversed nested tuple format (deprecated) + + """ + + def _get_dtype_config(obj: Any) -> DTypeConfig: + """ + Convert the given object into a ``DTypeConfig`` if possible, else throw an exception. + """ + if isinstance(obj, DTypeConfig): + return obj + if isinstance(obj, dict): + return DTypeConfig.from_dict(obj) + raise ValueError( + f"Expected a list of DTypeConfigs in " + f"backend_pattern_config_dict[\"{DTYPE_CONFIGS_DICT_KEY}\"], got '{type(obj)}'" + ) + + conf = cls() + if PATTERN_DICT_KEY in backend_pattern_config_dict: + conf.set_pattern(backend_pattern_config_dict[PATTERN_DICT_KEY]) + if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict: + conf.set_observation_type( + backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY] + ) + for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []): + conf.add_dtype_config(_get_dtype_config(d)) + conf.set_root_module( + backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY) # type: ignore[arg-type] + ) + conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY)) # type: ignore[arg-type] + conf.set_reference_quantized_module( + backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY) # type: ignore[arg-type] + ) + conf.set_fused_module( + backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY) # type: ignore[arg-type] + ) + conf.set_fuser_method( + backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY) # type: ignore[arg-type] + ) + conf._set_root_node_getter( + backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY) # type: ignore[arg-type] + ) + conf._set_extra_inputs_getter( + backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY) # type: ignore[arg-type] + ) + conf._set_num_tensor_args_to_observation_type( + backend_pattern_config_dict.get( + NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {} + ) + ) + conf._set_input_type_to_index( + backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {}) + ) + if PATTERN_COMPLEX_FORMAT_DICT_KEY in backend_pattern_config_dict: + conf._set_pattern_complex_format( + backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY] + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``BackendPatternConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`. + """ + backend_pattern_config_dict: dict[str, Any] = { + OBSERVATION_TYPE_DICT_KEY: self.observation_type, + DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs], + } + if self.pattern is not None: + backend_pattern_config_dict[PATTERN_DICT_KEY] = self.pattern + if self.root_module is not None: + backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module + if self.qat_module is not None: + backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module + if self.reference_quantized_module is not None: + backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY] = ( + self.reference_quantized_module + ) + if self.fused_module is not None: + backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module + if self.fuser_method is not None: + backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method + if self._root_node_getter is not None: + backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY] = ( + self._root_node_getter + ) + if self._extra_inputs_getter is not None: + backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY] = ( + self._extra_inputs_getter + ) + if len(self._num_tensor_args_to_observation_type) > 0: + backend_pattern_config_dict[ + NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY + ] = self._num_tensor_args_to_observation_type + if len(self._input_type_to_index) > 0: + backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = ( + self._input_type_to_index + ) + if self._pattern_complex_format is not None: + backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY] = ( + self._pattern_complex_format + ) + return backend_pattern_config_dict diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/executorch.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/executorch.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9b16492821b73dba1ff3ce6e2617d844d94229 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/executorch.py @@ -0,0 +1,498 @@ +# TODO: rename executorch to qnnpack_executorch since executorch is a general runtime +# not a specific backend + +import operator + +import torch +import torch.ao.nn.qat as nnqat +import torch.ao.nn.quantized.reference as nnqr +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fuser_method_mappings import ( + _sequential_wrapper2, + fuse_conv_bn, + fuse_conv_bn_relu, +) + +from ._common_operator_config_utils import _Conv2dMetadata +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + DTypeWithConstraints, + ObservationType, +) +from .qnnpack import ( + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_weighted_op_qint8_symmetric_dtype_config, +) + + +__all__ = [ + "get_executorch_backend_config", +] + + +# =================== +# | DTYPE CONFIGS | +# =================== + +executorch_weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +executorch_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +executorch_default_dynamic_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +executorch_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + scale_min_lower_bound=2**-12, +) + +executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + quant_min_lower_bound=-127, + quant_max_upper_bound=127, + scale_min_lower_bound=2**-12, +) + +executorch_default_dynamic_qint8_dtype_config = DTypeConfig( + input_dtype=executorch_act_qint8_scale_min_2_neg_12, + output_dtype=torch.float, + weight_dtype=executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12, + bias_dtype=torch.float, + is_dynamic=True, +) + +executorch_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +executorch_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + + +# ============================= +# | BACKEND PATTERN CONFIGS | +# ============================= + + +def _get_linear_configs() -> list[BackendPatternConfig]: + """ + Return all configs related to linear modules and ops. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config, + executorch_default_dynamic_quint8_dtype_config, + executorch_default_dynamic_qint8_dtype_config, + executorch_default_dynamic_float16_dtype_config, + ] + linear_configs: list[BackendPatternConfig] = [] + # linear module + linear_configs.append( + BackendPatternConfig(torch.nn.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nnqat.Linear) + ) + # linear qat module + linear_configs.append( + BackendPatternConfig(nnqat.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + # functional linear + linear_configs.append( + BackendPatternConfig(torch.nn.functional.linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + return linear_configs + + +def _get_conv_configs() -> list[BackendPatternConfig]: + """ + Return all configs related to conv modules and ops. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config, + ] + conv_configs = [] + for convs in [_Conv2dMetadata]: + # (1) Single conv modules/functions + # ----------------------------------- + # conv module + conv_configs.append( + BackendPatternConfig(convs.root) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.qat) + ) + # conv qat module + conv_configs.append( + BackendPatternConfig(convs.qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # functional conv + conv_configs.append( + BackendPatternConfig(convs.func) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + + # (2) Conv + relu + # ----------------------------------- + # conv module + relu module + conv_configs.append( + BackendPatternConfig((convs.root, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # conv module + functional relu + conv_configs.append( + BackendPatternConfig((convs.root, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # fused conv relu module + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.relu_qat) + ) + # conv relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # functional conv + relu module + conv_configs.append( + BackendPatternConfig((convs.func, nn.ReLU)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # functional conv + functional relu + conv_configs.append( + BackendPatternConfig((convs.func, F.relu)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # fused conv relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.relu_qat) + ) + + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + + # (3) Conv + batchnorm (+ relu) + # ------------------------------- + # conv + batchnorm (+ relu) + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn) + .set_fused_module(convs.fused_conv_bn) + ) + # conv + bn + relu module fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # conv + bn + relu functional fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # TODO: we can add fusion for torch.relu as well + # 3.2 conv + bn (+ relu) fused module configs + # fused conv bn + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_qat) + ) + + # fused conv bn relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_relu_qat) + ) + + # conv bn, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # conv bn relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + return conv_configs + + +def _get_binary_ops_configs() -> list[BackendPatternConfig]: + """ + Return all configs related to binary ops. + """ + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config, + ] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + binary_op_configs: list[BackendPatternConfig] = [] + for op in [ + operator.add, + torch.add, + operator.sub, + torch.sub, + operator.mul, + torch.mul, + ]: + bop_patterns = [ + (op, torch.nn.ReLU), + (op, torch.nn.functional.relu), + (op, torch.relu), + op, + ] + binary_op_configs.extend( + BackendPatternConfig(bop_pattern) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type( + num_tensor_args_to_observation_type_mapping + ) + for bop_pattern in bop_patterns + ) + return binary_op_configs + + +def _get_share_qparams_ops_configs() -> list[BackendPatternConfig]: + """ + Return the operator configs for the operators that works for both float and quantized + input if input is quantized, the output Tensor shares the same quantization parameter + with input. + + Example operator: avgpool2d, reshape, transpose, maxpool2d + Example observed operator: + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config, + ] + share_qparams_ops = [ + torch.nn.Flatten, + F.adaptive_avg_pool2d, + F.elu, + F.hardtanh, + F.max_pool2d, + F.pad, + F.relu, + F.relu6, + F.leaky_relu, + F.leaky_relu_, + torch.nn.AdaptiveAvgPool2d, + torch.nn.ConstantPad2d, + torch.nn.ELU, + torch.nn.MaxPool2d, + torch.nn.ReLU6, + torch.nn.Hardtanh, + torch.nn.LeakyReLU, + torch.clamp, + torch.flatten, + torch.mean, + torch.permute, + torch.permute_copy, + torch.squeeze, + "clamp", + "mean", + "permute", + "reshape", + "relu", + "relu_", + "squeeze", + "squeeze_", + "leaky_relu", + ] + share_qparams_op_configs: list[BackendPatternConfig] = [ + BackendPatternConfig(op) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + for op in share_qparams_ops + ] + return share_qparams_op_configs + + +def _get_bn_configs() -> list[BackendPatternConfig]: + """ + Return all configs related to batchnorm. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config, + ] + bn_configs = [] + bn_configs.append( + BackendPatternConfig(nn.BatchNorm2d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + return bn_configs + + +def _get_cat_configs() -> list[BackendPatternConfig]: + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config, + ] + cat_configs = [] + cat_configs.append( + BackendPatternConfig(torch.cat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + cat_configs.append( + BackendPatternConfig(torch.concat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + cat_configs.append( + BackendPatternConfig(torch.concatenate) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + return cat_configs + + +def _get_embedding_op_configs() -> list[BackendPatternConfig]: + dtype_configs = [ + executorch_weight_only_quint8_dtype_config, + ] + embedding_op_configs = [] + for embedding_op, qat_embedding_op, ref_embedding_op in [ + (nn.Embedding, nnqat.Embedding, nnqr.Embedding), + (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag), + ]: + embedding_op_configs.append( + BackendPatternConfig(embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_qat_module(qat_embedding_op) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + # config for qat op + embedding_op_configs.append( + BackendPatternConfig(qat_embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + + # config for functional embedding + embedding_op_configs.append( + BackendPatternConfig(torch.nn.functional.embedding) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1}) + ) + return embedding_op_configs + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_executorch_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for backends PyTorch lowers to through the Executorch stack. + """ + return ( + BackendConfig("executorch") + .set_backend_pattern_configs(_get_linear_configs()) + .set_backend_pattern_configs(_get_conv_configs()) + .set_backend_pattern_configs(_get_binary_ops_configs()) + .set_backend_pattern_configs(_get_share_qparams_ops_configs()) + .set_backend_pattern_configs(_get_bn_configs()) + .set_backend_pattern_configs(_get_cat_configs()) + .set_backend_pattern_configs(_get_embedding_op_configs()) + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/fbgemm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/fbgemm.py new file mode 100644 index 0000000000000000000000000000000000000000..5d665f4fd030aba47c98ee692f0d9e7eca41cbc6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/fbgemm.py @@ -0,0 +1,129 @@ +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig + + +__all__ = [ + "get_fbgemm_backend_config", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +# TODO: For now, these DTypeConfigs are identical to the ones defined in native.py +# In the future, once we support specifying quant_min/quant_max and scale_min/scale_max, +# these will diverge. In particular, for FBGEMM, we will restrict the activation quantized +# values to within [0, 127]. + +fbgemm_weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +fbgemm_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +fbgemm_default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +fbgemm_default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +fbgemm_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +fbgemm_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +fbgemm_weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_fbgemm_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native FBGEMM backend. + """ + conv_dtype_configs = [fbgemm_weighted_op_quint8_dtype_config] + linear_dtype_configs = [ + fbgemm_weighted_op_quint8_dtype_config, + fbgemm_default_dynamic_int8_dtype_config, + fbgemm_default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + default_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + share_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + tensor_info_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + fbgemm_default_dynamic_int8_dtype_config, + fbgemm_default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + fbgemm_weight_only_quint8_dtype_config, + fbgemm_weight_only_quint4x2_dtype_config, + ] + return ( + BackendConfig("fbgemm") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/native.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/native.py new file mode 100644 index 0000000000000000000000000000000000000000..a98d1a9a3d41b43b1c0ce55a2471d3342af71a55 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/native.py @@ -0,0 +1,231 @@ +# mypy: allow-untyped-defs +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_ln_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig + + +__all__ = [ + "get_test_only_legacy_native_backend_config", + "default_op_quint8_dtype_config", + "default_op_fp16_dtype_config", + "default_dynamic_int8_dtype_config", + "default_dynamic_float16_dtype_config", + "input_output_only_quint8_dtype_config", + "weight_only_quint8_dtype_config", + "weight_only_quint4x2_dtype_config", + "get_native_backend_config", + "get_native_backend_config_dict", + "get_test_only_legacy_native_backend_config_dict", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +# weighted op int8 dtype config +# this is config for ops that has quantized weights, like linear, conv +weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the dtype_configs but + # it is not really used yet, + # we will enable it a bit later after we moved everything to backend_config_dict + is_dynamic=True, +) + +default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the dtype_configs but + # it is not really used yet, + # we will enable it a bit later after we moved everything to backend_config_dict + is_dynamic=True, +) + +# Needed for LayerNorm and f.layer_norm, since currently the kernel only supports float weights +input_output_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.float, + bias_dtype=torch.float, +) + +weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_test_only_legacy_native_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional fp16 ops. + """ + conv_dtype_configs = [weighted_op_quint8_dtype_config] + linear_dtype_configs = [ + weighted_op_quint8_dtype_config, + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + default_op_fp16_dtype_config, + ] + binary_op_dtype_configs = [ + default_op_quint8_dtype_config, + default_op_fp16_dtype_config, + ] + default_op_dtype_configs = [default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [ + default_op_quint8_dtype_config, + default_op_fp16_dtype_config, + ] + share_qparams_op_dtype_configs = [ + default_op_quint8_dtype_config, + default_op_fp16_dtype_config, + ] + tensor_info_op_dtype_configs = [ + default_op_quint8_dtype_config, + ] + rnn_op_dtype_configs = [ + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + weight_only_quint8_dtype_config, + weight_only_quint4x2_dtype_config, + ] + layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config] + return ( + BackendConfig("_native_and_fp16") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) + + +def get_native_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack). + """ + # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK BackendConfigs + conv_dtype_configs = [weighted_op_quint8_dtype_config] + linear_dtype_configs = [ + weighted_op_quint8_dtype_config, + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [default_op_quint8_dtype_config] + default_op_dtype_configs = [default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [default_op_quint8_dtype_config] + share_qparams_op_dtype_configs = [default_op_quint8_dtype_config] + tensor_info_op_dtype_configs = [default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + weight_only_quint8_dtype_config, + weight_only_quint4x2_dtype_config, + ] + layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config] + return ( + BackendConfig("native") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) + + +def get_native_backend_config_dict(): + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) in dictionary form. + """ + return get_native_backend_config().to_dict() + + +def get_test_only_legacy_native_backend_config_dict(): + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional + fp16 ops in dictionary form. + """ + return get_test_only_legacy_native_backend_config().to_dict() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/onednn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/onednn.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc7a2cf4c669742583fb1fe23d9948e04dbecc1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/onednn.py @@ -0,0 +1,641 @@ +# mypy: allow-untyped-defs +import itertools +import operator + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.quantized.reference as nnqr +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fuser_method_mappings import _sequential_wrapper2 +from torch.ao.quantization.utils import MatchAllNode + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_ln_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, +) +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, +) + + +# =================== +# | DTYPE CONFIGS | +# =================== + +onednn_weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +onednn_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +onednn_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +onednn_weight_only_qint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.qint8, +) + +onednn_input_output_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.float, + bias_dtype=torch.float, +) + +# =================== +# | FUSER METHODS | +# =================== + + +def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu): + r"""Given the linear, bn and leaky_relu modules, fuses them and returns the fused module + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + linear: Module instance of type Linear + bn: BatchNorm1d instance that needs to be fused with the linear layer + leaky_relu: LeakyReLU instance that needs to be fused with the linear layer + Examples:: + >>> # xdoctest: +SKIP(failing) + >>> m1 = nn.Linear(20, 10) + >>> b1 = nn.BatchNorm1d(10) + >>> lr = nn.LeakyReLU(0.01) + >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr) + """ + if linear.training != bn.training or bn.training != leaky_relu.training: + raise AssertionError( + "Linear, BN and LeakyReLU all must be in the same mode (train or eval)." + ) + + if is_qat: + raise NotImplementedError( + f"Cannot fuse train modules: {(linear, bn, leaky_relu)}" + ) + else: + map_to_fused_module_eval = { + nn.Linear: nni.LinearLeakyReLU, + } + fused_module = map_to_fused_module_eval.get(type(linear), None) + if fused_module is not None: + fused_linear = nn.utils.fusion.fuse_linear_bn_eval(linear, bn) + fm = fused_module(fused_linear, leaky_relu) + return fm + else: + raise NotImplementedError( + f"Cannot fuse eval modules: {(linear, bn, leaky_relu)}" + ) + + +# ====================== +# | CONFIGS FOR CONV | +# ====================== +observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + +conv_dtype_configs = [onednn_weighted_op_int8_dtype_config] +conv_configs = _get_conv_configs(conv_dtype_configs) + +# (1) Conv2d + Add + +# conv2d Y +# \ / +# add + +# include: +# conv2d conv2d +# \ / +# add + + +def _fuse_conv_add_left(is_qat, add, conv, _): + return nni.ConvAdd2d(conv, add) + + +def _conv_add_root_node_getter_left(pattern): + _, conv, _ = pattern + return conv + + +def _conv_add_extra_inputs_getter_left(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, _conv, extra_input = pattern + return [extra_input] + + +# conv2d +# \ +# bn Y +# \ / +# add + + +def _fuse_conv_bn_add_left(is_qat, add, bn_conv, _): + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAdd2d(fused_conv, add) + + +def _conv_bn_add_root_node_getter_left(add_pattern): + _, bn_conv, _ = add_pattern + _bn, conv = bn_conv + return conv + + +def _conv_bn_add_extra_inputs_getter_left(add_pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, _bn_conv, extra_input = add_pattern + return [extra_input] + + +conv_add_left_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_left_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode) + ) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_left) + ._set_root_node_getter(_conv_bn_add_root_node_getter_left) + ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_left) + .set_fused_module(nni.ConvAdd2d) + ) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((add_op, nn.Conv2d, MatchAllNode)) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_left) + ._set_root_node_getter(_conv_add_root_node_getter_left) + ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_left) + .set_fused_module(nni.ConvAdd2d) + ) + +# Y conv2d +# \ / +# add + + +def _fuse_conv_add_right(is_qat, add, _, conv): + return nni.ConvAdd2d(conv, add) + + +def _conv_add_root_node_getter_right(pattern): + _add, _, conv = pattern + return conv + + +def _conv_add_extra_inputs_getter_right(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, extra_input, _conv = pattern + return [extra_input] + + +# conv2d +# / +# Y bn +# \ / +# add + + +def _fuse_conv_bn_add_right(is_qat, add, _, bn_conv): + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAdd2d(fused_conv, add) + + +def _conv_bn_add_root_node_getter_right(pattern): + _add, _, bn_conv = pattern + _bn, conv = bn_conv + return conv + + +def _conv_bn_add_extra_inputs_getter_right(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, extra_input, _bn_conv = pattern + return [extra_input] + + +conv_add_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)) + ) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_right) + ._set_root_node_getter(_conv_bn_add_root_node_getter_right) + ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_right) + .set_fused_module(nni.ConvAdd2d) + ) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((add_op, MatchAllNode, nn.Conv2d)) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_right) + ._set_root_node_getter(_conv_add_root_node_getter_right) + ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_right) + .set_fused_module(nni.ConvAdd2d) + ) + +conv_configs.append( + BackendPatternConfig(nni.ConvAdd2d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(conv_dtype_configs) + .set_root_module(nn.Conv2d) + .set_reference_quantized_module(nnqr.Conv2d) +) + +# (2) Conv2d + Add + Relu + +# conv2d Y +# \ / +# add +# \ +# relu + + +def _fuse_conv_add_relu_left(is_qat, relu, add_pattern): + add, conv, _ = add_pattern + return nni.ConvAddReLU2d(conv, add, relu) + + +def _conv_add_relu_root_node_getter_left(pattern): + _relu, add_pattern = pattern + _, conv, _ = add_pattern + return conv + + +def _conv_add_relu_extra_inputs_getter_left(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _relu, add_pattern = pattern + _, _conv, extra_input = add_pattern + return [extra_input] + + +# conv2d +# \ +# bn Y +# \ / +# add +# \ +# relu + + +def _fuse_conv_bn_add_relu_left(is_qat, relu, add_pattern): + add, bn_conv, _ = add_pattern + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAddReLU2d(fused_conv, add, relu) + + +def _conv_bn_add_relu_root_node_getter_left(pattern): + _relu, add_pattern = pattern + _, bn_conv, _ = add_pattern + _bn, conv = bn_conv + return conv + + +def _conv_bn_add_relu_extra_inputs_getter_left(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _relu, add_pattern = pattern + _, _bn_conv, extra_input = add_pattern + return [extra_input] + + +conv_add_relu_left_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_relu_left_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (nn.ReLU, (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) + ) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_relu_left) + ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_left) + ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_left) + .set_fused_module(nni.ConvAddReLU2d) + ) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((nn.ReLU, (add_op, nn.Conv2d, MatchAllNode))) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_relu_left) + ._set_root_node_getter(_conv_add_relu_root_node_getter_left) + ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_left) + .set_fused_module(nni.ConvAddReLU2d) + ) + +# Y conv2d +# \ / +# add +# \ +# relu + + +def _fuse_conv_add_relu_right(is_qat, relu, add_pattern): + add, _, conv = add_pattern + return nni.ConvAddReLU2d(conv, add, relu) + + +def _conv_add_relu_root_node_getter_right(pattern): + _relu, add_pattern = pattern + _, _extra_input, conv = add_pattern + return conv + + +def _conv_add_relu_extra_inputs_getter_right(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _relu, add_pattern = pattern + _, extra_input, _conv = add_pattern + return [extra_input] + + +# conv2d +# / +# Y bn +# \ / +# add +# \ +# relu + + +def _fuse_conv_bn_add_relu_right(is_qat, relu, add_pattern): + add, _, bn_conv = add_pattern + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAddReLU2d(fused_conv, add, relu) + + +def _conv_bn_add_relu_root_node_getter_right(pattern): + _relu, add_pattern = pattern + _, _, bn_conv = add_pattern + _bn, conv = bn_conv + return conv + + +def _conv_bn_add_relu_extra_inputs_getter_right(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _relu, add_pattern = pattern + _, extra_input, _bn_conv = add_pattern + return [extra_input] + + +conv_add_relu_left_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_relu_left_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (nn.ReLU, (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) + ) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_relu_right) + ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_right) + ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_right) + .set_fused_module(nni.ConvAddReLU2d) + ) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((nn.ReLU, (add_op, MatchAllNode, nn.Conv2d))) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_relu_right) + ._set_root_node_getter(_conv_add_relu_root_node_getter_right) + ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_right) + .set_fused_module(nni.ConvAddReLU2d) + ) + +conv_configs.append( + BackendPatternConfig(nni.ConvAddReLU2d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(conv_dtype_configs) + .set_root_module(nn.Conv2d) + .set_reference_quantized_module(nnqr.Conv2d) +) + +# ======================== +# | CONFIGS FOR LINEAR | +# ======================== + +linear_dtype_configs = [ + onednn_weighted_op_int8_dtype_config, + onednn_dynamic_int8_dtype_config, +] +linear_configs = _get_linear_configs(linear_dtype_configs) + + +def _add_eltwise_fusion_configs( + configs, + root_module, + root_op, + post_module, + post_op, + dtype_configs, + fuser_method, + fused_module, + observation_type, + ref_quant_module, +): + # 1 base module + op module fusion config + configs.append( + BackendPatternConfig((root_module, post_module)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuser_method) + .set_fused_module(fused_module) + ) + # base module + functional post op + configs.append( + BackendPatternConfig((root_module, post_op)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuser_method) + .set_fused_module(fused_module) + ) + + # 2 fused module configs + configs.append( + BackendPatternConfig(fused_module) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(root_module) + .set_reference_quantized_module(ref_quant_module) + ) + + # 3 functional base op + post op configs + configs.append( + BackendPatternConfig((root_op, post_module)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + configs.append( + BackendPatternConfig((root_op, post_op)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + + +# Configs for linear + leaky_relu fusion +_add_eltwise_fusion_configs( + linear_configs, + nn.Linear, + F.linear, + nn.LeakyReLU, + F.leaky_relu, + linear_dtype_configs, + _sequential_wrapper2(nni.LinearLeakyReLU), + nni.LinearLeakyReLU, + observation_type, + nnqr.Linear, +) + +# Configs for linear module + batchnorm + leaky_relu +linear_configs.append( + BackendPatternConfig((nn.Linear, nn.BatchNorm1d, nn.LeakyReLU)) + .set_dtype_configs(linear_dtype_configs) # noqa: E131 + .set_fuser_method(_fuse_linear_bn_leaky_relu) + .set_fused_module(nni.LinearLeakyReLU) +) + +# Configs for linear + tanh fusion +_add_eltwise_fusion_configs( + linear_configs, + nn.Linear, + F.linear, + nn.Tanh, + torch.tanh, + linear_dtype_configs, + _sequential_wrapper2(nni.LinearTanh), + nni.LinearTanh, + observation_type, + nnqr.Linear, +) + +# =========================== +# | CONFIGS FOR OTHER OPS | +# =========================== + +binary_op_dtype_configs = [onednn_op_quint8_dtype_config] +default_op_dtype_configs = [onednn_op_quint8_dtype_config] +fixed_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config] +share_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config] +rnn_op_dtype_configs = [onednn_dynamic_int8_dtype_config] +embedding_op_dtype_configs = [onednn_weight_only_qint8_dtype_config] +layer_norm_op_dtype_configs = [onednn_input_output_only_quint8_dtype_config] + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_onednn_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native ONEDNN backend. + """ + return ( + BackendConfig("onednn") + .set_backend_pattern_configs(conv_configs) + .set_backend_pattern_configs(linear_configs) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) + + +__all__ = [ + "get_onednn_backend_config", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/qnnpack.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/qnnpack.py new file mode 100644 index 0000000000000000000000000000000000000000..841bac512a6549f39f757b9531591f1e47e72a83 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/qnnpack.py @@ -0,0 +1,171 @@ +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig, DTypeWithConstraints + + +__all__ = [ + "get_qnnpack_backend_config", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +qnnpack_weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +qnnpack_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +qnnpack_default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +qnnpack_default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +qnnpack_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +qnnpack_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +qnnpack_weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + +# xnnpack compatible dtype configs + +# We restrict scale values to be 2 ** -12 to ensure the +# requantization scale never falls below the xnnpack lower +# threshold. Additionally, for qint8 weight, we restrict +# the quantization values to [-127, +127], excluding -128. +# For more detail, refer to the description of +# `default_symmetric_qnnpack_qconfig`. + +# TODO: add additional restriction on qscheme to ensure it +# is either per_tensor_symmetric or per_channel_symmetric + +qnnpack_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + scale_min_lower_bound=2**-12, +) + +qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + quant_min_lower_bound=-127, + quant_max_upper_bound=127, + scale_min_lower_bound=2**-12, +) + +qnnpack_weighted_op_qint8_symmetric_dtype_config = DTypeConfig( + input_dtype=qnnpack_act_qint8_scale_min_2_neg_12, + output_dtype=qnnpack_act_qint8_scale_min_2_neg_12, + weight_dtype=qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12, + bias_dtype=torch.float, +) + +qnnpack_default_op_qint8_symmetric_dtype_config = DTypeConfig( + input_dtype=qnnpack_act_qint8_scale_min_2_neg_12, + output_dtype=qnnpack_act_qint8_scale_min_2_neg_12, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_qnnpack_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native QNNPACK backend. + """ + conv_dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + qnnpack_weighted_op_quint8_dtype_config, + ] + linear_dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + qnnpack_weighted_op_quint8_dtype_config, + qnnpack_default_dynamic_int8_dtype_config, + qnnpack_default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + default_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + fixed_qparams_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + share_qparams_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + rnn_op_dtype_configs = [ + qnnpack_default_dynamic_int8_dtype_config, + qnnpack_default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + qnnpack_weight_only_quint8_dtype_config, + qnnpack_weight_only_quint4x2_dtype_config, + ] + return ( + BackendConfig("qnnpack") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/tensorrt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/tensorrt.py new file mode 100644 index 0000000000000000000000000000000000000000..d0490e2071f4f2df59b4bb6eb2a1d7885b4aa036 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/tensorrt.py @@ -0,0 +1,98 @@ +# mypy: allow-untyped-defs +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_conv_configs, + _get_linear_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, +) + + +__all__ = [ + "get_tensorrt_backend_config", + "get_tensorrt_backend_config_dict", +] + + +def get_tensorrt_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for the TensorRT backend. + NOTE: Current api will change in the future, it's just to unblock experimentation for + new backends, please don't use it right now. + TODO: add a README when it's more stable + """ + # dtype configs + weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.qint8, + output_dtype=torch.qint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + non_weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.qint8, + output_dtype=torch.qint8, + ) + + addmm_config = ( + BackendPatternConfig(torch.addmm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) + .add_dtype_config(weighted_op_qint8_dtype_config) + ._set_input_type_to_index( + { + "bias": 0, + "input": 1, + "weight": 2, + } + ) + ) + cat_config = ( + BackendPatternConfig(torch.cat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .add_dtype_config(non_weighted_op_qint8_dtype_config) + ) + conv_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + linear_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + binary_op_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + share_qparams_op_dtype_configs = [ + non_weighted_op_qint8_dtype_config, + ] + tensor_info_op_dtype_configs = [ + non_weighted_op_qint8_dtype_config, + ] + # there might be things not supported in fx2trt, but it will error out + # during fx2trt conversion and can support them after that + return ( + BackendConfig("tensorrt") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_config(addmm_config) + .set_backend_pattern_config(cat_config) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + ) + + +def get_tensorrt_backend_config_dict(): + """ + Return the `BackendConfig` for the TensorRT backend in dictionary form. + """ + return get_tensorrt_backend_config().to_dict() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4d486a061129324a311cdf74ebb58a51bf2dd9d8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/utils.py @@ -0,0 +1,317 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fuser_method_mappings import _reverse2, _reverse3 +from torch.ao.quantization.utils import Pattern + +from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig + + +__all__ = [ + "get_pattern_to_dtype_configs", + "get_qat_module_classes", + "get_fused_module_classes", + "get_pattern_to_input_type_to_index", + "get_root_module_to_quantized_reference_module", + "get_fuser_method_mapping", + "get_module_to_qat_module", + "get_fusion_pattern_to_root_node_getter", + "get_fusion_pattern_to_extra_inputs_getter", + "remove_boolean_dispatch_from_name", + "pattern_to_human_readable", + "entry_to_pretty_str", +] + + +def get_pattern_to_dtype_configs( + backend_config: BackendConfig, +) -> dict[Pattern, list[DTypeConfig]]: + pattern_to_dtype_configs: dict[Pattern, list[DTypeConfig]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + pattern_to_dtype_configs[pattern] = config.dtype_configs + return pattern_to_dtype_configs + + +def get_qat_module_classes(backend_config: BackendConfig) -> tuple[type, ...]: + qat_module_classes = [ + config.qat_module + for config in backend_config.configs + if config.qat_module is not None + ] + return tuple(set(qat_module_classes)) + + +def get_fused_module_classes(backend_config: BackendConfig) -> tuple[type, ...]: + fused_module_classes = [ + config.fused_module + for config in backend_config.configs + if config.fused_module is not None + ] + return tuple(set(fused_module_classes)) + + +def get_pattern_to_input_type_to_index( + backend_config: BackendConfig, +) -> dict[Pattern, dict[str, int]]: + pattern_to_input_type_to_index: dict[Pattern, dict[str, int]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + pattern_to_input_type_to_index[pattern] = config._input_type_to_index + return pattern_to_input_type_to_index + + +def get_root_module_to_quantized_reference_module( + backend_config: BackendConfig, +) -> dict[type[torch.nn.Module], type[torch.nn.Module]]: + mapping: dict[type[torch.nn.Module], type[torch.nn.Module]] = {} + for config in backend_config.configs: + if ( + config.root_module is not None + and config.reference_quantized_module is not None + ): + mapping[config.root_module] = config.reference_quantized_module + return mapping + + +def get_fuser_method_mapping( + backend_config: BackendConfig, +) -> dict[Pattern, nn.Sequential | Callable]: + fuser_method_mapping: dict[Pattern, nn.Sequential | Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config.fuser_method is not None: + # Note: both the fuser method and the pattern are specified in forward order in the + # BackendConfig, but the internal pattern matching code uses the reversed nested tuple + # format, so we need to convert both to the internal format + fuser_method = _get_fuser_method_in_reversed_nested_tuple_format(config) + fuser_method_mapping[pattern] = fuser_method + return fuser_method_mapping + + +def get_module_to_qat_module( + backend_config: BackendConfig, +) -> dict[Pattern, type[torch.nn.Module]]: + module_to_qat_module: dict[Pattern, type[torch.nn.Module]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config.qat_module is not None: + module_to_qat_module[pattern] = config.qat_module + return module_to_qat_module + + +def get_fusion_pattern_to_root_node_getter( + backend_config: BackendConfig, +) -> dict[Pattern, Callable]: + """Get a map from fusion pattern to a function that returns the root node + from the fusion pattern, e.g. the most common one is: + def get_root_node(node_pattern): + while not isinstance(node_pattern[-1], Node): + node_pattern = node_pattern[-1] + return node_pattern[-1] + This can work for all patterns whose root node is the "last node" in the pattern, + e.g. (torch.add, MatchAllNode, (torch.ReLU, torch.Conv2d)) + """ + root_node_getter_mapping: dict[Pattern, Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config._root_node_getter is not None: + root_node_getter_mapping[pattern] = config._root_node_getter + return root_node_getter_mapping + + +def get_fusion_pattern_to_extra_inputs_getter( + backend_config: BackendConfig, +) -> dict[Pattern, Callable]: + """Get a map from fusion pattern to a function that returns extra input nodes + from the fusion pattern, in the order required by the root node. This is optional, + if not specified, we will not copy over any extra inputs for the root node. + Example: + # Let's say we have the pattern (torch.add, MatchAllNode, (torch.nn.BatchNorm2d, torch.nn.Conv2d)) + # and root node is torch.nn.Conv2d, and the node in MatchAllNode would be an extra + # argument to the fused module, we can unpack the pattern and return the node at + # MatchAllNode here + # we can implement extra_inputs_getter as follows: + def extra_inputs_getter(pattern) -> List[Any]: + add, extra_input, conv_pattern = pattern + return [extra_input] + """ + extra_inputs_getter_mapping: dict[Pattern, Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config._extra_inputs_getter is not None: + extra_inputs_getter_mapping[pattern] = config._extra_inputs_getter + return extra_inputs_getter_mapping + + +def remove_boolean_dispatch_from_name(p) -> Any: + """ + Some ops have a default string representation such as + '.fn at 0x7ff1106bf280>', + this function replaces them with the hardcoded function names. + """ + if p is F.fractional_max_pool2d: + return "torch.nn.functional.fractional_max_pool2d" + elif p is F.fractional_max_pool3d: + return "torch.nn.functional.fractional_max_pool3d" + elif p is F.max_pool1d: + return "torch.nn.functional.max_pool1d" + elif p is F.max_pool2d: + return "torch.nn.functional.max_pool2d" + elif p is F.max_pool3d: + return "torch.nn.functional.max_pool3d" + elif p is F.adaptive_max_pool1d: + return "torch.nn.functional.adaptive_max_pool1d" + elif p is F.adaptive_max_pool2d: + return "torch.nn.functional.adaptive_max_pool2d" + elif p is F.adaptive_max_pool3d: + return "torch.nn.functional.adaptive_max_pool3d" + if "boolean_dispatch" in str(p): + raise AssertionError( + f"{p} does not have a human readable representation in " + + "quantization documentation" + ) + return p + + +def pattern_to_human_readable(p) -> Any: + if isinstance(p, tuple): + # nested patterns, recurse + return tuple(pattern_to_human_readable(inner_p) for inner_p in p) + elif isinstance(p, str): + # method names are already human readable + return p + else: + p = remove_boolean_dispatch_from_name(p) + return p + + +# TODO(future PR): move backend_config_dict to use dataclass and move this logic to +# the corresponding __str__ function +def entry_to_pretty_str(entry) -> str: + """ + Given a backend_config_dict entry, returns a string with the human readable + representation of it. + """ + s = "{\n" + + # always output the pattern first + if "pattern" in entry: + pattern_str = pattern_to_human_readable(entry["pattern"]) + + s += f" 'pattern': {pattern_str},\n" + + # custom output for dtype_configs to make it look nice + if "dtype_configs" in entry: + s += " 'dtype_configs': [\n" + for dtype_config in entry["dtype_configs"]: + s += " {\n" + for k, v in dtype_config.items(): + s += f" '{k}': {v},\n" + s += " },\n" + s += " ],\n" + + # custom output for num_tensor_args_to_observation_type to make it look nice + if "num_tensor_args_to_observation_type" in entry: + s += " 'num_tensor_args_to_observation_type': {\n" + for k, v in entry["num_tensor_args_to_observation_type"].items(): + s += f" {k}: {v},\n" + s += " },\n" + + # output all the other fields + custom_handled_fields = [ + "pattern", + "dtype_configs", + "num_tensor_args_to_observation_type", + ] + for field_name in entry: + if field_name in custom_handled_fields: + continue + s += f" '{field_name}': {entry[field_name]},\n" + + s += "}" + return s + + +def _get_pattern_in_reversed_nested_tuple_format( + config: BackendPatternConfig, +) -> Pattern: + """ + Return the pattern specified in the given config in the reversed nested tuple format + used internally in the quantization pattern matching code. + + If the pattern is not a tuple, or the pattern is already specified in the reversed + nested tuple format, return the pattern as is. Otherwise: + + For 2-tuples (a, b), return (b, a). + For 3-tuples (a, b, c), return (c, (b, a)). + + For example: + * Given nn.Linear, return nn.Linear + * Given (nn.Linear, nn.ReLU), return (nn.ReLU, nn.Linear) + * Given (nn.Conv2d, nn.BatchNorm2d, nn.ReLU), return + (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)) + + For context, the reason why this is needed is the user-facing BackendConfig + API accepts the flat 2-or-3-tuple format in forward order. While this simple + format handles the vast majority of use cases, it does not handle the more + complex ones, and so the internal pattern matching code for quantization uses + the following, more general reversed nested tuple format instead: + + operator = module_type | functional | torch op | native op | MatchAllNode + Pattern = (operator, Pattern, Pattern, ...) | operator + + In the future, we expect to replace the above complex format with the one used + by the subgraph rewriter in torch.fx, so we don't have to maintain our own + complex pattern matching code. Then we won't need this helper function anymore. + """ + if config._pattern_complex_format is not None: + return config._pattern_complex_format + if config.pattern is None: + raise ValueError( + "Either 'pattern' or 'pattern_complex_format' must be specified" + ) + if not isinstance(config.pattern, tuple): + return config.pattern + + # Pattern is specified in the simple tuple format, need to convert + if len(config.pattern) == 2: + (a, b) = config.pattern + return (b, a) + elif len(config.pattern) == 3: + (a, b, c) = config.pattern + return (c, (b, a)) + else: + raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern) + + +def _get_fuser_method_in_reversed_nested_tuple_format( + config: BackendPatternConfig, +) -> Callable: + """ + Return the fuser method specified in the given config in the reversed nested + tuple format used internally in the quantization pattern matching code. + + If pattern is specified in the reversed nested tuple format, we assume the + fuser method is also specified in this format and simply return it as is. + Otherwise, we convert the fuser method as follows: + + * Given f(is_qat, conv, relu), return f'(is_qat, relu, conv) + * Given f(is_qat, conv, bn, relu), return f'(is_qat, relu, bn_conv), + where bn_conv is a 2-tuple (bn, conv) + + The first argument of a fuser method is always `is_qat` and is not affected + in the conversion. We currently only support functions with 3 or 4 arguments. + """ + if config.fuser_method is None: + raise AssertionError("config.fuser_method must be provided") + if config._pattern_complex_format is not None: + return config.fuser_method + if not isinstance(config.pattern, tuple): + raise ValueError("Expected pattern to be a tuple, got: ", config.pattern) + + # Pattern is specified in the simple tuple format, need to convert + if len(config.pattern) == 2: + return _reverse2(config.fuser_method) + elif len(config.pattern) == 3: + return _reverse3(config.fuser_method) + else: + raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/x86.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/x86.py new file mode 100644 index 0000000000000000000000000000000000000000..c64b56c981b391140f63038ac507b0708ee876f4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/x86.py @@ -0,0 +1,126 @@ +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig + + +__all__ = [ + "get_x86_backend_config", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +# X86 aligns with FBGEMM for now + +x86_weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +x86_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +x86_default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +x86_default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +x86_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +x86_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +x86_weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_x86_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native x86 backend. + """ + conv_dtype_configs = [x86_weighted_op_int8_dtype_config] + linear_dtype_configs = [ + x86_weighted_op_int8_dtype_config, + x86_default_dynamic_int8_dtype_config, + x86_default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [x86_weighted_op_int8_dtype_config] + default_op_dtype_configs = [x86_default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [x86_weighted_op_int8_dtype_config] + share_qparams_op_dtype_configs = [x86_default_op_quint8_dtype_config] + tensor_info_op_dtype_configs = [x86_default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + x86_default_dynamic_int8_dtype_config, + x86_default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + x86_weight_only_quint8_dtype_config, + x86_weight_only_quint4x2_dtype_config, + ] + return ( + BackendConfig("x86") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..72d624ad7d6a3926c5d34afab3b7066928f9933d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__init__.py @@ -0,0 +1,3 @@ +from .convert import convert +from .fuse import fuse +from .prepare import prepare diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_decomposed.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_decomposed.py new file mode 100644 index 0000000000000000000000000000000000000000..0754627a19dd1241dda4c53121f994b1b63ff025 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_decomposed.py @@ -0,0 +1,1268 @@ +# mypy: allow-untyped-defs +import math + +import torch +from torch._refs import _unsqueeze_multiple +from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax +from torch.library import impl, Library + + +# Note: decomposed means decomposed quantized tensor, using decomposed so that the +# name is not too long +quantized_decomposed_lib = Library("quantized_decomposed", "DEF") + +_INTEGER_DTYPES = [torch.uint8, torch.int8, torch.uint16, torch.int16, torch.int32] +_FLOAT_DTYPES = [torch.float8_e5m2, torch.float8_e4m3fn] + +_DTYPE_TO_QVALUE_BOUNDS = { + k: (torch.iinfo(k).min, torch.iinfo(k).max) for k in _INTEGER_DTYPES +} +_DTYPE_TO_QVALUE_BOUNDS.update( + {k: (int(torch.finfo(k).min), int(torch.finfo(k).max)) for k in _FLOAT_DTYPES} +) + + +# Helper to check the passed in quant min and max are valid for the dtype +def _quant_min_max_bounds_check(quant_min, quant_max, dtype): + if dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise ValueError(f"Unsupported dtype: {dtype}") + quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] + + if quant_min < quant_min_lower_bound: + raise AssertionError( + "quant_min out of bound for dtype, " + f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" + ) + + if quant_max > quant_max_upper_bound: + raise AssertionError( + "quant_max out of bound for dtype, " + f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" + ) + + +quantized_decomposed_lib.define( + "quantize_per_tensor(Tensor input, float scale, int zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd") +def quantize_per_tensor( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scale (float): quantization parameter for affine quantization + zero_point (int): quantization parameter for affine quantization + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + + inv_scale = 1.0 / scale + return torch.clamp( + torch.round(input * inv_scale) + zero_point, quant_min, quant_max + ).to(dtype) + + +@impl(quantized_decomposed_lib, "quantize_per_tensor", "Meta") +def quantize_per_tensor_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + return torch.empty_like(input, dtype=dtype) + + +quantized_decomposed_lib.define( + "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, "quantize_per_tensor.tensor", "CompositeExplicitAutograd" +) +def quantize_per_tensor_tensor( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + return quantize_per_tensor( + input, + scale.item(), + zero_point.item(), # type: ignore[arg-type] + quant_min, # type: ignore[arg-type] + quant_max, # type: ignore[arg-type] + dtype, + ) + + +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta") +def quantize_per_tensor_tensor_meta( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + return torch.empty_like(input, dtype=dtype) + + +# TODO: remove other variants and keep this one +quantized_decomposed_lib.define( + "quantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, " + "Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, "quantize_per_tensor.tensor2", "CompositeExplicitAutograd" +) +def quantize_per_tensor_tensor2( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: torch.Tensor, + quant_max: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + """Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + return quantize_per_tensor( + input, + scale.item(), + zero_point.item(), # type: ignore[arg-type] + quant_min.item(), # type: ignore[arg-type] + quant_max.item(), # type: ignore[arg-type] + dtype, + ) + + +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta") +def quantize_per_tensor_tensor2_meta( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: torch.Tensor, + quant_max: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + return quantize_per_tensor_tensor_meta( + input, + scale, + zero_point, # type: ignore[arg-type] + quant_min, # type: ignore[arg-type] + quant_max, # type: ignore[arg-type] + dtype, + ) + + +# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in +# the signature as metadata for the input Tensor, this might be useful for pattern +# matching in the future +# We will revisit this later if we found there are no use cases for it +quantized_decomposed_lib.define( + "dequantize_per_tensor(Tensor input, float scale, int zero_point, " + "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd") +def dequantize_per_tensor( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + """Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + + Args: + input (torch.Tensor): Tensor with dtype matching `dtype` argument, + e.g. (`torch.uint8`), it is a per tensor quantized Tensor if combined with + quantization parameters in the argument of this function (scale/zero_point) + + scale (float): quantization parameter for affine quantization + + zero_point (int): quantization parameter for affine quantization + + quant_min (int): minimum quantized value for input Tensor (not used in computation, + reserved for pattern matching) + + quant_max (int): maximum quantized value for input Tensor (not used in computation, + reserved for pattern matching) + + dtype (torch.dtype): dtype for input Tensor (not used in computation, + reserved for pattern matching) + + out_dtype (torch.dtype?): optional dtype for output Tensor + + Returns: + dequantized float32 Tensor + """ + if input.dtype != dtype: + raise AssertionError( + f"Expecting input to have dtype: {dtype}, but got {input.dtype}" + ) + if out_dtype is None: + out_dtype = torch.float32 + if dtype in _DTYPE_TO_QVALUE_BOUNDS: + # TODO: investigate why + # (input - zero_point).to(torch.float32) * scale + # failed the test + return (input.to(out_dtype) - zero_point) * scale + else: + raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + + +@impl(quantized_decomposed_lib, "dequantize_per_tensor", "Meta") +def dequantize_per_tensor_meta( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + if out_dtype is None: + out_dtype = torch.float32 + return torch.empty_like(input, dtype=out_dtype) + + +quantized_decomposed_lib.define( + "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " + "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, + "dequantize_per_tensor.tensor", + "CompositeExplicitAutograd", +) +def dequantize_per_tensor_tensor( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + """Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + return dequantize_per_tensor( + input, + scale.item(), + zero_point.item(), # type: ignore[arg-type] + quant_min, + quant_max, + dtype, + out_dtype=out_dtype, + ) + + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta") +def dequantize_per_tensor_tensor_meta( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + if out_dtype is None: + out_dtype = torch.float32 + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + if input.dtype != dtype: + raise AssertionError( + f"Expecting input to have dtype: {dtype}, but got {input.dtype}" + ) + if dtype in _DTYPE_TO_QVALUE_BOUNDS: + return torch.empty_like(input, dtype=out_dtype) + else: + raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + + +# TODO: remove other variants and keep this one +quantized_decomposed_lib.define( + "dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, " + "Tensor quant_min, Tensor quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, + "dequantize_per_tensor.tensor2", + "CompositeExplicitAutograd", +) +def dequantize_per_tensor_tensor2( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: torch.Tensor, + quant_max: torch.Tensor, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + """Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + return dequantize_per_tensor( + input, + scale.item(), + zero_point.item(), # type: ignore[arg-type] + quant_min.item(), # type: ignore[arg-type] + quant_max.item(), # type: ignore[arg-type] + dtype, + out_dtype=out_dtype, + ) + + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta") +def dequantize_per_tensor_tensor2_meta( + input, + scale, + zero_point, + quant_min, + quant_max, + dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + return dequantize_per_tensor_tensor_meta( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype=out_dtype + ) + + +quantized_decomposed_lib.define( + "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, " + "float eps, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd") +def choose_qparams_tensor( + input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + """Given an input Tensor, derive the per tensor affine quantization parameter + (scale and zero_point) for target quantized Tensor from the Tensor + + Args: + input (torch.Tensor): floating point input Tensor + quant_min (int): minimum quantized value for target quantized Tensor + quant_max (int): maximum quantized value for target quantized Tensor + dtype (torch.dtype): dtype for target quantized Tensor + + Returns: + scale (float): quantization parameter for the target quantized Tensor + zero_point (int): quantization parameter for the target quantized Tensor + """ + if input.dtype not in [ + torch.float32, + torch.float16, + torch.bfloat16, + ]: + raise AssertionError( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + if dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise AssertionError( + f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + ) + validate_qmin_qmax(qmin, qmax) + + min_val, max_val = torch.aminmax(input) + + return determine_qparams( + min_val, + max_val, + qmin, + qmax, + dtype, + torch.Tensor([eps]), + has_customized_qrange=False, + ) + + +quantized_decomposed_lib.define( + "choose_qparams_symmetric.tensor(Tensor input, int quant_min, int quant_max, " + "float eps, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_symmetric.tensor", + "CompositeExplicitAutograd", +) +def choose_qparams_symmetric_tensor( + input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + """Given an input Tensor, derive the per tensor affine quantization parameter + (scale and zero_point) for target quantized Tensor from the Tensor + + Args: + input (torch.Tensor): floating point input Tensor + quant_min (int): minimum quantized value for target quantized Tensor + quant_max (int): maximum quantized value for target quantized Tensor + dtype (torch.dtype): dtype for target quantized Tensor + + Returns: + scale (float): quantization parameter for the target quantized Tensor + zero_point (int): quantization parameter for the target quantized Tensor + """ + if input.dtype not in [ + torch.float32, + torch.float16, + torch.bfloat16, + ]: + raise AssertionError( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + if dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise AssertionError( + f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + ) + validate_qmin_qmax(qmin, qmax) + + min_val, max_val = torch.aminmax(input) + return determine_qparams( + min_val, + max_val, + qmin, + qmax, + dtype, + torch.Tensor([eps]), + has_customized_qrange=False, + qscheme=torch.per_tensor_symmetric, + ) + + +@impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta") +def choose_qparams_tensor_meta( + input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + if input.dtype not in [ + torch.float32, + torch.float16, + torch.bfloat16, + ]: + raise AssertionError( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + if quant_min >= quant_max: + raise AssertionError( + f"Expecting quant_min to be smaller than quant_max but received min: {quant_min} max: {quant_max}" + ) + return torch.empty(1, dtype=torch.double, device=input.device), torch.empty( + 1, dtype=torch.int64, device=input.device + ) + + +@impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "Meta") +def choose_qparams_symmetric_tensor_meta( + input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty(1, dtype=torch.double, device=input.device), torch.empty( + 1, dtype=torch.int64, device=input.device + ) + + +# Helper function used to implement per-channel quantization against any axis +def _permute_to_axis_zero(x, axis): + new_axis_list = list(range(x.dim())) + new_axis_list[axis] = 0 + new_axis_list[0] = axis + y = x.permute(tuple(new_axis_list)) + return y, new_axis_list + + +quantized_decomposed_lib.define( + "quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "quantize_per_channel", "CompositeExplicitAutograd") +def quantize_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Affine per channel quantization for the Tensor using the same quantization + parameters for each channel/axis to map from floating point to quantized values + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scales (torch.Tensor): a list of scale quantization parameter for + affine quantization, one per channel + zero_point (torch.Tensor): a list of zero_point quantization parameter for + affine quantization, one per channel + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + input, permute_axis_list = _permute_to_axis_zero(input, axis) + + new_shape = [1] * input.dim() + new_shape[0] = scales.shape[0] + scales = scales.view(new_shape) + zero_points = zero_points.view(new_shape) + + res = torch.clamp( + torch.round(input * (1.0 / scales)) + zero_points, quant_min, quant_max + ) + out = res.permute(tuple(permute_axis_list)) + return out.to(dtype) + + +@impl(quantized_decomposed_lib, "quantize_per_channel", "Meta") +def quantize_per_channel_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=dtype) + + +# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in +# the signature as metadata for the input Tensor, this might be useful for pattern +# matching in the future +# We will revisit this later if we found there are no use cases for it +quantized_decomposed_lib.define( + "dequantize_per_channel(Tensor input, Tensor scales, Tensor? zero_points, int axis, " + "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd") +def dequantize_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor | None, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + """Affine per channel dequantization for the Tensor using the same quantization + parameters for each channel/axis to map from quantized values to floating point values + + Args: + input (torch.Tensor): Tensor with dtype matching `dtype` argument, + e.g. (`torch.uint8`), it is a per channel quantized Tensor if combined with + quantization parameter in the argument of this function (scales/zero_points/axis) + + scales (torch.Tensor): a list of scale quantization parameter for + affine quantization, one per channel + + zero_points (torch.Tensor): a list of zero_point quantization parameter for + affine quantization, one per channel + + quant_min (int): minimum quantized value for output Tensor (not used in computation, + reserved for pattern matching) + + quant_max (int): maximum quantized value for output Tensor (not used in computation, + reserved for pattern matching) + + dtype (torch.dtype): requested dtype for output Tensor (not used in computation, + reserved for pattern matching) + + out_dtype (torch.dtype?): optional dtype for output Tensor + + Returns: + dequantized float32 Tensor + """ + if input.dtype != dtype: + raise AssertionError( + f"Expecting input to have dtype: {dtype}, but got dtype: {input.dtype}" + ) + if out_dtype is None: + out_dtype = torch.float32 + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + input, permute_axis_list = _permute_to_axis_zero(input, axis) + + new_shape = [1] * input.dim() + new_shape[0] = scales.shape[0] + scales = scales.view(new_shape) + if zero_points is not None: + res = (input - zero_points.view(new_shape)) * scales + else: + res = input * scales + + res = res.to(out_dtype) + + out = res.permute(tuple(permute_axis_list)) + return out + + +@impl(quantized_decomposed_lib, "dequantize_per_channel", "Meta") +def dequantize_per_channel_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor | None, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + if input.dtype != dtype: + raise AssertionError( + f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" + ) + if out_dtype is None: + out_dtype = torch.float32 + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=out_dtype) + + +quantized_decomposed_lib.define( + "choose_qparams_per_token(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token", + "CompositeExplicitAutograd", +) +def choose_qparams_per_token( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """Choose quantization parameters for per token quantization. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32/float16 Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + + Returns: + scales and zero_points, both float32 Tensors + """ + + scales = input.abs().amax(dim=-1, keepdim=True) + if scales.dtype == torch.float16: + scales = ( + scales.float() + ) # want float scales to avoid overflows for fp16, (bf16 has wide enough range) + if dtype == torch.int8: + n_bits = 8 + quant_max = 2 ** (n_bits - 1) - 1 + else: + raise Exception( # noqa: TRY002 + f"unsupported dtype in choose_qparams_per_token: {dtype}" + ) + + scales = scales.clamp(min=1e-5).div(quant_max) + zero_points = torch.zeros_like(scales) + return scales, zero_points + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token", + "Meta", +) +def choose_qparams_per_token_meta( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + size = list(input.shape[:-1]) + [1] + return torch.empty(size, dtype=torch.double, device=input.device), torch.empty( + size, dtype=torch.int64, device=input.device + ) + + +quantized_decomposed_lib.define( + "_choose_qparams_per_token_asymmetric_impl(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "_choose_qparams_per_token_asymmetric_impl", + "CompositeImplicitAutograd", +) +def _choose_qparams_per_token_asymmetric_impl( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """Choose quantization parameters for per token quantization. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32/float16 Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + + Returns: + scales and zero_points, both float32 Tensors + """ + # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18 + qmin, qmax = -128, 127 + min_val = torch.amin(input, dim=-1, keepdim=True) + max_val = torch.amax(input, dim=-1, keepdim=True) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + eps = torch.finfo(torch.float32).eps # use xnnpack eps? + + # scale + scale = (max_val_pos - min_val_neg) / float(qmax - qmin) + scale = scale.clamp(min=eps) + + # zero point + descaled_min = min_val_neg / scale + descaled_max = max_val_pos / scale + zero_point_from_min_error = qmin + descaled_min + zero_point_from_max_error = qmax + descaled_max + zero_point = torch.where( + zero_point_from_min_error + zero_point_from_max_error > 0, + qmin - descaled_min, + qmax - descaled_max, + ) + zero_point = torch.clamp(zero_point, qmin, qmax).round() + + return scale.to(torch.float64), zero_point.to(torch.int64) + + +quantized_decomposed_lib.define( + "choose_qparams_per_token_asymmetric(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token_asymmetric", + "CompositeExplicitAutograd", +) +def choose_qparams_per_token_asymmetric( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + return _choose_qparams_per_token_asymmetric_impl(input, dtype) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token_asymmetric", + "Meta", +) +def choose_qparams_per_token_asymmetric_meta( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + size = list(input.shape[:-1]) + [1] + return torch.empty(size, dtype=torch.double, device=input.device), torch.empty( + size, dtype=torch.int64, device=input.device + ) + + +def _per_token_quant_qparam_dim_check(input, scales, zero_points): + num_tokens = math.prod(list(input.size())[:-1]) + if num_tokens != scales.numel(): + raise AssertionError(f"num_tokens: {num_tokens} scales: {scales.size()}") + if num_tokens != zero_points.numel(): + raise AssertionError( + f"num_tokens: {num_tokens} zero_points: {zero_points.size()}" + ) + + +quantized_decomposed_lib.define( + "quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "quantize_per_token", "CompositeExplicitAutograd") +def quantize_per_token( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +): + """Per token quantization for the Tensor using the quantization parameters to map + from floating point to quantized values. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scales (float32 torch.Tensor): quantization parameter for per token affine quantization + zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + _per_token_quant_qparam_dim_check(input, scales, zero_points) + input = ( + input.mul(1.0 / scales) + .add(zero_points) + .round() + .clamp(quant_min, quant_max) + .to(dtype) + ) + return input + + +@impl(quantized_decomposed_lib, "quantize_per_token", "Meta") +def quantize_per_token_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +): + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=dtype) + + +quantized_decomposed_lib.define( + "dequantize_per_token(Tensor input, Tensor scales, Tensor zero_points, " + "int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "dequantize_per_token", "CompositeExplicitAutograd") +def dequantize_per_token( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, +): + """Per token dequantization for the Tensor using the quantization parameters to map + from floating point to quantized values. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): quantized Tensor (uint8, int8 etc.) + scales (float64 torch.Tensor): quantization parameter for per token affine quantization + zero_points (int64 torch.Tensor): quantization parameter for per token affine quantization + quant_min (int): minimum quantized value for input Tensor + quant_max (int): maximum quantized value for input Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor + + Returns: + dequantized Tensor with dtype `output_dtype` + """ + input = input - zero_points + input = input * scales + # Since scales are of float64 type, we need to cast it to output dtype requested + return input.to(output_dtype) + + +@impl(quantized_decomposed_lib, "dequantize_per_token", "Meta") +def dequantize_per_token_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, +): + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + # TODO: support fp16 + return torch.empty_like(input, dtype=output_dtype) + + +quantized_decomposed_lib.define( + "quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, int quant_min, " + "int quant_max, ScalarType dtype, int group_size) -> Tensor" +) + + +# TODO: dtype is ignored for now +@impl( + quantized_decomposed_lib, "quantize_per_channel_group", "CompositeExplicitAutograd" +) +def quantize_per_channel_group( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + group_size=128, +): + if group_size <= 1: + raise AssertionError("group_size must be > 1") + # needed for GPTQ single column quantize + if group_size > input.shape[-1] and scales.shape[-1] == 1: + group_size = input.shape[-1] + + if input.shape[-1] % group_size != 0: + raise AssertionError("input.shape[-1] must be divisible by group_size") + if input.dim() != 2: + raise AssertionError("input must be 2-dimensional") + + # TODO: check for dtype, currently we can't express torch.int4 so it's omitted + to_quant = input.reshape(-1, group_size) + if torch.isnan(to_quant).sum() != 0: + raise AssertionError("to_quant must not contain NaNs") + + scales = scales.reshape(-1, 1) + zero_points = zero_points.reshape(-1, 1) + + input_int8 = ( + to_quant.mul(1.0 / scales) + .add(zero_points) + .round() + .clamp_(quant_min, quant_max) + .to(dtype) + .reshape_as(input) + ) + + return input_int8 + + +@impl(quantized_decomposed_lib, "quantize_per_channel_group", "Meta") +def quantize_per_channel_group_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + group_size=128, +): + """Groupwise quantization within each channel for an 2-d Tensor using the quantization parameters + to map from floating point to quantized values. This means for each row of a 2-d Tensor + (M, N), we calculate scales/zero_points for each `group_size` elements + and quantize every `group_size` elements with the same quantization parameter. + The dimension for scales/zero_points will be (M * ceil(N, group_size),) + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization + zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + if group_size <= 1: + raise AssertionError("group_size must be > 1") + # needed for GPTQ single column quantize + if group_size > input.shape[-1] and scales.shape[-1] == 1: + group_size = input.shape[-1] + + if input.shape[-1] % group_size != 0: + raise AssertionError("input.shape[-1] must be divisible by group_size") + if input.dim() != 2: + raise AssertionError("input must be 2-dimensional") + return torch.empty_like(input, dtype=dtype) + + +quantized_decomposed_lib.define( + "dequantize_per_channel_group(Tensor input, Tensor scales, Tensor? zero_points, int quant_min, " + "int quant_max, ScalarType dtype, int group_size, ScalarType output_dtype) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, + "dequantize_per_channel_group", + "CompositeExplicitAutograd", +) +def dequantize_per_channel_group( + w_int8: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor | None, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + group_size: int = 128, + output_dtype: torch.dtype = torch.float32, +): + """Groupwise dequantization within each channel for an 2-d Tensor using the quantization parameters + to map from floating point to quantized values. This means for each row of a 2-d Tensor + (M, N), we calculate scales/zero_points for each `group_size` elements + and quantize every `group_size` elements with the same quantization parameter. + The dimension for scales/zero_points will be (M * ceil(N, group_size),) + + Args: + input (torch.Tensor): quantized Tensor (uint8/int8 etc.) + scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization + zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization + quant_min (int): minimum quantized value for input Tensor + quant_max (int): maximum quantized value for input Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor + + Returns: + dequantized Tensor with dtype `output_dtype` + """ + + if group_size <= 1: + raise AssertionError("group_size must be > 1") + # needed for GPTQ single column dequantize + if group_size > w_int8.shape[-1] and scales.shape[-1] == 1: + group_size = w_int8.shape[-1] + if w_int8.shape[-1] % group_size != 0: + raise AssertionError("w_int8.shape[-1] must be divisible by group_size") + if w_int8.dim() != 2: + raise AssertionError("w_int8 must be 2-dimensional") + + w_int8_grouped = w_int8.reshape(-1, group_size) + scales = scales.reshape(-1, 1) + if zero_points is not None: + zp = zero_points.reshape(-1, 1) + else: + zp = torch.zeros([], dtype=torch.int32, device=scales.device) + w_dq = w_int8_grouped.sub(zp).mul(scales).reshape_as(w_int8).to(output_dtype) + return w_dq + + +quantized_decomposed_lib.define( + "fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " + "int quant_min, int quant_max) -> Tensor" +) + + +class FakeQuantPerChannel(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max): + if scales.dtype != torch.float32: + scales = scales.to(torch.float32) + if zero_points.dtype != torch.int32: + zero_points = zero_points.to(torch.int32) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") + broadcast_dims = list(range(axis)) + list(range(axis + 1, input.ndim)) + unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims) + unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims) + temp = torch.round(input * (1.0 / unsqueeze_scales)) + unsqueeze_zero_points + out = ( + torch.clamp(temp, quant_min, quant_max) - unsqueeze_zero_points + ) * unsqueeze_scales + mask = torch.logical_and((temp >= quant_min), (temp <= quant_max)) + + ctx.save_for_backward(mask) + return out + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, gy): + (mask,) = ctx.saved_tensors + return gy * mask, None, None, None, None, None + + +@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Autograd") +def fake_quant_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, +) -> torch.Tensor: + return FakeQuantPerChannel.apply( + input, scales, zero_points, axis, quant_min, quant_max + ) + + +@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Meta") +def fake_quant_per_channel_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, +) -> torch.Tensor: + return torch.empty_like(input) + + +quantized_decomposed_lib.define( + "convert_element_type.no_fuse(Tensor input, ScalarType dtype) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, + "convert_element_type.no_fuse", + "CompositeExplicitAutograd", +) +def convert_element_type(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return torch.ops.prims.convert_element_type.default(input, dtype) + + +@impl(quantized_decomposed_lib, "convert_element_type.no_fuse", "Meta") +def convert_element_type_meta(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return torch.empty_like(input, dtype=dtype) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_equalize.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_equalize.py new file mode 100644 index 0000000000000000000000000000000000000000..dda37214210e34bb7676b9877d2e44876366a07f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_equalize.py @@ -0,0 +1,1020 @@ +# mypy: allow-untyped-defs +import operator +import warnings +from collections import namedtuple +from typing import Any + +import torch +import torch.ao.nn.intrinsic as nni +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr +from torch.ao.quantization.observer import ( + _with_args, + ObserverBase, + PerChannelMinMaxObserver, +) +from torch.ao.quantization.utils import _parent_name, check_min_max_valid +from torch.fx import GraphModule +from torch.fx.graph import Node + +from .utils import ( + get_new_attr_name_with_prefix, + maybe_get_next_module, + node_arg_is_weight, +) + + +CUSTOM_MODULE_SUPP_LIST: list[Any] = [] + + +def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor: + """Reshapes the scale so that we can multiply it to the input by the given axis.""" + new_shape = [1] * input.ndim + new_shape[axis] = input.size(axis) + return scale.view(new_shape) + + +qsheme_mapping_per_tensor_to_per_channel = { + torch.per_tensor_affine: torch.per_channel_affine, + torch.per_tensor_symmetric: torch.per_channel_symmetric, +} + + +class _InputEqualizationObserver(nn.Module): + r"""Observer for tracking the running min/max values of input columns, and + computing the quantization parameters for the overall min/max input values. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme + quant_min: Minimum quantization value. If unspecified, it will + follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will + follow the 8-bit setup. + + The running minimum/maximum :math:`x_\text{min/max}` are computed in the + same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`, + with the difference that the running min/max values are stored per column. + This observer is intended to be used along with a WeightEqualizationObserver + to calculate the equalization scale. + """ + + def __init__( + self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + quant_min=None, + quant_max=None, + factory_kwargs=None, + ) -> None: + super().__init__() + + if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + raise TypeError("Input qscheme must be per-tensor") + + self.dtype = dtype + self.qscheme = qscheme + + per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme] + self.input_obs = PerChannelMinMaxObserver( + ch_axis=1, + dtype=dtype, + qscheme=per_channel_qscheme, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + ) + + self.equalization_scale = torch.tensor(1) + self.equalization_shape: list[int] = [] + + def forward(self, x_orig): + if x_orig.ndim < 2 or x_orig.ndim > 5: + raise ValueError( + "InputEqualizationObserver only supports Linear and Conv layers" + ) + + # Calculate the shape needed to reshape the equalization scale later (needed for Conv layers) + self.equalization_shape = [1] * x_orig.ndim + self.equalization_shape[1] = x_orig.size(1) + + return self.input_obs(x_orig) + + def get_input_minmax(self): + return (self.input_obs.min_val, self.input_obs.max_val) + + def set_equalization_scale(self, equalization_scale): + # Reshape the equalization scale along axis=1 so that it can be + # multiplied with the input along axis=1 + if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1): + return + self.equalization_scale = torch.reshape( + equalization_scale, self.equalization_shape + ) + + def calculate_scaled_minmax(self): + r"""Returns the scaled min/max inputs""" + if ( + self.equalization_scale.nelement() == 1 + and self.equalization_scale == torch.tensor(1) + ): + warnings.warn( + "Must call calculate_equalization_scale before calling calculate_scaled_minmax. " + + "Will not scale the next quantization observer.", + stacklevel=2, + ) + return None, None + + # Calculate qparams for the scaled min/max inputs + # Scale the input by the equalization scale located at the same column + # index + (min_inputs, max_inputs) = self.get_input_minmax() + equalization_scale_reshaped = reshape_scale( + self.equalization_scale, 0, min_inputs + ) + min_input_scaled = torch.min(torch.mul(min_inputs, equalization_scale_reshaped)) + max_input_scaled = torch.max(torch.mul(max_inputs, equalization_scale_reshaped)) + + return min_input_scaled, max_input_scaled + + with_args = classmethod(_with_args) + + +class _WeightEqualizationObserver(nn.Module): + r"""Observer for tracking the running min/max values of weight columns and + rows, and computing the quantization parameters for the weight rows. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme + quant_min: Minimum quantization value. If unspecified, it will + follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will + follow the 8-bit setup. + + This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used + to record the running minimum and maximum of columns of incoming weight + tensors. This observer is intended to be used along with an + InputEqualizationObserver to calculate the equalization scale. + + The running minimum/maximum :math:`w_\text{min/max}` are computed in the + same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`. + """ + + def __init__( + self, + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=None, + quant_max=None, + factory_kwargs=None, + ) -> None: + super().__init__() + + self.dtype = dtype + self.qscheme = qscheme + self.ch_axis = 1 + + per_channel_qscheme = qscheme + if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme] + self.weight_col_obs = PerChannelMinMaxObserver( + ch_axis=1, + dtype=dtype, + qscheme=per_channel_qscheme, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + ) + + self.equalization_scale = torch.tensor(1) + + def forward(self, w_orig): + if w_orig.ndim < 2 or w_orig.ndim > 5: + raise ValueError( + "InputEqualizationObserver only supports Linear and Conv layers" + ) + + return self.weight_col_obs(w_orig) + + def get_weight_col_minmax(self): + return (self.weight_col_obs.min_val, self.weight_col_obs.max_val) + + def set_equalization_scale(self, equalization_scale): + self.equalization_scale = equalization_scale + + with_args = classmethod(_with_args) + + +def calculate_equalization_scale( + input_obs: _InputEqualizationObserver, weight_obs: _WeightEqualizationObserver +) -> torch.Tensor: + r"""Calculates the equalization scale and sets the equalization_scale value + in the observers. + + Args: + input_obs: Observer that tracks the ranges for the input columns + weight_obs: Observer that tracks the ranges for the weight columns + """ + + (min_inputs, max_inputs) = input_obs.get_input_minmax() + (min_weights, max_weights) = weight_obs.get_weight_col_minmax() + + if not ( + check_min_max_valid(min_inputs, max_inputs) + and check_min_max_valid(min_weights, max_weights) + ): + warnings.warn( + "Must run observer before calling calculate_equalization_scale. " + + "Returning default equalization scale torch.tensor(1).", + stacklevel=2, + ) + return torch.tensor(1) + + if min_inputs.shape != min_weights.shape: + raise ValueError( + "Input and Weight must have the same column dimension. " + + f"Found {min_inputs.shape} and {min_weights.shape} shapes instead." + ) + + equalization_scale = torch.sqrt( + (max_weights - min_weights) / (max_inputs - min_inputs) + ) + # Replace all 'inf', 'nan', 0's with 1s to prevent errors + equalization_scale[equalization_scale == 0.0] = 1 + equalization_scale = torch.nan_to_num(equalization_scale, nan=1, posinf=1, neginf=1) + return equalization_scale + + +class EqualizationQConfig( + # pyrefly: ignore [invalid-inheritance] + namedtuple("EqualizationQConfig", ["input_activation", "weight"]) +): + """ + Describes how to quantize a layer or a part of the network specifically for + input-weight equalization by providing settings (observer classes) for + inputs, outputs, and weights. + + Note that EqualizationQConfig needs to contain observer **classes** (like + MinMaxObserver) or a callable that returns instances on invocation, not the + concrete observer instances themselves. + Quantization function will instantiate observers multiple times for each of + the layers. + + Observer classes have usually reasonable default arguments, but they can be + overwritten with `with_args` method (that behaves like functools.partial): + + my_qconfig = EqualizationQConfig(input_activation=_InputEqualizationObserver.with_args(dtype=torch.qint8), + weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8)) + """ + + __slots__ = () + + def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity): + if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module): + raise ValueError( + "EqualizationQConfig received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" + ) + self = super().__new__(cls, input_activation, weight) + return self + + +input_equalization_observer = _InputEqualizationObserver.with_args( + dtype=torch.quint8, qscheme=torch.per_tensor_symmetric +) +weight_equalization_observer = _WeightEqualizationObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric +) +default_equalization_qconfig = EqualizationQConfig( + input_activation=input_equalization_observer, weight=weight_equalization_observer +) + + +def fused_module_supports_equalization(module) -> bool: + """Checks if the fused node supports equalization.""" + return type(module) in [ + nni.LinearReLU, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + ] + + +def nn_module_supports_equalization(module) -> bool: + """Checks if the torch.nn node supports equalization.""" + return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d] + + +def custom_module_supports_equalization(module) -> bool: + """Checks if the custom node supports equalization.""" + return type(module) in CUSTOM_MODULE_SUPP_LIST + + +def node_supports_equalization(node: Node, modules) -> bool: + """Checks if the current node supports equalization + Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers + """ + if node.op == "call_module": + return ( + nn_module_supports_equalization(modules[str(node.target)]) + or fused_module_supports_equalization(modules[str(node.target)]) + or custom_module_supports_equalization(modules[str(node.target)]) + ) + elif node.op == "call_function": + return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d] + return False + + +def is_equalization_observer(observer: nn.Module) -> bool: + return isinstance( + observer, (_InputEqualizationObserver, _WeightEqualizationObserver) + ) + + +############################################################################### +# Functions for equalization during convert # +############################################################################### + + +def get_op_node_and_weight_eq_obs( + input_eq_obs_node: Node, model: GraphModule, modules: dict[str, nn.Module] +) -> tuple[Node | None, _WeightEqualizationObserver | None]: + """Gets the following weight equalization observer. There should always + exist a weight equalization observer after an input equalization observer. + + Returns the operation node that follows the input equalization observer node + and the weight equalization observer + """ + + # Find the op node that comes directly after the input equalization observer + op_node = None + for user in input_eq_obs_node.users: + if node_supports_equalization(user, modules): + op_node = user + break + + if op_node is None: + raise AssertionError( + "Expected an operation node after the input equalization observer" + ) + if op_node.op == "call_module": + # If the op_node is a nn.Linear layer, then it must have a + # WeightEqualizationObserver configuration + maybe_equalization_node_name_to_config = _get_observed_graph_module_attr( + model, "equalization_node_name_to_qconfig" + ) + if maybe_equalization_node_name_to_config is None: + raise AssertionError( + "Expected 'equalization_node_name_to_qconfig' attribute in observed graph module" + ) + equalization_node_name_to_qconfig: dict[str, Any] = ( + maybe_equalization_node_name_to_config # type: ignore[assignment] + ) + if equalization_node_name_to_qconfig.get(op_node.name, None) is None: + raise AssertionError( + f"No equalization qconfig found for op node {op_node.name}" + ) + weight_eq_obs = equalization_node_name_to_qconfig.get( # type: ignore[union-attr] + op_node.name, None + ).weight() + + if not isinstance(weight_eq_obs, _WeightEqualizationObserver): + raise AssertionError( + "Expected weight equalization observer to be a _WeightEqualizationObserver" + ) + return op_node, weight_eq_obs + + elif op_node.op == "call_function": + weight_node = maybe_get_weight_eq_obs_node(op_node, modules) + if weight_node is not None: + weight_eq_obs = modules[str(weight_node.target)] + if not isinstance(weight_eq_obs, _WeightEqualizationObserver): + raise AssertionError( + "Expected weight equalization observer to be a _WeightEqualizationObserver" + ) + return op_node, weight_eq_obs + + return None, None + + +def maybe_get_weight_eq_obs_node( + op_node: Node, modules: dict[str, nn.Module] +) -> Node | None: + """Gets the weight equalization observer node if it exists.""" + if op_node.op != "call_function": + raise AssertionError( + "maybe_get_weight_eq_obs_node expects a call_function op_node" + ) + for node_arg in op_node.args: + if node_arg_is_weight(op_node, node_arg): + if ( + isinstance(node_arg, Node) + and node_arg.op == "call_module" + and isinstance( + modules[str(node_arg.target)], _WeightEqualizationObserver + ) + ): + return node_arg + return None + + +def maybe_get_next_input_eq_obs( + node: Node, modules: dict[str, nn.Module] +) -> _InputEqualizationObserver | None: + """Gets the following input equalization observer if it exists. + + For example, in the case of connecting linear layers: + x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2 + If the node being passed in is the linear1 node, then we want to return eq_obs2, + the following equalization observer for linear2. + + However, if there are no connecting layers: + x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add + Then we want to return None. + + In the case of an unfused linear-relu layer with a connecting linear layer: + linear1 -> relu -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2 + Since it is unfused, we want to skip over the relu layer and return eq_obs2, + the following equalization observer for linear2. + """ + + if not node_supports_equalization(node, modules): + raise AssertionError("Node does not support equalization") + + # Locate the following nn.ReLU or F.relu node if it exists + maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU) + if maybe_relu_node is None: + maybe_relu_node = maybe_get_next_module( + node, modules, target_functional_type=F.relu + ) + + # Locate the following output observer if it exists. + # We will skip the relu node if it exists. + maybe_obs_node = ( + maybe_get_next_module(node, modules, ObserverBase) + if maybe_relu_node is None + else maybe_get_next_module(maybe_relu_node, modules, ObserverBase) + ) + if maybe_obs_node is None: + return None + + maybe_eq_obs_node = maybe_get_next_module( + maybe_obs_node, modules, _InputEqualizationObserver + ) + if maybe_eq_obs_node is None: + return None + + maybe_eq_obs = modules[str(maybe_eq_obs_node)] + if not isinstance(maybe_eq_obs, _InputEqualizationObserver): + raise AssertionError( + "Expected the following equalization observer to be an _InputEqualizationObserver" + ) + return maybe_eq_obs + + +def maybe_get_next_equalization_scale( + node: Node, modules: dict[str, nn.Module] +) -> torch.Tensor | None: + """If the next next node is an InputEqualizationObserver then we want to + return its equalization scale, else we return 1 + + This is used in the case where there are two connecting linear layers: + linear1 -> LinearOutObs -> InputEqObs -> linear2 + In this case, the node given is linear1 and we want to locate the InputEqObs. + """ + next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules) + # pyrefly: ignore [invalid-argument] + if next_inp_eq_obs: + if ( + next_inp_eq_obs.equalization_scale.nelement() == 1 + and next_inp_eq_obs.equalization_scale == torch.tensor(1) + ): + return None + return next_inp_eq_obs.equalization_scale + return None + + +def scale_input_observer(node: Node, modules: dict[str, nn.Module]) -> None: + """Scales the following input quantization observer's min/max values by + updating the values with the scaled min/max values calculated by the input + equalization observer + """ + input_eq_obs = modules[str(node.target)] + if not isinstance(input_eq_obs, _InputEqualizationObserver): + raise AssertionError( + "Expected the module at node.target to be an _InputEqualizationObserver" + ) + + input_quant_obs_node = node.args[0] + if not isinstance(input_quant_obs_node, Node): + raise AssertionError( + "Expected the input quantization observer node to be a Node" + ) + + input_quant_obs = modules[str(input_quant_obs_node.target)] + if not isinstance(input_quant_obs, ObserverBase): + return + + min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax() + if min_input_scaled is None and max_input_scaled is None: + return + input_quant_obs.min_val = min_input_scaled + input_quant_obs.max_val = max_input_scaled + + +def scale_weight_node( + node: Node, + modules: dict[str, nn.Module], + equalization_scale: torch.Tensor, + next_equalization_scale: torch.Tensor | None, +) -> None: + """Scale the weights for input-weight equalization by multiplying the + weight by 1/equalization_scale and next_equalization_scale + + Args: + node: Current node whose weights we want to scale + equalization_scale: Current node's calculated equalization scale + next_equalization_scale: Next node's calculated equalization scale if + the following node needs to be equalized, 1 otherwise + """ + if equalization_scale is None: + return + + if fused_module_supports_equalization(modules[str(node.target)]): + op_module = modules[str(node.target)][0] # type: ignore[index] + else: + op_module = modules[str(node.target)] + if not ( + nn_module_supports_equalization(op_module) + or custom_module_supports_equalization(op_module) + ): + raise AssertionError( + "Expected operation module to support equalization (nn or custom)" + ) + + # Scale the weights for input-weight equalization + # If the following layer needs to be equalized then we will multiply its scale + weight = op_module.weight + if not isinstance(weight, torch.Tensor): + raise AssertionError("Expected op_module.weight to be a torch.Tensor") + + # Scale the weights by the reciprocal of the equalization scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=1 + equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight) + scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped)) + + if next_equalization_scale is None: + op_module.weight = nn.Parameter(scaled_weight) + return + + # Multiply the weights row wise by the next equalization scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=0 + next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, weight) + scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) + + op_module.weight = nn.Parameter(scaled_weight) + + # Multiply the bias element wise by the next equalization scale + bias = op_module.bias + if bias is None: + return + if not isinstance(bias, torch.Tensor): + raise AssertionError("Expected op_module.bias to be a torch.Tensor") + + # Reshape the equalization scale so that we can multiply it element-wise to the bias + next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) + scaled_bias = torch.mul(bias, next_equalization_scale_reshaped) + op_module.bias = nn.Parameter(scaled_bias) + + +def scale_weight_functional( + op_node: Node, + model: GraphModule, + modules: dict[str, nn.Module], + equalization_scale: torch.Tensor, + next_equalization_scale: torch.Tensor | None, +) -> None: + """Scales the weight value for functional layers""" + if equalization_scale is None: + return + + # From the given op_node, the path looks like: + # get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node + # So we want to trace back from the op_node to get the equalization observer + # node, then the quantization observer node, and then finally the weight + # node which contains the weight values. + + # Get the equalization observer node + weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules) + if weight_eq_obs_node is None: + return + + # Get the quantization observer node + weight_quant_obs_node = weight_eq_obs_node.args[0] + if weight_quant_obs_node is None: + return + if not ( + isinstance(weight_quant_obs_node, Node) + and isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase) + ): + raise AssertionError( + "Expected weight_quant_obs_node to be a Node whose module is an ObserverBase" + ) + + # Get the get_attr(weight) node + weight_node = weight_quant_obs_node.args[0] + if weight_node is None: + return + if not (isinstance(weight_node, Node) and weight_node.op == "get_attr"): + raise AssertionError("Expected weight node to be a 'get_attr' Node") + + weight_parent_name, weight_name = _parent_name(weight_node.target) + weight = getattr(modules[weight_parent_name], weight_name) + + # Scale the weights for input-weight equalization + # If the following layer needs to be equalized then we will multiply its scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=1 + equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight) + scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped)) + + if next_equalization_scale is None: + setattr(modules[weight_parent_name], weight_name, scaled_weight) + return + + # Multiply the weights row wise by the next equalization scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=1 + next_equalization_scale_reshaped = reshape_scale( + next_equalization_scale, 0, scaled_weight + ) + scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) + + setattr(modules[weight_parent_name], weight_name, scaled_weight) + if not torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight): + raise AssertionError("Model buffer for weight does not match the scaled weight") + + # Multiply the bias element wise by the next equalization scale + bias_node = None + for node in op_node.args: + # Find the node containing the weight values + if isinstance(node, Node) and node.op == "get_attr" and "bias" in node.name: + bias_node = node + break + if bias_node is None: + return + + bias_parent_name, bias_name = _parent_name(bias_node.target) + bias = getattr(modules[bias_parent_name], bias_name) + + # Reshape the equalization scale so that we can multiply it element-wise to the bias + next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) + scaled_bias = torch.mul(bias, next_equalization_scale_reshaped) + setattr(modules[bias_parent_name], bias_name, scaled_bias) + + +def clear_weight_quant_obs_node(op_node: Node, modules: dict[str, nn.Module]) -> None: + """Given the operation node, we want find the corresponding quantization + observer and reset its min/max values + """ + weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules) + if weight_eq_obs_node is None: + return + + weight_quant_obs_node = weight_eq_obs_node.args[0] + if weight_quant_obs_node is None: + return + if not isinstance(weight_quant_obs_node, Node): + raise AssertionError("Expected weight_quant_obs_node to be a Node") + + weight_quant_obs = modules[str(weight_quant_obs_node.target)] + if not isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase): + raise AssertionError( + "Expected the module at weight_quant_obs_node to be an ObserverBase" + ) + weight_quant_obs.reset_min_max_vals() # type: ignore[operator] + + +def remove_node(model: GraphModule, node: Node, prev_node: Node): + """Removes the given node from the model by replacing all of its users with + the given previous node + """ + # For all of the current node's users, replace the current node with + # the input quantization observer node + orig_users = list(node.users.keys()) + for user_node in orig_users: + user_node.replace_input_with(node, prev_node) + + # Erase the InputEqualizationObserver node + model.graph.erase_node(node) + + +def update_obs_for_equalization( + model: GraphModule, modules: dict[str, nn.Module] +) -> dict[str, _WeightEqualizationObserver]: + """Update all of the observer's equalization scale. For each + InputEqualizationObserver, we will find the location of the next + WeightEqualizationObserver, create it, and calculate the equalization scale + based on the two observers. + + We will then return a dictionary mapping operation node names to + the corresponding WeightEqualizationObservers for that operation. + """ + weight_eq_obs_dict = {} + for node in model.graph.nodes: + if node.op == "call_module" and isinstance( + modules[node.target], _InputEqualizationObserver + ): + input_eq_obs = modules[node.target] + if not isinstance(input_eq_obs, _InputEqualizationObserver): + raise AssertionError( + "Expected module at node.target to be an _InputEqualizationObserver" + ) + op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules) + + if op_node is None or weight_eq_obs is None: + continue + + if op_node.op == "call_module": + # Calibrate the weight equalization observer since it has just + # been created + if fused_module_supports_equalization(modules[str(op_node.target)]): + module = modules[str(op_node.target)][0] # type: ignore[index] + if not nn_module_supports_equalization(module): + raise AssertionError( + "Expected fused module to support equalization" + ) + weight_eq_obs(module.weight) + else: + weight_eq_obs(modules[str(op_node.target)].weight) + + # Calculate and set the equalization scale values + equalization_scale = calculate_equalization_scale( + input_eq_obs, weight_eq_obs + ) + input_eq_obs.set_equalization_scale(equalization_scale) + weight_eq_obs.set_equalization_scale(equalization_scale) + + weight_eq_obs_dict[op_node.name] = weight_eq_obs + + return weight_eq_obs_dict + + +def convert_eq_obs( + model: GraphModule, + modules: dict[str, nn.Module], + weight_eq_obs_dict: dict[str, _WeightEqualizationObserver], +) -> None: + """Converts the equalization operations and updates the other nodes in the + following way: + - Removes the input equalization observers and inserts a mul operator + along with an equalization scale node wherever applicable (we do not + want to insert a mul operator between connecting linear layers). + - Updates the input quantization observers with the scaled input min/max + values. + - Scales the weights by the current and next equalization scales. + - Removes the weight equalization observer node if it exists. + + Before (after prepare): + weight values + | + WeightQuantObs + | + WeightEqObs + | + x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs + + After this function: + scaled weight values + | + equalization scale WeightQuantObs + | | + x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs + + After convert: + equalization scale scaled weight values + | | + x -> mul -> quantize_per_tensor -> quantized::linear + + Note that although the equalization observer appeared after the quantization + observer after prepare_fx, the mul node appears before the quantization node + after convert_fx. This is because placing the equalization observer after + the quantization observer in prepare_fx would allow us to keep the invariant + that the graph before the current node inserts its observers is not + modified. + + Having the equalization observer before the quantization observer would also + cause some inconsistences between the ordering of the quantization and + equalization observers. + For example, a single linear layer would look like: + x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1 + But between two connected linear layers, it would look like: + linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2 + """ + for node in model.graph.nodes: + if node.op == "call_module" and isinstance( + modules[node.target], _InputEqualizationObserver + ): + inp_quant_obs_node = node.args[0] + prev_node = inp_quant_obs_node.args[0] + + # If the previous node is a layer that needs to be equalized, then + # we will remove the current node because we do not need to add any + # equalization nodes between two layers that need to be equalized + + # Before: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> input_eq_obs2 (node) -> linear2 + # After: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> linear2 + if ( + node_supports_equalization(prev_node, modules) + or "relu" in prev_node.name + ): + remove_node(model, node, inp_quant_obs_node) + continue + + # Update the following input quantization observer's min/max values + scale_input_observer(node, modules) + + # Remove the InputEqualization node and add a mul operator before + # the quantization observer node that appears before the equalization node + # Before: x -> input_quant_obs -> input_eq_obs -> linear + # After: x -> mul -> input_quant_obs -> linear + + # Create a node containing the equalization scale + with model.graph.inserting_before(inp_quant_obs_node): + get_new_eq_scale_name = get_new_attr_name_with_prefix( + prev_node.name + "_equalization_scale" + ) + name = get_new_eq_scale_name(modules) + setattr(model, name, modules[node.target].equalization_scale) + eq_scale_node = model.graph.create_node("get_attr", name) + + # Create a node multiplying the input with the equalization scale + with model.graph.inserting_after(eq_scale_node): + inputs = (prev_node, eq_scale_node) + mul_node = model.graph.create_node("call_function", torch.mul, inputs) + + # Set the mul nod to be the input_quant_obs_node's input instead of + # the previous node + inp_quant_obs_node.replace_input_with(prev_node, mul_node) + remove_node(model, node, inp_quant_obs_node) + + elif weight_eq_obs_dict.get(node.name, None) is not None: + weight_eq_obs = weight_eq_obs_dict.get(node.name) + if not isinstance(weight_eq_obs, _WeightEqualizationObserver): + raise AssertionError( + "Expected weight equalization observer to be a _WeightEqualizationObserver" + ) + equalization_scale = weight_eq_obs.equalization_scale + + if ( + equalization_scale.nelement() == 1 + and equalization_scale == torch.tensor(1) + ): + equalization_scale = None # type: ignore[assignment] + maybe_next_equalization_scale = maybe_get_next_equalization_scale( + node, modules + ) + + # Scale the weight nodes + if node.op == "call_module": + scale_weight_node( + node, + modules, + # pyrefly: ignore [bad-argument-type] + equalization_scale, + maybe_next_equalization_scale, + ) + elif node.op == "call_function": + scale_weight_functional( + node, + model, + modules, + # pyrefly: ignore [bad-argument-type] + equalization_scale, + maybe_next_equalization_scale, + ) + + weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules) + if weight_eq_obs_node is None: + return + if not isinstance( + modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver + ): + raise AssertionError( + "Expected weight equalization observer to be a _WeightEqualizationObserver" + ) + + # Clear the quantization observer's min/max values so that they + # can get updated later based on the new scale values + clear_weight_quant_obs_node(node, modules) + + # Erase the weight equalization observer node + prev_node = weight_eq_obs_node.args[0] + remove_node(model, weight_eq_obs_node, prev_node) # type: ignore[arg-type] + else: + raise ValueError( + "Expected operation node to be 'call_module' or 'call_function" + + f"Instead got node {node.name} as '{node.op}'." + ) + + +def _convert_equalization_ref(model: GraphModule): + """Reference function which applies changes needed for equalization, but + does not quantize the nodes + """ + modules = dict(model.named_modules(remove_duplicate=False)) + + # Calculate the equalization scale, update the observers with the scaled + # inputs, and scale the weight + weight_eq_obs_dict = update_obs_for_equalization(model, modules) + convert_eq_obs(model, modules, weight_eq_obs_dict) + + return GraphModule(model, model.graph) + + +############################################################################### +# Functions for running the equalized model on the Numeric Suite # +############################################################################### + + +def get_layer_sqnr_dict( + model_a: nn.Module, model_b: nn.Module, x: torch.Tensor +) -> dict[str, float]: + """Runs the Numeric Suite on model_a and model_b and returns a dictionary + containing the SQNR between layers in model_a and model_b. + + Note: In order to support equalized models, this function has a hacky fix in + which we do not match any torch.mul operators. This is because equalized + models contain extra mul operators to scale the input by the equalization + scale, but this edge case has not been resolved yet within the numeric suite code. + + Args: + model_a: A float model + model_b: A quantized model + x: Inputs to use during calibration + """ + import torch.ao.ns._numeric_suite_fx as ns + from torch.ao.ns.fx.mappings import get_unmatchable_types_map + + unmatchable_types_map = get_unmatchable_types_map() + unmatchable_types_map["funs_unmatchable"].add(torch.mul) + + model_a_ns, model_b_ns = ns.add_loggers( + "fp32", + model_a, + "int8", + model_b, + ns.OutputLogger, + unmatchable_types_map=unmatchable_types_map, + ) + + model_a_ns(x) + model_b_ns(x) + + activation_comparison_dict = ns.extract_logger_info( + model_a_ns, model_b_ns, ns.OutputLogger, "int8" + ) + ns.extend_logger_results_with_comparison( + activation_comparison_dict, + "fp32", + "int8", + torch.ao.ns.fx.utils.compute_sqnr, + "sqnr", + ) + + # Construct a dictionary mapping layer names to the SQNR values + layer_sqnr_dict = {} + for key in activation_comparison_dict: + layer = activation_comparison_dict[key]["node_output"]["int8"][0]["fqn"] + sqnr = activation_comparison_dict[key]["node_output"]["int8"][0]["sqnr"][0] + layer_sqnr_dict[layer] = sqnr + + return layer_sqnr_dict + + +def get_equalization_qconfig_dict( + layer_sqnr_dict: dict[str, float], num_layers_to_equalize: int +) -> Any: + """Given the layer to SQNR dictionary, find the layers with the highest + quantization errors, and return an equalization_qconfig_dict + specifying to only equalize those top layers. + + Args: + layer_sqnr_dict: Dictionary mapping layer names to SQNR values (found + when comparing an equalized model against a float model) + num_layers_to_equalize: Number of layers with the highest quantization + errors to equalize + """ + + # Sort the layer_sqnr_dictionary values and get the layers with the lowest + # SQNR values (aka highest quantization errors) + layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=operator.itemgetter(1)) + layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize] + + # Constructs an equalization_qconfig_dict that specifies to only equalize + # the layers with the highest quantization errors + module_to_qconfig_list = [ + (item[0], default_equalization_qconfig) for item in layers_to_equalize + ] + equalization_qconfig_dict = {"module_name": module_to_qconfig_list} + return equalization_qconfig_dict diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..ad20bcc96251d8fb439e5201a2038e28e5ec675b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -0,0 +1,1413 @@ +# mypy: allow-untyped-defs +import operator +from collections.abc import Callable +from typing import Any + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.intrinsic.quantized.dynamic as nniqd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.ao.nn.quantized.reference as nnqr +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization.quantization_mappings import get_quantized_operator +from torch.ao.quantization.utils import _parent_name +from torch.fx import GraphModule, map_arg, Node +from torch.fx.graph import Graph + +from .utils import ( + collect_producer_nodes, + create_node_from_old_node_preserve_meta, + get_linear_prepack_op_for_dtype, + get_new_attr_name_with_prefix, + get_qconv_prepack_op, + graph_module_from_producer_nodes, +) + + +QOP_TO_ARG_NAMES_TO_SKIP: dict[Callable[..., Any], list[str]] = { + torch._ops.ops.quantized.hardswish: ["inplace"], + torch._ops.ops.quantized.elu: ["inplace"], + torch._ops.ops.quantized.dropout: ["inplace"], + torch._ops.ops.quantized.instance_norm: [ + "running_mean", + "running_var", + "use_input_stats", + "momentum", + ], +} + + +def _is_node_in_list(node, modules, func_list, method_list, module_type_list): + is_call_function = node.op == "call_function" and node.target in func_list + is_call_method = node.op == "call_method" and node.target in method_list + is_call_module = ( + node.op == "call_module" and type(modules[str(node.target)]) in module_type_list + ) + return is_call_function, is_call_method, is_call_module + + +def is_fixed_qparams_node(node, modules): + func_list = [ + torch.nn.functional.hardsigmoid, + torch.nn.functional.sigmoid, + torch.sigmoid, + torch.tanh, + ] + method_list = [ + "hardsigmoid", + "hardsigmoid_", + "sigmoid", + "sigmoid_", + "tanh", + "tanh_", + ] + module_type_list = [ + torch.nn.Hardsigmoid, + torch.nn.Sigmoid, + torch.nn.Tanh, + torch.nn.Softmax, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_default_node(node, modules): + func_list = [ + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.instance_norm, + torch.nn.functional.layer_norm, + torch.nn.functional.leaky_relu, + torch.nn.functional.dropout, + ] + method_list: list[Any] = [] + module_type_list = [ + nnqr.ConvTranspose1d, + nnqr.ConvTranspose2d, + nnqr.ConvTranspose3d, + torch.nn.ELU, + torch.nn.LeakyReLU, + torch.nn.Hardswish, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.Dropout, + torch.nn.PReLU, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.ao.nn.intrinsic.BNReLU2d, + torch.ao.nn.intrinsic.BNReLU3d, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_copy_node(node, modules): + func_list = [ + torch.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.interpolate, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.relu, + torch.nn.functional.relu6, + torch.avg_pool1d, + torch._C._nn.avg_pool2d, + torch._C._nn.avg_pool3d, + torch.clamp, + torch.flatten, + torch.mean, + operator.floordiv, + # F.channel_shuffle and torch.channel_shuffle are essentially the same thing + # so we only need to put one of them here + torch.channel_shuffle, + ] + method_list = [ + "clamp", + "mean", + "relu", + "relu_", + ] + module_type_list = [ + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.Hardtanh, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.ChannelShuffle, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_general_tensor_shape_node(node, modules): + func_list = [ + torch.narrow, + torch.transpose, + torch.repeat_interleave, + torch.squeeze, + torch.stack, + torch.unsqueeze, + torch.nn.functional.pixel_shuffle, + torch.nn.functional.pixel_unshuffle, + ] + method_list = [ + "contiguous", + "detach", + "detach_", + "permute", + "repeat", + "repeat_interleave", + "reshape", + "resize_", + "shape", + "size", + "squeeze", + "squeeze_", + "transpose", + "unsqueeze", + "unsqueeze_", + "view", + ] + module_type_list = [ + torch.nn.Identity, + torch.nn.PixelShuffle, + torch.nn.PixelUnshuffle, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_other_node(node, modules): + func_list = [ + torch.cat, + ] + method_list: list[Any] = [] + module_type_list: list[Any] = [] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_special_pattern_node(node, modules): + res_function, res_method, res_module = False, False, False + for checker in [ + is_fixed_qparams_node, + is_default_node, + is_copy_node, + is_general_tensor_shape_node, + is_other_node, + ]: + is_call_function, is_call_method, is_call_module = checker(node, modules) + res_function = res_function or is_call_function + res_method = res_method or is_call_method + res_module = res_module or is_call_module + return res_function, res_method, res_module + + +def is_dequantize_node(node): + return ( + isinstance(node, Node) + and node.op == "call_method" + and node.target == "dequantize" + ) + + +def is_getattr_tensor_metadata_node(node): + return ( + node.op == "call_function" + and node.target is getattr + and node.args[1] == "shape" + ) + + +def is_get_tensor_info_node(node): + return node.op == "call_method" and node.target in ["shape", "size"] + + +def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: dict[str, QConfigAny]): + """ + Return True if the op is configured with a None qconfig, False otherwise. + Note: maybe need to generalize this to also check for the dtype, and we + only lower when dtype matches, but right now fbgemm/qnnpack only support + a single dtype, so it is OK for now. + """ + return op.name in qconfig_map and qconfig_map[op.name] is None + + +# Mapping from reference module class to the replacement static quantized module class for lowering +STATIC_LOWER_MODULE_MAP: dict[type[nn.Module], type[WeightedQuantizedModule]] = { + nnqr.Linear: nnq.Linear, + nnqr.Conv1d: nnq.Conv1d, + nnqr.Conv2d: nnq.Conv2d, + nnqr.Conv3d: nnq.Conv3d, +} + +# Mapping from reference module class to the replacement dynamic quantized module class for lowering +DYNAMIC_LOWER_MODULE_MAP: dict[type[nn.Module], type[nn.Module]] = { + nnqr.Linear: nnqd.Linear, + nnqr.GRUCell: nnqd.GRUCell, + nnqr.LSTMCell: nnqd.LSTMCell, + nnqr.RNNCell: nnqd.RNNCell, + nnqr.LSTM: nnqd.LSTM, + nnqr.GRU: nnqd.GRU, +} + +# Mapping from reference module class to the replacement weight only quantized module class for lowering +# TODO: correct the namespace for these modules +WEIGHT_ONLY_LOWER_MODULE_MAP: dict[type[nn.Module], type[nn.Module]] = { + nnqr.Embedding: nnq.Embedding, + nnqr.EmbeddingBag: nnq.EmbeddingBag, +} + +# TODO: merge with STATIC_LOWER_MODULE_MAP after we merge +# _lower_static_weighted_ref_module and special_pattern_replacement +SPECIAL_PATTERN_LOWER_MODULE_MAP = { + nn.BatchNorm2d: nnq.BatchNorm2d, + nn.BatchNorm3d: nnq.BatchNorm3d, + nnqr.ConvTranspose1d: nnq.ConvTranspose1d, + nnqr.ConvTranspose2d: nnq.ConvTranspose2d, + nnqr.ConvTranspose3d: nnq.ConvTranspose3d, + nn.ELU: nnq.ELU, + nn.LeakyReLU: nnq.LeakyReLU, + nn.Hardswish: nnq.Hardswish, + nn.InstanceNorm1d: nnq.InstanceNorm1d, + nn.InstanceNorm2d: nnq.InstanceNorm2d, + nn.InstanceNorm3d: nnq.InstanceNorm3d, + nn.LayerNorm: nnq.LayerNorm, + nn.Dropout: nnq.Dropout, + nn.Softmax: nnq.Softmax, + nn.PReLU: nnq.PReLU, + nni.BNReLU2d: nniq.BNReLU2d, + nni.BNReLU3d: nniq.BNReLU3d, +} + +# Mapping from fused module class to a 2-tuple of: +# 1) The inner reference module class +# 2) The replacement static quantized module class for lowering +STATIC_LOWER_FUSED_MODULE_MAP: dict[ + type[nn.Module], tuple[type[nn.Module], type[WeightedQuantizedModule]] +] = { + nni.LinearReLU: (nnqr.Linear, nniq.LinearReLU), + # TODO: LinearLeakyReLU is registered as global but it is only fused and + # lowered when ondnn's backend config is used. Maybe need to separate + # registration and lowering functions for different backends in the future. + nni.LinearLeakyReLU: (nnqr.Linear, nniq.LinearLeakyReLU), + nni.LinearTanh: (nnqr.Linear, nniq.LinearTanh), + nni.ConvReLU1d: (nnqr.Conv1d, nniq.ConvReLU1d), + nni.ConvReLU2d: (nnqr.Conv2d, nniq.ConvReLU2d), + nni.ConvReLU3d: (nnqr.Conv3d, nniq.ConvReLU3d), +} + +# The difference between STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP and STATIC_LOWER_FUSED_MODULE_MAP: +# The refer node inside STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP has 2 inputs. +# Mapping from fused module class to a 2-tuple of: +# 1) The inner reference module class +# 2) The replacement static quantized module class for lowering +STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP: dict[ + type[nn.Module], tuple[type[nn.Module], type[WeightedQuantizedModule]] +] = { + nni.ConvAdd2d: (nnqr.Conv2d, nniq.ConvAdd2d), + nni.ConvAddReLU2d: (nnqr.Conv2d, nniq.ConvAddReLU2d), +} + +# Mapping from fused module class to a 2-tuple of: +# 1) The inner reference module class +# 2) The replacement dynamic quantized module class for lowering +DYNAMIC_LOWER_FUSED_MODULE_MAP: dict[ + type[nn.Module], tuple[type[nn.Module], type[nn.Module]] +] = { + nni.LinearReLU: (nnqr.Linear, nniqd.LinearReLU), +} + +# Mapping from a functional to lower to a 2-tuple of +# 1) The quantized version of the op +# 2) The quantized version of the op fused with relu, if it exists, else None +STATIC_LOWER_FUNCTIONAL_MAP: dict[Callable, tuple[Callable, Callable | None]] = { + F.linear: (torch.ops.quantized.linear, torch.ops.quantized.linear_relu), + F.conv1d: (torch.ops.quantized.conv1d, torch.ops.quantized.conv1d_relu), + F.conv2d: (torch.ops.quantized.conv2d, torch.ops.quantized.conv2d_relu), + F.conv3d: (torch.ops.quantized.conv3d, torch.ops.quantized.conv3d_relu), + F.conv_transpose1d: (torch.ops.quantized.conv_transpose1d, None), + F.conv_transpose2d: (torch.ops.quantized.conv_transpose2d, None), + F.conv_transpose3d: (torch.ops.quantized.conv_transpose3d, None), +} + +WEIGHT_PREPACK_OPS: set[Callable] = { + torch._ops.ops.quantized.linear_prepack, + torch._ops.ops.quantized.linear_prepack_fp16, + torch._ops.ops.quantized.conv1d_prepack, + torch._ops.ops.quantized.conv2d_prepack, + torch._ops.ops.quantized.conv3d_prepack, + torch.ops.quantized.conv_transpose1d_prepack, + torch.ops.quantized.conv_transpose2d_prepack, + torch.ops.quantized.conv_transpose3d_prepack, +} + +# Mapping from a functional to a dictionary, where the key is a 2-tuple of +# (input_activation_dtype, weight_dtype) and the value is a 2-tuple of +# 1) The dynamically quantized version of the op +# 2) The dynamically quantized version of the op fused with relu, if it exists, else None +DYNAMIC_LOWER_FUNCTIONAL_MAP: dict[ + Callable, dict[tuple[torch.dtype, torch.dtype], tuple[Callable, Callable | None]] +] = { + F.linear: { + (torch.quint8, torch.qint8): ( + torch.ops.quantized.linear_dynamic, + torch.ops.quantized.linear_relu_dynamic, + ), + (torch.float16, torch.float16): ( + torch.ops.quantized.linear_dynamic_fp16, + torch.ops.quantized.linear_relu_dynamic_fp16, + ), + }, + # dynamic conv + relu is not available yet + F.conv1d: { + (torch.quint8, torch.qint8): (torch.ops.quantized.conv1d_dynamic, None), + }, + F.conv2d: { + (torch.quint8, torch.qint8): (torch.ops.quantized.conv2d_dynamic, None), + }, + F.conv3d: { + (torch.quint8, torch.qint8): (torch.ops.quantized.conv3d_dynamic, None), + }, +} + +CONV_FUNCTIONAL_OPS: set[Callable] = { + F.conv1d, + F.conv2d, + F.conv3d, +} + +CONV_TRANSPOSE_FUNCTIONAL_OPS: set[Callable] = { + F.conv_transpose1d, + F.conv_transpose2d, + F.conv_transpose3d, +} + +# TODO: add tests for lowering these ops +QBIN_OP_MAPPING: dict[Callable | str, Callable] = { + operator.add: torch.ops.quantized.add, + torch.add: torch.ops.quantized.add, + operator.mul: torch.ops.quantized.mul, + operator.matmul: torch.ops.quantized.matmul, + torch.mul: torch.ops.quantized.mul, + torch.matmul: torch.ops.quantized.matmul, +} +QBIN_RELU_OP_MAPPING: dict[Callable | str, Callable] = { + operator.add: torch.ops.quantized.add_relu, + torch.add: torch.ops.quantized.add_relu, + operator.mul: torch.ops.quantized.mul_relu, + torch.mul: torch.ops.quantized.mul_relu, +} + +ORIGINAL_WEIGHTS_LOOKUP = "original_weights_lookup" + + +def _save_packed_weight(self, destination, prefix, keep_vars): + for attr_name in dir(self): + if "_packed_weight" in attr_name and isinstance( + getattr(self, attr_name), torch._C.ScriptObject + ): # type: ignore[attr-defined] + packed_weight = getattr(self, attr_name) + destination[prefix + attr_name] = packed_weight + + +def _load_packed_weight( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + attrs_to_pop = [] + for attr_name in state_dict: + if attr_name.startswith("_packed_weight") and isinstance( + state_dict[attr_name], torch._C.ScriptObject + ): # type: ignore[attr-defined] # noqa: B950 + setattr(self, attr_name, state_dict[attr_name]) + attrs_to_pop.append(attr_name) + + # pop the packed param attributesn + for attr_name in attrs_to_pop: + state_dict.pop(attr_name) + + +def fold_weight( + quantized_model: GraphModule, + node_name_to_scope: dict[str, tuple[str, type]], + keep_original_weights: bool = False, +) -> GraphModule: + """ + Trace back from the weight node util we hit getattr, reconstruct the + graph module with the traced nodes and run the graph module to pack the + weight. then replace the original chain of ops with the packed weight. + """ + packed_weights = {} + # map from folded node name to the prepacked weight name + folded_nodes = {} + original_weights_lookup: dict[str, list] = {} + lookup_counter = 0 + # get packed weights + for node in quantized_model.graph.nodes: + if node.op == "call_function" and node.target in WEIGHT_PREPACK_OPS: + nodes_to_fold = collect_producer_nodes(node) + if nodes_to_fold is not None: + for node_to_fold in nodes_to_fold: + folded_nodes[node_to_fold.name] = node + + prepacking_module = graph_module_from_producer_nodes( + quantized_model, nodes_to_fold + ) + packed_weight = prepacking_module() + packed_weights[node.name] = packed_weight + if keep_original_weights: + original_weights = list(prepacking_module.state_dict().values()) + original_weights_lookup[str(lookup_counter)] = sorted( + original_weights, key=lambda x: x.numel(), reverse=True + ) + if len(original_weights_lookup[str(lookup_counter)]) == 1: + # bias is None + original_weights_lookup[str(lookup_counter)].append(None) + lookup_counter += 1 + lookup_counter = 0 + + # remove folded nodes and replace the prepacking node with getattr + folded_graph = Graph() + env: dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node.name]) + + for node in quantized_model.graph.nodes: + prepack_node = folded_nodes.get(node.name, None) + if prepack_node is node: + packed_weight = packed_weights[node.name] + # add a prepacked attribute to root + op_node = next(iter(prepack_node.users)) + module_path, _ = node_name_to_scope[op_node.name] + get_new_packed_weight_name = get_new_attr_name_with_prefix( + module_path + "_packed_weight_" + ) + packed_weight_name = get_new_packed_weight_name(quantized_model) + setattr(quantized_model, packed_weight_name, packed_weight) + # replace prepack node with a getattr node + env[node.name] = folded_graph.create_node( + "get_attr", packed_weight_name, (), {} + ) + if keep_original_weights: + key_name = ( + packed_weight_name.replace(":", "_") + .replace("/", "_") + .replace("|", "_") + .replace(" ", "") + .lower() + ) + original_weights_lookup[key_name] = original_weights_lookup[ + str(lookup_counter) + ] + del original_weights_lookup[str(lookup_counter)] + lookup_counter += 1 + elif prepack_node is not None: + # remove the fold node + continue + else: + # copy other nodes + env[node.name] = folded_graph.node_copy(node, load_arg) + + quantized_model = GraphModule(quantized_model, folded_graph) + quantized_model._register_state_dict_hook(_save_packed_weight) + quantized_model.register_load_state_dict_pre_hook(_load_packed_weight) + + if keep_original_weights: + setattr( # noqa: B010 + quantized_model, ORIGINAL_WEIGHTS_LOOKUP, original_weights_lookup + ) + + return quantized_model + + +def _get_module(node: Node, modules: dict[str, nn.Module]) -> nn.Module | None: + """ + Return the `torch.nn.Module` that corresponds to the specified node's target. + If no such node exists, return None. + """ + if node.op == "call_module" and str(node.target) in modules: + return modules[str(node.target)] + else: + return None + + +def _match_static_pattern( + node: Node, + modules: dict[str, nn.Module], + qconfig_map: dict[str, QConfigAny], + matching_modules_or_ops: list[Callable], + dequantize_node_arg_indices: list[int], +) -> tuple[Node, Node, Node] | tuple[None, None, None]: + """ + Match the pattern (dequantize - ref node - quantize) against the node provided. + + If there is a match, return a 3-tuple of: + 1) q_node: the quantize node, + 2) relu_node: a relu node wrapping the ref_node, and + 3) ref_node: a reference module or functional node to replace with its quantized counterpart + Otherwise, if there is no match, return a 3-tuple of (None, None, None). + + Parameters: + node: The `torch.fx.Node` to match against. + modules: A mapping from node names to modules in the model graph, used for module lookup. + qconfig_map: A mapping from node names to the qconfigs associated with the nodes. + If the corresponding qconfig for the reference node is None, then return no match. + matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s. + If the reference node is not in this list, then return no match. + dequantize_node_arg_indices: A list of indices in the reference node args where dequantize + nodes may be present. An empty list means skipping the check for dequantize nodes. + """ + SKIP_LOWERING_VALUE = (None, None, None) + + # Match quantize node + if node.op != "call_function" or node.target != torch.quantize_per_tensor: + return SKIP_LOWERING_VALUE + q_node = node + ref_node = q_node.args[0] + if not isinstance(ref_node, Node): + raise AssertionError("Expected the reference node to be a torch.fx Node") + + # Handle cases where the node is wrapped in a ReLU + if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or ( + ref_node.op == "call_module" and type(_get_module(ref_node, modules)) is nn.ReLU + ): + relu_node = ref_node + ref_node = relu_node.args[0] + if not isinstance(ref_node, Node): + raise AssertionError( + "Expected the reference node after ReLU to be a torch.fx Node" + ) + else: + relu_node = None + if should_skip_lowering(ref_node, qconfig_map): + return SKIP_LOWERING_VALUE + + # Match reference module or functional + if isinstance(matching_modules_or_ops[0], type) and issubclass( + matching_modules_or_ops[0], nn.Module + ): + expected_op = "call_module" + match_key = type(_get_module(ref_node, modules)) + else: + expected_op = "call_function" + match_key = ref_node.target # type: ignore[assignment] + if ref_node.op != expected_op or match_key not in matching_modules_or_ops: + return SKIP_LOWERING_VALUE + + # Match dequantize node(s). Both of the following conditions must pass: + # (1) All `torch.fx.Node`s at the matching indices must be a dequantize node + # (2) There must be at least one dequantize node + matched_dequantize = False + for i in dequantize_node_arg_indices: + if i >= len(ref_node.args): + raise AssertionError( + f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}" + ) + arg = ref_node.args[i] + if is_dequantize_node(arg): + matched_dequantize = True + elif isinstance(arg, Node): + return SKIP_LOWERING_VALUE + if not matched_dequantize: + return SKIP_LOWERING_VALUE + + return (q_node, relu_node, ref_node) # type: ignore[return-value] + + +def _match_static_pattern_with_two_inputs( + node: Node, + modules: dict[str, nn.Module], + qconfig_map: dict[str, QConfigAny], + matching_modules_or_ops: list[Callable], +) -> tuple[Node, Node] | tuple[None, None]: + """ + (dequantize \ + Match the pattern (dequantize - ref node - quantize) against the node provided. + + If there is a match, return a 2-tuple of: + 1) q_node: the quantize node, + 2) ref_node: a reference module or functional node to replace with its quantized counterpart + Otherwise, if there is no match, return a 2-tuple of (None, None). + + Parameters: + node: The `torch.fx.Node` to match against. + modules: A mapping from node names to modules in the model graph, used for module lookup. + qconfig_map: A mapping from node names to the qconfigs associated with the nodes. + If the corresponding qconfig for the reference node is None, then return no match. + matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s. + If the reference node is not in this list, then return no match. + """ + SKIP_LOWERING_VALUE = (None, None) + + # Match quantize node + if node.op != "call_function" or node.target != torch.quantize_per_tensor: + return SKIP_LOWERING_VALUE + q_node = node + ref_node = q_node.args[0] + if not isinstance(ref_node, Node): + raise AssertionError("Expected the reference node to be a torch.fx Node") + + if should_skip_lowering(ref_node, qconfig_map): + return SKIP_LOWERING_VALUE + + # Match reference module or functional + if isinstance(matching_modules_or_ops[0], type) and issubclass( + matching_modules_or_ops[0], nn.Module + ): + expected_op = "call_module" + match_key = type(_get_module(ref_node, modules)) + else: + # This pass only support op of "call_module" + return SKIP_LOWERING_VALUE + + if ref_node.op != expected_op or match_key not in matching_modules_or_ops: + return SKIP_LOWERING_VALUE + + # Check ref_node has 2 input nodes, both are dq node. + if len(ref_node.args) != 2: + return SKIP_LOWERING_VALUE + for i in range(len(ref_node.args)): + arg = ref_node.args[i] + if not is_dequantize_node(arg): + return SKIP_LOWERING_VALUE + + return (q_node, ref_node) + + +def _lower_static_weighted_ref_module( + model: GraphModule, qconfig_map: dict[str, QConfigAny] +): + """ + Traverse the graph and find dequantize - ref module - quantize patterns + and replace them with the quantized version of the ref module. + """ + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) + matching_modules = list(STATIC_LOWER_MODULE_MAP.keys()) + list( + STATIC_LOWER_FUSED_MODULE_MAP.keys() + ) + q_node, _relu_node, ref_node = _match_static_pattern( + n, + modules, + qconfig_map, + matching_modules, # type: ignore[arg-type] + dequantize_node_arg_indices=[0], + ) + if q_node is None: + continue + if ref_node is None: + raise AssertionError( + "Expected a reference node when matching static pattern" + ) + (_, scale_node, zero_point_node, _) = q_node.args + ref_module = _get_module(ref_node, modules) + ref_class = type(ref_module) + if not isinstance(scale_node, Node): + raise AssertionError("Expected scale_node to be a Node") + if not isinstance(zero_point_node, Node): + raise AssertionError("Expected zero_point_node to be a Node") + if not issubclass(ref_class, nn.Module): + raise AssertionError( + "Expected reference module class to be a subclass of nn.Module" + ) + + # Step 1: Change this pattern to use the corresponding quantized module + # For fused modules, we also check whether the inner module is a reference module + # If so, we replace the entire fused module with the corresponding quantized module + if ref_class in STATIC_LOWER_FUSED_MODULE_MAP: + inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_MAP[ref_class] + if type(ref_module[0]) is not inner_ref_class: # type: ignore[index] + continue + else: + q_class = STATIC_LOWER_MODULE_MAP[ref_class] + output_scale = getattr(model, scale_node.target) # type: ignore[arg-type] + output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type] + q_module = q_class.from_reference(ref_module, output_scale, output_zero_point) + # replace reference module with quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(modules[parent_name], module_name, q_module) + + # Step 2: Reroute around dq_node, and remove q_node and its args + if len(ref_node.args) != 1: + raise AssertionError("Expected reference node to have exactly 1 arg") + dq_node = ref_node.args[0] + if not isinstance(dq_node, Node): + raise AssertionError("Expected dq_node to be a Node") + ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type] + q_node.replace_all_uses_with(ref_node) + model.graph.erase_node(q_node) + model.graph.erase_node(scale_node) + model.graph.erase_node(zero_point_node) + + +def _lower_static_weighted_ref_module_with_two_inputs( + model: GraphModule, qconfig_map: dict[str, QConfigAny] +): + """ + Traverse the graph and find patterns + dequantize dequantize + \\ // + ref module + \\ + quantize + and replace them with the quantized version of the ref module. + """ + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + # (dequantize \ + # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) + matching_modules = list(STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP.keys()) + (q_node, ref_node) = _match_static_pattern_with_two_inputs( + n, + modules, + qconfig_map, + matching_modules, # type: ignore[arg-type] + ) + if q_node is None: + continue + if ref_node is None: + raise AssertionError( + "Expected a reference node when matching static pattern with two inputs" + ) + (_, scale_node, zero_point_node, _) = q_node.args + ref_module = _get_module(ref_node, modules) + ref_class = type(ref_module) + if not isinstance(scale_node, Node): + raise AssertionError("Expected scale_node to be a Node") + if not isinstance(zero_point_node, Node): + raise AssertionError("Expected zero_point_node to be a Node") + if not issubclass(ref_class, nn.Module): + raise AssertionError( + "Expected reference module class to be a subclass of nn.Module" + ) + + # Step 1: Change this pattern to use the corresponding quantized module + # For fused modules, we also check whether the inner module is a reference module + # If so, we replace the entire fused module with the corresponding quantized module + if ref_class in STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP: + inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP[ + ref_class + ] + if type(ref_module[0]) is not inner_ref_class: # type: ignore[index] + continue + else: + continue + output_scale = getattr(model, scale_node.target) # type: ignore[arg-type] + output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type] + q_module = q_class.from_reference(ref_module, output_scale, output_zero_point) + # replace reference module with quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(modules[parent_name], module_name, q_module) + + # Step 2: Reroute around dq_node, and remove q_node and its args + if len(ref_node.args) != 2: + raise AssertionError("Expected reference node to have exactly 2 args") + for arg in ref_node.args: + if not is_dequantize_node(arg): + continue + dq_node = arg + if not isinstance(dq_node, Node): + raise AssertionError("Expected dq_node to be a Node") + ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type] + + q_node.replace_all_uses_with(ref_node) + model.graph.erase_node(q_node) + model.graph.erase_node(scale_node) + model.graph.erase_node(zero_point_node) + + +def _lower_dynamic_weighted_ref_module(model: GraphModule): + """ + Traverse the graph and find quantize_per_tensor_dynamic - dequantize - ref_module patterns + and replace them with the dynamically quantized version of the ref module. + """ + named_modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + if n.op != "call_module" or type(named_modules[str(n.target)]) not in set( + DYNAMIC_LOWER_MODULE_MAP.keys() + ).union(set(DYNAMIC_LOWER_FUSED_MODULE_MAP.keys())): + continue + ref_node = n + dq_node = ref_node.args[0] + if dq_node.op != "call_method" or dq_node.target != "dequantize": + continue + + input_dynamic_q_node = dq_node.args[0] + + if ( + input_dynamic_q_node.op != "call_function" + or input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic + ): + continue + + activation_dtype = input_dynamic_q_node.args[1] + is_fp16 = activation_dtype == torch.float16 + is_int8 = activation_dtype in [torch.quint8, torch.qint8] + if not is_int8 and not is_fp16: + continue + + ref_module = named_modules[str(ref_node.target)] + ref_class = type(ref_module) + if ref_class in DYNAMIC_LOWER_FUSED_MODULE_MAP: + inner_ref_class, q_class = DYNAMIC_LOWER_FUSED_MODULE_MAP[ref_class] + if type(ref_module[0]) is not inner_ref_class: + continue + else: + q_class = DYNAMIC_LOWER_MODULE_MAP.get(ref_class) # type: ignore[assignment] + # TODO: maybe define a WeightedDynamicallyQuantizedModule + q_module = q_class.from_reference(ref_module) # type: ignore[attr-defined] + + # replace reference module with dynamically quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(named_modules[parent_name], module_name, q_module) + ref_node.replace_input_with(dq_node, input_dynamic_q_node.args[0]) + + +def _lower_weight_only_weighted_ref_module(model: GraphModule): + """ + Traverse the graph and find ref_module patterns + and replace them with the weight only quantized version of the ref module. + """ + named_modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + if n.op != "call_module" or type(named_modules[str(n.target)]) not in set( + WEIGHT_ONLY_LOWER_MODULE_MAP.keys() + ): + continue + ref_node = n + ref_module = named_modules[str(ref_node.target)] + ref_class = type(ref_module) + q_class = WEIGHT_ONLY_LOWER_MODULE_MAP.get(ref_class) + # TODO: WeightedQuantizedModule is currently assuming static quant apis + # with output_scale, output_zero_point in from_reference, we may want to + # relax that, or rename this + # TODO: maybe define a WeightedWeightOnlyQuantizedModule + q_module = q_class.from_reference(ref_module) # type: ignore[union-attr] + + # replace reference module with dynamically quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(named_modules[parent_name], module_name, q_module) + + +def _lower_static_weighted_ref_functional( + model: GraphModule, qconfig_map: dict[str, QConfigAny] +): + """ + Traverse the graph and replace functional reference patterns with their quantized versions. + """ + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + # Step 0: Find nodes that match this pattern (dequantize - functional op - quantize) + matching_ops = list(STATIC_LOWER_FUNCTIONAL_MAP.keys()) + (q_node, relu_node, func_node) = _match_static_pattern( + n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1] + ) + if q_node is None: + continue + if func_node is None: + raise AssertionError( + "Expected a function node when matching static functional pattern" + ) + (_, output_scale_node, output_zp_node, _) = q_node.args + (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args + if not isinstance(output_zp_node, Node): + raise AssertionError("Expected output_zp_node to be a Node") + if not isinstance(input_dq_node, Node): + raise AssertionError("Expected input_dq_node to be a Node") + if not isinstance(weight_dq_node, Node): + raise AssertionError("Expected weight_dq_node to be a Node") + quantized_weight = weight_dq_node.args[0] + if not isinstance(quantized_weight, Node): + raise AssertionError("Expected quantized_weight to be a Node") + if quantized_weight.op != "call_function" or quantized_weight.target not in ( + torch.quantize_per_tensor, + torch.quantize_per_channel, + ): + continue + + # Step 1: Replace quantized weights with packed weights, which will be folded later + # Use the right prepack op and prepare the corresponding args + # Linear prepack args: (quantized weights[, bias]) + # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) + prepack_args = [quantized_weight] + remaining_func_args + if func_node.target is F.linear: + weight_dtype = quantized_weight.args[-1] + prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) + elif func_node.target in CONV_FUNCTIONAL_OPS: + prepack_op = get_qconv_prepack_op(func_node.target) # type: ignore[arg-type] + # For conv1d, the stride, padding, and dilation args may be ints, + # in which case we need to convert them to tuples + if func_node.target is F.conv1d: + for i in [2, 3, 4]: + if len(prepack_args) > i and isinstance(prepack_args[i], int): + prepack_args[i] = (prepack_args[i],) + elif func_node.target in CONV_TRANSPOSE_FUNCTIONAL_OPS: + prepack_op = get_qconv_prepack_op(func_node.target) # type: ignore[arg-type] + # For conv_transpose1d, the stride, padding, and dilation args may be ints, + # in which case we need to convert them to tuples + if func_node.target is F.conv_transpose1d: + # Note prepack_args[5] is groups. + for i in [2, 3, 4, 6]: + if len(prepack_args) > i and isinstance(prepack_args[i], int): + prepack_args[i] = (prepack_args[i],) + # swap dilation and groups + # prepack op has arguments: {w, b, stride, padding, output_padding, dilation, groups} + # transposed conv op has arguments: {x, w, b, stride, padding, output_padding, groups, dilation} + if len(prepack_args) > 6: + prepack_args[5], prepack_args[6] = prepack_args[6], prepack_args[5] + else: + raise ValueError(f"Lowering is not supported for op '{func_node.target}'") + with model.graph.inserting_before(output_scale_node): # type: ignore[arg-type] + # kwargs of the func node are needed for prepack op (i.e., quantized::linear_prepack) + # They are not needed for compute op (i.e., quantized::linear) + kwargs = func_node.kwargs + # F.linear uses 'bias' key for bias while qlinear_prepack uses 'B' for bias + if func_node.target is F.linear and "bias" in kwargs: + kwargs = kwargs.copy() + kwargs["B"] = kwargs["bias"] + del kwargs["bias"] + packed_weight = model.graph.create_node( + "call_function", prepack_op, tuple(prepack_args), kwargs + ) + + # Step 2: Replace reference pattern with the corresponding quantized op + (q_func, q_relu_func) = STATIC_LOWER_FUNCTIONAL_MAP[func_node.target] # type: ignore[index] + # conv_transpose does not support fusion with relu yet. q_relu_func is None in such cases + if q_relu_func is not None: + func_node.target = q_relu_func if relu_node is not None else q_func + else: + func_node.target = q_func + func_node.args = ( + input_dq_node.args[0], + packed_weight, + output_scale_node, + output_zp_node, + ) + # kwargs for func_node has been moved to kwargs for prepack op + func_node.kwargs = {} + q_node.replace_all_uses_with(func_node) + # Move func_node after output_zp_node in the graph + output_zp_node.append(func_node) + + # Clean up: Remove quantize node, and the relu node if it exists + model.graph.erase_node(q_node) + if relu_node is not None and q_relu_func is not None: + model.graph.erase_node(relu_node) + + +def _lower_dynamic_weighted_ref_functional( + model: GraphModule, qconfig_map: dict[str, QConfigAny] +): + """ + Traverse the graph and replace functional reference patterns with their dynamically + quantized versions. + Examples: + quantize_per_tensor_dynamic - dequantize - functional linear --> linear_dynamic + to(torch.float16) - dequantize - functional linear --> linear_dynamic_fp16 + """ + modules = dict(model.named_modules(remove_duplicate=False)) + # we want to search in reserved order so that we can match the larger patterns first + # e.g. we want to match linear - relu before linear. + for n in reversed(model.graph.nodes): + # Step 0: Find nodes that match this pattern + # (quantize_per_tensor_dynamic - dequantize - dynamically quantized op) + # We search for the pattern backwards, starting with the quantize node + # Quantize node args: (func, scale, zp, dtype) + func_node = n + # Handle cases where the functional op is wrapped in a ReLU + if ( + func_node.op == "call_function" + and func_node.target is F.relu + or func_node.op == "call_module" + and type(modules[str(func_node.target)]) is torch.nn.ReLU + ): + relu_node = func_node + func_node = relu_node.args[0] + else: + relu_node = None + if should_skip_lowering(func_node, qconfig_map): + continue + # Linear args: (dequantized inputs, dequantized weights[, bias]) + # Conv args: (dequantized inputs, dequantized weights[, bias, stride, padding, dilation, groups]) + if ( + func_node.op != "call_function" + or func_node.target not in DYNAMIC_LOWER_FUNCTIONAL_MAP + ): + continue + (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args + if ( + input_dq_node.op != "call_method" + or input_dq_node.target != "dequantize" + or weight_dq_node.op != "call_method" + or weight_dq_node.target != "dequantize" + ): + continue + + input_dynamic_q_node = input_dq_node.args[0] + + if ( + input_dynamic_q_node.op != "call_function" + or input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic + ): + continue + + reduce_range_node = None + (pattern_input, activation_dtype, reduce_range_node) = input_dynamic_q_node.args + is_fp16 = activation_dtype == torch.float16 + is_int8 = activation_dtype in [torch.quint8, torch.qint8] + if not is_int8 and not is_fp16: + continue + + quantized_weight = weight_dq_node.args[0] + weight_dtype = quantized_weight.args[-1] + + # Step 1: Try to select reference pattern with the corresponding quantized op + dynamic_quant_dtype_key = (activation_dtype, weight_dtype) + if ( + dynamic_quant_dtype_key + not in DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target] + ): + print( + f"Didn't find dtype combination {dynamic_quant_dtype_key} during " + f"dynamic quantized op lowering for {func_node.target}" + ) + continue + (q_func, q_relu_func) = DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target][ + dynamic_quant_dtype_key + ] + + if q_func is None or q_relu_func is None: + print( + "Didn't find corresponding quantized function or quantized relu function " + f"for {func_node.target}, {dynamic_quant_dtype_key}" + ) + continue + + # Step 2: Replace quantized weights with packed weights, which will be folded later + # Use the right prepack op and prepare the corresponding args + # Linear prepack args: (quantized weights[, bias]) + # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) + prepack_args = [quantized_weight] + remaining_func_args + prepack_kwargs = {} + if func_node.target is F.linear: + prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) + kwargs = func_node.kwargs.copy() + if "bias" in kwargs: + prepack_kwargs["B"] = kwargs["bias"] + del kwargs["bias"] + func_node.kwargs = kwargs + elif func_node.target in CONV_FUNCTIONAL_OPS: + prepack_op = get_qconv_prepack_op(func_node.target) + # For conv1d, the stride, padding, and dilation args may be ints, + # in which case we need to convert them to tuples + if func_node.target is F.conv1d: + for i in [2, 3, 4]: + if len(prepack_args) > i and isinstance(prepack_args[i], int): + prepack_args[i] = (prepack_args[i],) + else: + raise ValueError(f"Lowering is not supported for op '{func_node.target}'") + with model.graph.inserting_before(func_node): + packed_weight = model.graph.create_node( + "call_function", prepack_op, tuple(prepack_args), prepack_kwargs + ) + + # Step 3: Replace reference pattern with the corresponding quantized op + func_node.target = q_relu_func if relu_node is not None else q_func + if is_int8: + func_node.args = (pattern_input, packed_weight, reduce_range_node) + else: + func_node.args = (pattern_input, packed_weight) + + if relu_node is not None: + relu_node.replace_all_uses_with(func_node) + + # Step 4: Remove the relu node if it exists + if relu_node is not None: + model.graph.erase_node(relu_node) + + +def _lower_quantized_binary_op(model: GraphModule, qconfig_map: dict[str, QConfigAny]): + binary_ops_to_lower: list[Callable] = [ + operator.add, + torch.add, + operator.mul, + torch.mul, + torch.matmul, + ] + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) + (q_node, relu_node, bop_node) = _match_static_pattern( + n, + modules, + qconfig_map, + binary_ops_to_lower, + dequantize_node_arg_indices=[0, 1], + ) + if q_node is None: + continue + if bop_node is None: + raise AssertionError( + "Expected a binary op node when matching quantized binary op pattern" + ) + (_, scale_node, zero_point_node, _) = q_node.args + + # Step 1: Remove dequant nodes + num_dq_nodes = 0 + for arg in bop_node.args: + if not is_dequantize_node(arg): + continue + dq_node = arg + if not isinstance(dq_node, Node): + raise AssertionError("Expected dq_node to be a Node") + dn_input = dq_node.args[0] + bop_node.replace_input_with(dq_node, dn_input) # type: ignore[arg-type] + num_dq_nodes += 1 + if num_dq_nodes <= 0: + raise AssertionError( + "Expected at least one dequantize node in binary op args" + ) + + # Step 2: Swap binary op to quantized binary op + if bop_node.target not in QBIN_OP_MAPPING: + raise AssertionError( + f"Unsupported binary op {bop_node.target} for lowering" + ) + binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING + qbin_op = binop_to_qbinop[bop_node.target] + # prepare the args for quantized binary op + # (x, y) + qop_node_args = list(bop_node.args) + # (x, y, scale, zero_point) + # add scale and zero_point arguments for Tensor - Tensor operation + if num_dq_nodes == 2: + qop_node_args.extend([scale_node, zero_point_node]) + # insert a call to quantized binary op and remove the original binary op + with model.graph.inserting_after(q_node): + qop_node = create_node_from_old_node_preserve_meta( + model.graph, + ("call_function", qbin_op, tuple(qop_node_args), {}), + bop_node, + ) + q_node.replace_all_uses_with(qop_node) + + # Step 3: Remove quantize node, binary op node, and relu node if any + model.graph.erase_node(q_node) + if relu_node is not None: + model.graph.erase_node(relu_node) + model.graph.erase_node(bop_node) + + +def special_pattern_replacement(model: GraphModule): + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + q_node = n + is_quantize = q_node.target is torch.quantize_per_tensor + is_to_fp16 = ( + q_node.op == "call_method" + and q_node.target == "to" + and len(q_node.args) == 2 + and q_node.args[1] == torch.float16 + ) + # Only continue when neither quantize nor to_fp16 + if not is_quantize and not is_to_fp16: + continue + ref_node = q_node.args[0] + # get output scale/zero_point/dtype from the quantize node + # ref_node, scale_node, zero_point_node, dtype = q_node.args + # TODO: add safety checks that users for the ref_node and dq_node needs to be one + is_call_function, is_call_method, is_call_module = is_fixed_qparams_node( + ref_node, modules + ) + if is_to_fp16 and (is_call_function or is_call_method or is_call_module): + # TODO: add a warning or error out here? (bc-breaking if error out) + # warnings.warn( + # "Only reference patterns are currently supported for {dtype} dtype with {op} op" + # "".format(dtype=dtypes, op=ref_node)) + continue + + is_call_function, is_call_method, is_call_module = is_default_node( + ref_node, modules + ) + if is_to_fp16 and (is_call_function or is_call_method or is_call_module): + # TODO: add a warning or error out here? (bc-breaking if error out) + continue + + # This check includes all supported ops + is_call_function, is_call_method, is_call_module = is_special_pattern_node( + ref_node, modules + ) + if not (is_call_module or is_call_function or is_call_method): + continue + if len(ref_node.args) <= 0 and len(ref_node.kwargs) <= 0: + raise AssertionError("Expected ref_node to have args or kwargs") + dq_node_or_nodes = ( + ref_node.args[0] + if len(ref_node.args) > 0 + else next(iter(ref_node.kwargs.values())) + ) + if not isinstance(dq_node_or_nodes, (Node, tuple, list)): + raise AssertionError( + "Expected dq_node_or_nodes to be a Node, tuple, or list" + ) + is_dequantize = False + if isinstance(dq_node_or_nodes, Node): + is_dequantize = ( + dq_node_or_nodes.op == "call_method" + and dq_node_or_nodes.target == "dequantize" + ) + elif isinstance(dq_node_or_nodes, (tuple, list)): + is_dequantize = all( + x.op == "call_method" and x.target == "dequantize" + for x in dq_node_or_nodes + ) + + if not is_dequantize: + continue + + # TODO: enable we have patterns that needs to swap the modules + if is_call_module: + ref_module = modules[ref_node.target] + if type(ref_module) in SPECIAL_PATTERN_LOWER_MODULE_MAP and is_quantize: + qmodule_cls = SPECIAL_PATTERN_LOWER_MODULE_MAP.get(type(ref_module)) + scale_node = q_node.args[1] + zero_point_node = q_node.args[2] + output_scale = getattr(model, scale_node.target) + output_zero_point = getattr(model, zero_point_node.target) + + qmodule = qmodule_cls.from_reference( # type:ignore[union-attr] + ref_module, output_scale, output_zero_point + ) + # replace reference module with quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(modules[parent_name], module_name, qmodule) + + # reroute around dq node: + dq_nodes: list[Node] = [] + if isinstance(dq_node_or_nodes, Node): + dq_nodes = [dq_node_or_nodes] + elif isinstance(dq_node_or_nodes, (tuple, list)): + dq_nodes = list(dq_node_or_nodes) + + for dq_node in dq_nodes: + dn_input = dq_node.args[0] + ref_node.replace_input_with(dq_node, dn_input) + + # store q node args + qnode_qparams = list(q_node.args)[1:] + # replace uses of q node with input and remove q node + q_node_input = q_node.args[0] + q_node.replace_all_uses_with(q_node_input) + model.graph.erase_node(q_node) + + is_call_function, is_call_method, is_call_module = is_default_node( + ref_node, modules + ) + if is_call_function: + # pass scale/zer_point arguments from quantize_per_tensor to the default node operator + # insert an op after the zero_point node so that the scale/zero_point + # nodes are is available + qop = get_quantized_operator(ref_node.target) + args = list(ref_node.args) + kwargs = dict(ref_node.kwargs) + if qop in QOP_TO_ARG_NAMES_TO_SKIP: + args_to_skip = QOP_TO_ARG_NAMES_TO_SKIP[qop] + for arg in args_to_skip: + if arg in kwargs: + kwargs.pop(arg) + kwargs["output_scale"] = qnode_qparams[0] + kwargs["output_zero_point"] = qnode_qparams[1] + with model.graph.inserting_after(qnode_qparams[1]): + qop_node = create_node_from_old_node_preserve_meta( + model.graph, ("call_function", qop, tuple(args), kwargs), ref_node + ) + ref_node.replace_all_uses_with(qop_node) + model.graph.erase_node(ref_node) + else: + # remove scale/zero_point node for quantize node + for n in qnode_qparams: + if isinstance(n, Node): + model.graph.erase_node(n) + + return model + + +def _lower_getattr_tensor_metadta_op(model: GraphModule): + """Modified the graph of the model inplace, to skip extra dequantize op before + the general tensor shape ops when possible + """ + for n in model.graph.nodes: + if is_getattr_tensor_metadata_node(n): + maybe_dq = n.args[0] + if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize": + continue + # skip the dequantize node + args = list(n.args) + args[0] = n.args[0].args[0] + n.args = tuple(args) + + +def _lower_get_tensor_info_op(model: GraphModule): + """Modified the graph of the model inplace, to skip extra dequantize op before + the general tensor shape ops when possible + """ + for n in model.graph.nodes: + if not is_get_tensor_info_node(n): + continue + maybe_dq = n.args[0] + if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize": + continue + # skip the dequantize node + args = list(n.args) + args[0] = n.args[0].args[0] + n.args = tuple(args) + + +def _lower_to_native_backend( + model: GraphModule, + qconfig_map: dict[str, QConfigAny], + node_name_to_scope: dict[str, tuple[str, type]], + keep_original_weights: bool = False, +) -> GraphModule: + """Lower a quantized reference model (with reference quantized operator patterns) + to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same + operator signature so they can be lowered with the same function + """ + _lower_static_weighted_ref_module(model, qconfig_map) + _lower_static_weighted_ref_module_with_two_inputs(model, qconfig_map) + _lower_dynamic_weighted_ref_module(model) + _lower_weight_only_weighted_ref_module(model) + _lower_static_weighted_ref_functional(model, qconfig_map) + _lower_dynamic_weighted_ref_functional(model, qconfig_map) + _lower_quantized_binary_op(model, qconfig_map) + _lower_getattr_tensor_metadta_op(model) + _lower_get_tensor_info_op(model) + special_pattern_replacement(model) + model.graph.eliminate_dead_code() + model = fold_weight(model, node_name_to_scope, keep_original_weights) + model.graph.eliminate_dead_code() + model.recompile() + model.graph.lint() + return model diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/convert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..9a19a40cab908baa78fffeb89f46eedc71976736 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/convert.py @@ -0,0 +1,1323 @@ +# mypy: ignore-errors + +import copy +import operator +import warnings +from typing import Any, TYPE_CHECKING + +import torch +from torch.ao.quantization import CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY +from torch.ao.quantization.backend_config import ( + BackendConfig, + get_native_backend_config, +) +from torch.ao.quantization.backend_config.utils import ( + get_fused_module_classes, + get_pattern_to_dtype_configs, + get_qat_module_classes, + get_root_module_to_quantized_reference_module, +) +from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.qconfig import qconfig_equals, QConfigAny +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quant_type import QuantType +from torch.ao.quantization.quantize import _remove_qconfig +from torch.ao.quantization.stubs import DeQuantStub +from torch.ao.quantization.utils import ( + _parent_name, + activation_is_statically_quantized, + get_qparam_dict, + get_swapped_custom_module_class, + is_per_channel, + to_underlying_dtype, + weight_is_quantized, +) +from torch.fx import GraphModule +from torch.fx.graph import Argument, Graph, Node +from torch.nn.utils.parametrize import type_before_parametrizations + +# importing the lib so that the quantized_decomposed ops are registered +from ._decomposed import quantized_decomposed_lib # noqa: F401 +from ._equalize import convert_eq_obs, update_obs_for_equalization +from .custom_config import ConvertCustomConfig, PrepareCustomConfig +from .graph_module import _is_observed_module, _is_observed_standalone_module +from .lower_to_fbgemm import lower_to_fbgemm +from .qconfig_mapping_utils import ( + _compare_prepare_convert_qconfig_mappings, + _generate_node_name_to_qconfig, + _is_qconfig_supported_by_dtype_configs, + _update_qconfig_for_fusion, + _update_qconfig_for_qat, +) +from .utils import ( + _get_module, + _is_custom_module_lstm, + _is_custom_module_mha, + assert_and_get_unique_device, + collect_producer_nodes, + create_getattr_from_value, + get_custom_module_class_keys, + graph_module_from_producer_nodes, + node_arg_is_weight, +) + + +if TYPE_CHECKING: + from collections.abc import Callable + + +__all__ = [ + "convert", + "convert_custom_module", + "convert_standalone_module", + "convert_weighted_module", +] + +SUPPORTED_QDTYPES = [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.uint8, + torch.int8, + torch.uint16, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, +] + +_QSCHEME_TO_CHOOSE_QPARAMS_OP = { + torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, +} + + +def _replace_observer_with_quantize_dequantize_node_decomposed( + model: torch.fx.GraphModule, + node: Node, + modules: dict[str, torch.nn.Module], + node_name_to_scope: dict[str, tuple[str, type]], + node_name_to_qconfig: dict[str, QConfigAny], + model_device: torch.device | None = None, +) -> None: + """Replace activation_post_process module call node with quantize and + dequantize node working with decomposed Tensor + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) -> + torch.ops.quantized_decomposed.dequantize_per_tensor() -> ... + + or quantize_per_channel and dequantize_per_channel + """ + graph = model.graph + if modules is None: + raise AssertionError("modules must not be None") + if not isinstance(node.target, str): + raise AssertionError( + f"Expected node.target to be a str, but got {type(node.target)}" + ) + module_path, prefix = _get_module_path_and_prefix( + node, node_name_to_scope, node_name_to_qconfig + ) + activation_post_process = modules[node.target] + if hasattr(activation_post_process, "convert"): + activation_post_process.convert(model, node) + return + # skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all( + _has_none_qconfig(n, node_name_to_qconfig) + for n in list(node.args) + list(node.users.keys()) + ) + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find corresponding quantize op and info for the activation_post_process + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + return + + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment] + + def add_dequantize_op_kwargs(dequantize_op, input_node): + dequantize_op_kwargs = {} + if "val" in input_node.meta: + dq_out_dtype = input_node.meta["val"].dtype + if dq_out_dtype != torch.float32: + dequantize_op_kwargs = {"out_dtype": dq_out_dtype} + return dequantize_op_kwargs + + if dtype in SUPPORTED_QDTYPES and (not is_dynamic): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op: Callable | None = None + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] + quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default + dequantize_op = ( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ) + quant_min = activation_post_process.quant_min + quant_max = activation_post_process.quant_max + dtype_ = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_axis_": ch_axis, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_, + } + else: + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + scale = float(scale) + zero_point = int(zero_point) + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + dtype_ = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_, + } + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"] and ( + not isinstance(value_or_node, (float, int)) + ): + # For scale and zero_point values we register them as buffers in the root module. + # However, note that when the values are not tensors, as in the case of + # per_tensor quantization, they will be treated as literals. + # However, registering them as a node seems to cause issue with dynamo + # tracing where it may consider tensor overload as opposed to default. + # With extra check of scale and zero_point being scalar, it makes + # sure that the default overload can be used. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, + ) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + add_dequantize_op_kwargs(dequantize_op, input_node), + ) + + node.replace_all_uses_with(dequantized_node) + # propagate numeric debug handle from observer/fake_quant node to dequantize node + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + if CUSTOM_KEY not in dequantized_node.meta: + dequantized_node.meta[CUSTOM_KEY] = {} + dequantized_node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + CUSTOM_KEY + ][NUMERIC_DEBUG_HANDLE_KEY] + graph.erase_node(node) + elif is_dynamic: + # uint8/int8/fp16 dynamic quantization + + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor + # we only use choose_qparams for is_decomposed now, + # but we should probably align the non-decomposed path with this as well, + # and that can be done after we remove reduce_range flag + # 1. extract qparams from activation_post_process module + dtype_ = to_underlying_dtype(dtype) + if dtype_ not in [torch.uint8, torch.int8]: + raise AssertionError( + "only uint8 and int8 are supported in reference flow for dynamic quantization right now" + ) + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined] + eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps) # type: ignore[attr-defined] + # note: scale and zero_point are missing for quantize_per_tensor op + # we'll need to get this from choose_qparams op, which we'll add after + # this step + qparams = { + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_eps_": eps, + "_dtype_": dtype_, + } + + choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme] + # 2. insert choose_qparams op and update the qparams list + with graph.inserting_before(node): + input_node = node.args[0] + choose_qparams_op_inputs = [node.args[0]] + list(qparams.values()) + choose_qparams_node = graph.create_node( + "call_function", choose_qparams_op, tuple(choose_qparams_op_inputs), {} + ) + # choose_qparms returns (scale, zero_point) + scale_node = graph.create_node( + "call_function", operator.getitem, (choose_qparams_node, 0), {} + ) + zero_point_node = graph.create_node( + "call_function", operator.getitem, (choose_qparams_node, 1), {} + ) + # we have quant_min, quant_max and dtype, all should be stored + # as literals + quant_min = qparams["_quant_min_"] + quant_max = qparams["_quant_max_"] + dtype = qparams["_dtype_"] + qparams = { + "_scale_": scale_node, + "_zero_point_": zero_point_node, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype, + } + + # 3. replace activation_post_process node to quantize and dequantize node + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"]: + # in this case we have a node in the graph since it's dynamically + # computed from the input, with choose_qparams op + qparam_node = value_or_node + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we + # store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + # need to use the tensor variant of this op, since scale and zero_point + # from choose_qparam are Tensors, instead of float/int, this is to + # prevent these nodes being traced away by downstream systems + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + add_dequantize_op_kwargs(dequantize_op, input_node), + ) + + node.replace_all_uses_with(dequantized_node) + # propagate numeric debug handle from observer/fake_quant node to dequantize node + if NUMERIC_DEBUG_HANDLE_KEY in node.meta: + dequantized_node.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + NUMERIC_DEBUG_HANDLE_KEY + ] + graph.erase_node(node) + elif dtype == torch.float16: + # Insert to_fp16 -> to_fp32 node + dtype_convert_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse + with graph.inserting_before(node): + input_node = node.args[0] + convert_fp16_node = graph.create_node( + "call_function", dtype_convert_op, (input_node, torch.float16), {} + ) + convert_fp32_node = graph.create_node( + "call_function", dtype_convert_op, (convert_fp16_node, torch.float), {} + ) + node.replace_all_uses_with(convert_fp32_node) + graph.erase_node(node) + + # should not reach since we have checks in the beginning to make sure the + # activation_post_process is supported + + +def _replace_observer_with_quantize_dequantize_node( + model: torch.fx.GraphModule, + node: Node, + modules: dict[str, torch.nn.Module], + node_name_to_scope: dict[str, tuple[str, type]], + node_name_to_qconfig: dict[str, QConfigAny], + model_device: torch.device | None = None, +) -> None: + """Replace activation_post_process module call node with quantize and + dequantize node + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... + """ + if modules is None: + raise AssertionError("modules must not be None") + if not isinstance(node.target, str): + raise AssertionError( + f"Expected node.target to be a str, but got {type(node.target)}" + ) + graph = model.graph + module_path, prefix = _get_module_path_and_prefix( + node, node_name_to_scope, node_name_to_qconfig + ) + activation_post_process = modules[node.target] + # skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all( + _has_none_qconfig(n, node_name_to_qconfig) + for n in list(node.args) + list(node.users.keys()) + ) + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find corresponding quantize op and info for the activation_post_process + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + return + + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] + + if dtype in [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] and (not is_dynamic): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + node_type = "call_function" + quantize_op: Callable | None = None + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_axis_": ch_axis, + "_dtype_": dtype, + } + quantize_op = torch.quantize_per_channel + else: + scale = float(scale) + zero_point = int(zero_point) + qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} + quantize_op = torch.quantize_per_tensor + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"]: + # For scale and zero_point values we register them as buffers in the root module. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, + ) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif is_dynamic: + # uint8/int8/fp16 dynamic quantization branch + + node_type = "call_function" + quantize_op = torch.quantize_per_tensor_dynamic + # TODO: get reduce range from observer + # reduce_range = activation_post_process.reduce_range + reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86") + qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range} + + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for value in qparams.values(): + quantize_op_inputs.append(value) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif dtype == torch.float16: + node_type = "call_method" + quantize_op = "to" # type: ignore[assignment] + qparams = {"_dtype_": dtype} + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for value in qparams.values(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + quantize_op_inputs.append(value) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + + # should not reach since we have checks in the beginning to make sure the + # activation_post_process is supported + + +# this is a temporary hack for custom module, we may want to implement +# this properly after the custom module class design is finalized +# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted +# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs +# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively. +def _replace_observer_or_dequant_stub_with_dequantize_node( + node: Node, graph: Graph +) -> None: + call_custom_module_node = node.args[0] + if not isinstance(call_custom_module_node, Node): + raise AssertionError( + f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" + ) + node.replace_all_uses_with(call_custom_module_node) + graph.erase_node(node) + _insert_dequantize_node(call_custom_module_node, graph) + + +def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool: + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] + + return ( + (dtype in SUPPORTED_QDTYPES and (not is_dynamic)) + or is_dynamic # type: ignore[return-value] + or dtype == torch.float16 + ) + + +def _has_none_qconfig( + node: Argument, node_name_to_qconfig: dict[str, QConfigAny] +) -> bool: + """Check if a node has a qconfig of None, i.e. user requested to not quantize + the node + """ + return ( + isinstance(node, Node) + and node.name in node_name_to_qconfig + and node_name_to_qconfig[node.name] is None + ) + + +def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None: + """Extract the subgraph that produces the weight for dynamic quant + or weight only quant node and run the subgraph to observe the weight. + Note that the observers of dynamic quant or weight only quant ops are + run during the convert step. + """ + for node in observed.graph.nodes: + if node.op != "call_function": + continue + for node_arg in node.args: + # node_arg is weight + if node_arg and node_arg_is_weight(node, node_arg): + weight_observer_nodes = collect_producer_nodes(node_arg) + if weight_observer_nodes is None: + continue + weight_observer_module = graph_module_from_producer_nodes( + observed, weight_observer_nodes + ) + # run the weight observer + weight_observer_module() + + +def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None: + """If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node, + we'll recursively remove the dequantize Node + """ + if isinstance(arg, Node) and arg.op == "call_method" and arg.target == "dequantize": + quantize_node = arg.args[0] + # we only replace the specific use since dequantize could be used by other nodes + # as well + node.replace_input_with(arg, quantize_node) + elif isinstance(arg, (list, tuple)): + for arg_element in arg: + _maybe_recursive_remove_dequantize(arg_element, node, graph) + elif isinstance(arg, dict): + for arg_element in arg.values(): + _maybe_recursive_remove_dequantize(arg_element, node, graph) + else: + warnings.warn( + f"Unsupported node type in recursive remove dequantize: {type(arg)}", + stacklevel=2, + ) + + +def _get_module_path_and_prefix( + obs_node: Node, + node_name_to_scope: dict[str, tuple[str, type]], + node_name_to_qconfig: dict[str, QConfigAny], +) -> tuple[str, str]: + """Given and observer node, get the `Scope` or the fully qualified name for + the submodule containing the observed node, also return a prefix of "_input" + when the observed node is an input of a F.linear op, and not the output of another + quantized op. + TODO: this logic is hacky, we should think about how to remove it or make it more + general + """ + observed_node = obs_node.args[0] + # an observer can be inserted for both input of the next operator or output of the previous + # operator (they can be the same) + # this flag identifies if the observer is inserted only because the observed node is + # the input of the next operator + if not isinstance(observed_node, Node): + raise AssertionError( + f"Expecting observed node to be a Node, but got {observed_node}" + ) + is_input_observer_only = ( + node_name_to_qconfig[observed_node.name] is None + if observed_node.name in node_name_to_qconfig + else None + ) + if is_input_observer_only: + # if the quantize function is at the input of op, then we find the first user of the observer_node + # to get the path. If a linear call_function is in the user list, we return the first instance + # of linear node to get the FQN. + users = list(obs_node.users) + first_linear_use_or_first_use = users[0] if users else None + linear_node = None + for n in users: + if n.op == "call_function" and n.target is torch.nn.functional.linear: + linear_node = n + break + if linear_node: + first_linear_use_or_first_use = linear_node + prefix = "_input" + else: + # if the quantize function is at the output of the op, we use the observer input node to get the path + first_linear_use_or_first_use = observed_node + prefix = "" + + if ( + first_linear_use_or_first_use + and first_linear_use_or_first_use.name in node_name_to_scope + ): + module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name] + else: + # TODO: it's not used, so actually we can skip quantization + # but this requires changing return type of quantize_node + # we can fix it later if needed + module_path = "" + return module_path, prefix + + +def _insert_dequantize_node(node: Node, graph: Graph) -> None: + """Inserts dequantize node for `node` in `graph`""" + with graph.inserting_after(node): + dequantize_node = graph.call_method("dequantize", (node,)) + for user_node in dict(node.users): + if user_node is not dequantize_node: + user_node.replace_input_with(node, dequantize_node) + + +def _maybe_get_observer_for_node( + node: Node, modules: dict[str, torch.nn.Module] +) -> torch.nn.Module | None: + """ + If the node is observed, return the observer + instance. Otherwise, return None. + """ + for maybe_obs_node in node.users: + if maybe_obs_node.op == "call_module": + maybe_obs = modules[str(maybe_obs_node.target)] + if _is_activation_post_process(maybe_obs): + return maybe_obs + return None + + +def convert_standalone_module( + node: Node, + modules: dict[str, torch.nn.Module], + model: torch.fx.GraphModule, + is_reference: bool, + backend_config: BackendConfig | None, +) -> None: + """Converts a observed standalone module to a quantized standalone module by calling + the fx convert api, currently using the same `is_reference` flag as parent, but we may + changing this behavior in the future (e.g. separating quantization and lowering for + standalone module as well) + + Args: + - node: The call_module node of the observed standalone module + - modules: named_module of original model + - model: original model + - is_reference: a flag from parent provided by user to decide if we want to + produce a reference model or a fbgemm/qnnpack model + - backend_config: backend configuration of the target backend of quantization + """ + # TODO: remove is_reference flag + if is_reference: + convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx + else: + convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined] + # We know that observed standalone module is a GraphModule since + # it's produced by us + observed_standalone_module: GraphModule = modules[str(node.target)] # type: ignore[assignment] + sm_input_quantized_idxs = observed_standalone_module.meta[ + "_observed_graph_module_attrs" + ].standalone_module_input_quantized_idxs + # remove the dequantize nodes for inputs + args = list(node.args) + for idx in range(len(args)): + if idx in sm_input_quantized_idxs: + arg = args[idx] + if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr] + quantize_node = arg.args[0] # type: ignore[union-attr] + node.replace_input_with(arg, quantize_node) + if len(arg.users) == 0: # type: ignore[union-attr] + model.graph.erase_node(arg) + # add dequantize node for output + sm_output_quantized_idxs = observed_standalone_module.meta[ + "_observed_graph_module_attrs" + ].standalone_module_output_quantized_idxs + if len(sm_output_quantized_idxs) > 0: + if sm_output_quantized_idxs[0] != 0: + raise AssertionError( + "Currently only quantized output idxs = [0] is supported" + ) + + # if it's non-empty, then it means the output is kept in quantized form + # we'll just add a dequantize node after this node + _insert_dequantize_node(node, model.graph) + + # TODO: allow convert_custom_config to override backend_config + # for standalone module + quantized_standalone_module = convert_fn( + observed_standalone_module, backend_config=backend_config + ) + parent_name, name = _parent_name(node.target) + # update the modules dict + setattr(modules[parent_name], name, quantized_standalone_module) + modules[str(node.target)] = quantized_standalone_module + + +def convert_weighted_module( + node: Node, + modules: dict[str, torch.nn.Module], + observed_node_names: set[str], + node_name_to_qconfig: dict[str, QConfigAny], + backend_config: BackendConfig, + is_decomposed: bool = False, + is_reference: bool = False, + model_device: torch.device | None = None, +) -> None: + """Convert a weighted module to reference quantized module in the model + If the QConfig of a QAT module is not set, the module will still be converted to + a float module. + + Args: + - node: The call_module node of the observed standalone module + - modules: named_module of original model + - observed_node_names: names for the set of observed fx node, we can skip + this conversion if the node is not observed + """ + original_module = modules[str(node.target)] + qconfig: QConfigAny = original_module.qconfig # type: ignore[assignment] + weight_post_process = None + qat_module_classes = get_qat_module_classes(backend_config) + + if isinstance(original_module, qat_module_classes): + # Converting qat module to a float module, we need to attach + # weight fake_quant to the module, weight fake_quant is assumed to be run during + # QAT so we don't need to run it again here + weight_post_process = original_module.weight_fake_quant + original_module = original_module.to_float() # type: ignore[operator] + # change qat module to float module + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, original_module) + + is_observed = node.name in observed_node_names + # If a qconfig is not defined for this node, then skip converting to a reference module + if ( + qconfig is None + or _has_none_qconfig(node, node_name_to_qconfig) + or not is_observed + ): + return + + # skip converting to reference quantized module if the qconfig is not supported + pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) + dtype_configs = pattern_to_dtype_configs.get(type(original_module), []) + if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs): + return + + # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized + is_weight_quantized = weight_is_quantized(qconfig) + + # the condition for swapping the module to reference quantized module is: + # weights need to be quantized + if not is_weight_quantized: + return + + fused_module = None + float_module = original_module + # extract the individual float_module and fused module + if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule): + fused_module = float_module + float_module = fused_module[0] # type: ignore[index] + + # TODO: move this to the reference quantized module + # weight_qparams or weight_qparams dict + wq_or_wq_dict = {"is_decomposed": is_decomposed} + if isinstance(float_module, torch.nn.RNNCellBase): + weight_post_process_ih = qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process_hh = qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process_ih(float_module.weight_ih) + weight_post_process_hh(float_module.weight_hh) + weight_qparams_ih = get_qparam_dict(weight_post_process_ih) + weight_qparams_hh = get_qparam_dict(weight_post_process_hh) + wq_or_wq_dict.update( + { + "weight_ih": weight_qparams_ih, + "weight_hh": weight_qparams_hh, + } + ) + elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)): + # format for wq_or_wq_dict (flattened attributes): + # {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...} + for wn in float_module._flat_weights_names: + if hasattr(float_module, wn) and wn.startswith("weight"): + weight = getattr(float_module, wn) + weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] + if weight_post_process.dtype == torch.qint8: # type: ignore[union-attr] + weight_post_process(weight) # type: ignore[operator, misc] + wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process) + else: + # weight_post_process is None means the original module is not a QAT module + # we need to get weight_post_process from qconfig in this case + is_ptq = weight_post_process is None + if is_ptq: + weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] + if model_device is not None: + device = model_device + else: + device = assert_and_get_unique_device(float_module) + if device: + weight_post_process.to(device) + + # Call weight observer/fake_quant at least once to ensure the scales and zero points + # have the right shapes. Note: there are two cases where we don't have to do this: + # + # (1) QAT: The model's forward method already calls the weight observer/fake_quant, + # and this typically happens during training, so we don't need to do it here. + # + # (2) Non-reference (lowered) case: The quantized module's from_float method already + # calls the weight observer/fake_quant, so we don't have to do it here. + # + # Currently we ignore both cases and call the weight observer/fake_quant here + # regardless, which is technically incorrect. For (1), this is mainly to preserve BC + # in test code, which may not always train before convert. In the future, we should + # break BC for these two cases. See https://github.com/pytorch/pytorch/issues/73941. + # + # For PT2, however, we don't need to preserve BC here, so we can skip this hack + # for QAT. We identify this case as (is_decomposed + is_reference + is_qat). + # Note that we still need it for PTQ in the PT2 flow since the model's forward + # method doesn't call the weight observer. + is_qat = not is_ptq + if not (is_decomposed and is_reference and is_qat): + weight_post_process(float_module.weight) # type: ignore[operator] + + wq_or_wq_dict.update(get_qparam_dict(weight_post_process)) + + # We use the same reference module for all modes of quantization: static, dynamic, weight_only + # root_module_to_quantized_reference_module: module mapping from root (floating point) module class + # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d + root_module_to_quantized_reference_module = ( + get_root_module_to_quantized_reference_module(backend_config) + ) + ref_qmodule_cls = root_module_to_quantized_reference_module.get( + type_before_parametrizations(float_module), None + ) + if ref_qmodule_cls is None: + raise AssertionError( + f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" + ) + ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined] + if fused_module is not None: + fused_module[0] = ref_qmodule # type: ignore[operator] + else: + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, ref_qmodule) + + +def _remove_previous_dequantize_in_custom_module( + node: Node, prev_node: Node, graph: Graph +) -> None: + """ + Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows: + + Before: quantize - dequantize - custom_module + After: quantize - custom_module + \\ - dequantize + """ + # expecting the input node for a custom module node to be a Node + if not isinstance(prev_node, Node): + raise AssertionError( + f"Expecting the argument for custom module node to be a Node, but got {prev_node}" + ) + if prev_node.op == "call_method" and prev_node.target == "dequantize": + node.replace_input_with(prev_node, prev_node.args[0]) + # Remove the dequantize node if it doesn't have other users + if len(prev_node.users) == 0: + graph.erase_node(prev_node) + + +def convert_custom_module( + node: Node, + graph: Graph, + modules: dict[str, torch.nn.Module], + custom_module_class_mapping: dict[QuantType, dict[type, type]], + statically_quantized_custom_module_nodes: set[Node], +) -> None: + """Converts an observed custom module to a quantized custom module based on + `custom_module_class_mapping` + For static quantization, we'll also remove the previous `dequantize` node and + attach the observer node for output to the module, the observer for the node + will be converted to a dequantize node instead of quantize-dequantize pairs + later in the graph. In the end we would have a quantized custom module that + has the same interface as a default quantized module in nn.quantized namespace, + i.e. quantized input and quantized output. + + Args: + - node: The call_module node of the observed standalone module + - graph: The graph containing the node + - modules: named_module of original model + - custom_module_class_mapping: mapping from observed custom module class to + quantized custom module class, used to swap custom modules + - statically_quantized_custom_module_nodes: we'll add the custom module node + if we find it is statically quantized, this will be used later when converting + observers to quant/dequant node pairs, if the observed node is a statically + quantized custom module nodes, we'll convert the observer to a dequantize node, + this is to keep the interface the same as the default quantized module. + TODO: maybe we want to redesign this part to align with reference model design + as well, but there has been some discussions around the interface, so we can do + it later. + """ + observed_custom_module = modules[str(node.target)] + qconfig = observed_custom_module.qconfig + if activation_is_statically_quantized(qconfig): + statically_quantized_custom_module_nodes.add(node) + if _is_custom_module_lstm(node, modules): + # The inputs are tuples in the form (input, (hidden0, hidden1)) + # Ensure all three input nodes are quantized + if not ( + len(node.args) == 2 + and isinstance(node.args[1], tuple) + and len(node.args[1]) == 2 + ): + raise AssertionError( + "Expected LSTM custom module inputs to be (input, (hidden0, hidden1))" + ) + (inputs, (hidden0, hidden1)) = node.args # type: ignore[misc] + if not isinstance(inputs, Node): + raise AssertionError("Expected inputs to be a Node") + if not isinstance(hidden0, Node): + raise AssertionError("Expected hidden0 to be a Node") + if not isinstance(hidden1, Node): + raise AssertionError("Expected hidden1 to be a Node") + _remove_previous_dequantize_in_custom_module(node, inputs, graph) + _remove_previous_dequantize_in_custom_module(node, hidden0, graph) + _remove_previous_dequantize_in_custom_module(node, hidden1, graph) + elif _is_custom_module_mha(node, modules): + # Inputs are in the form (query, key, value) + # TODO: This is the first step in enabling the full fx custom module + # quantization path for MultiheadAttention, and only covers the inputs + # to the module. + # Additional handling is yet to be implemented for the outputs, similar + # to LSTM custom module + if len(node.args) != 3: + raise AssertionError( + "Expected MHA custom module inputs to be (query, key, value)" + ) + query, key, value = node.args + if not isinstance(query, Node): + raise AssertionError("Expected query to be a Node") + if not isinstance(key, Node): + raise AssertionError("Expected key to be a Node") + if not isinstance(value, Node): + raise AssertionError("Expected value to be a Node") + _remove_previous_dequantize_in_custom_module(node, query, graph) + _remove_previous_dequantize_in_custom_module(node, key, graph) + _remove_previous_dequantize_in_custom_module(node, value, graph) + else: + # remove the previous dequant node to ensure the inputs are quantized + arg = node.args[0] + if not isinstance(arg, Node): + raise AssertionError("Expected arg to be a Node") + _remove_previous_dequantize_in_custom_module(node, arg, graph) + # absorb the following observer into the module conversion + activation_post_process = _maybe_get_observer_for_node(node, modules) + if activation_post_process is None: + raise AssertionError( + "Expected activation_post_process to be present for observed custom module" + ) + observed_custom_module.activation_post_process = activation_post_process + + # swap the observed custom module to quantized custom module + quantized_custom_module_class = get_swapped_custom_module_class( + observed_custom_module, custom_module_class_mapping, qconfig + ) + quantized_custom_module = quantized_custom_module_class.from_observed( + observed_custom_module + ) + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, quantized_custom_module) + + +def convert( + model: GraphModule, + is_reference: bool = False, + convert_custom_config: ConvertCustomConfig | dict[str, Any] | None = None, + is_standalone_module: bool = False, + _remove_qconfig_flag: bool = True, + qconfig_mapping: QConfigMapping | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, + is_decomposed: bool = False, + keep_original_weights: bool = False, +) -> GraphModule: + """ + We will convert an observed model (a module with observer calls) to a reference + quantized model, the rule is simple: + 1. for each observer module call in the graph, we'll convert it to calls to + quantize and dequantize functions based on the observer instance + 2. for weighted operations like linear/conv, we need to convert them to reference + quantized module, this requires us to know whether the dtype configured for the + weight is supported in the backend, this is done in prepare step and the result + is stored in observed_node_names, we can decide whether we need to swap the + module based on this set + + Args: + * `is_standalone_module`: when this flag is True, it means we are quantizing + a submodule that is not inlined in parent module, and will be quantized + separately as one unit. + + * `is_decomposed`: a boolean flag to indicate whether we want to use the + quantize operator for decomposed quantized tensor + (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone + quantized tensor (torch.quantize_per_tensor) + + Returns: + a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config, with + input_quantized_idxs, output_quantized_idxs, please + see docs for :func:`~torch.ao.quantization.prepare_fx` for details + """ + if convert_custom_config is None: + convert_custom_config = ConvertCustomConfig() + + if isinstance(convert_custom_config, dict): + warnings.warn( + "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " + "in a future version. Please pass in a ConvertCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) + convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) + + if isinstance(qconfig_mapping, dict): + warnings.warn( + "Passing a QConfig dictionary to convert is deprecated and will not be supported " + "in a future version. Please pass in a QConfigMapping instead.", + FutureWarning, + stacklevel=2, + ) + qconfig_mapping = ( + QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None + ) + qconfig_mapping = copy.deepcopy(qconfig_mapping) + if not (qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)): + raise AssertionError("qconfig_mapping must be None or a QConfigMapping") + + if isinstance(backend_config, dict): + warnings.warn( + "Passing a backend_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a BackendConfig instead.", + FutureWarning, + stacklevel=2, + ) + backend_config = BackendConfig.from_dict(backend_config) + + if backend_config is None: + backend_config = get_native_backend_config() + + if not _is_observed_module(model): + raise AssertionError("incoming model must be produced by prepare_fx") + observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] + node_name_to_scope: dict[str, tuple[str, type]] = ( + observed_graph_module_attrs.node_name_to_scope + ) + prepare_custom_config: PrepareCustomConfig = ( + observed_graph_module_attrs.prepare_custom_config + ) + observed_node_names: set[str] = observed_graph_module_attrs.observed_node_names + node_name_to_qconfig: dict[str, QConfigAny] = ( + observed_graph_module_attrs.node_name_to_qconfig + ) # type: ignore[assignment] + + # mapping from fully qualified module name to module instance + # for example, + # { + # '': Model(...), + # 'linear': Linear(...), + # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), + # } + # We use remove_duplicate=False here because torch.cat uses + # the same activation_post_process module instance but different names + modules = dict(model.named_modules(remove_duplicate=False)) + + # TODO refactor this code once we update the prepare logic to have additional information on + # which graph nodes have been observed and share that with convert to decide which observers to ignore. + if qconfig_mapping: + prepare_qconfig_mapping: QConfigMapping = ( + observed_graph_module_attrs.qconfig_mapping + ) # type: ignore[assignment] + modules_copy = copy.deepcopy(modules) + + if observed_graph_module_attrs.is_qat: + _update_qconfig_for_qat(qconfig_mapping, backend_config) + _update_qconfig_for_fusion(model, qconfig_mapping) + + _compare_prepare_convert_qconfig_mappings( + prepare_qconfig_mapping, qconfig_mapping + ) # type: ignore[arg-type] + convert_node_name_to_qconfig = _generate_node_name_to_qconfig( + model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope + ) + # check the convert_node_name_to_qconfig generated and ensure that + # all the values either match what was set in prepare node_name_to_qconfig + # or are set to None in the convert_node_name_to_qconfig. + for k, v in node_name_to_qconfig.items(): + if k not in convert_node_name_to_qconfig: + raise AssertionError( + f"Expected key {k} in convert node_name_to_qconfig" + ) + if convert_node_name_to_qconfig[k] is not None: + if not qconfig_equals(v, convert_node_name_to_qconfig[k]): + raise AssertionError( + f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " + f"but {v} was updated to {convert_node_name_to_qconfig[k]}" + ) + node_name_to_qconfig = convert_node_name_to_qconfig + + custom_module_classes = get_custom_module_class_keys( + convert_custom_config.observed_to_quantized_mapping + ) + custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping + + if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None: + # If we want to do equalization then do the following: + # Calculate the equalization scale, update the observers with the scaled + # inputs, and scale the weight + weight_eq_obs_dict = update_obs_for_equalization(model, modules) + convert_eq_obs(model, modules, weight_eq_obs_dict) + + # always run weight observers in the top level forward method + # for dynamic quant ops or weight only quant ops + _run_weight_observers(model, backend_config) + + # additional state to override inputs to be quantized, if specified + # by the user + placeholder_node_seen_cnt = 0 + input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes + + root_module_to_quantized_reference_module = ( + get_root_module_to_quantized_reference_module(backend_config) + ) + # convert tuples so that it can work with isinstance(module, tuple_of_classes) + root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) + qat_module_classes = get_qat_module_classes(backend_config) + fused_module_classes = get_fused_module_classes(backend_config) + statically_quantized_custom_module_nodes: set[Node] = set() + model_device = assert_and_get_unique_device(model) + + for node in list(model.graph.nodes): + if node.op == "placeholder": + cur_placeholder_node_idx = placeholder_node_seen_cnt + placeholder_node_seen_cnt += 1 + if cur_placeholder_node_idx in input_quantized_idxs: + # Inputs are assumed to be quantized if the user specified the + # input_quantized_idxs override. + # we need to dequantize the inputs since all operators took + # floating point inputs in reference quantized models + _insert_dequantize_node(node, model.graph) + elif node.op == "output": + # If the argument is empty we don't need to do anything + if len(output_quantized_idxs) == 0: + continue + # Result are kept quantized if the user specified the + # output_quantized_idxs override. + # Remove the dequantize operator for the node in the end if any + return_node = node + output = node.args[0] + # outputs can be Node, list, tuple, dict, other cases are not supported yet + if isinstance(output, (list, tuple)): + for idx in output_quantized_idxs: + _maybe_recursive_remove_dequantize( + output[idx], return_node, model.graph + ) + elif isinstance(output, (Node, dict)): + # we treat dict as a single argument currently, but it can be extended + # to support {"key": dtype} after we change output_quantized_idxs to + # dict + if 0 in output_quantized_idxs: + _maybe_recursive_remove_dequantize(output, return_node, model.graph) + else: + warnings.warn( + f"Unsupported node type for output_quantized_idxs: {type(output)}", + stacklevel=2, + ) + elif node.op == "call_module": + mod = _get_module(node, modules) + if mod is None: + raise AssertionError( + "Expected module for call_module node to be present in modules mapping" + ) + if _is_activation_post_process(mod): + observed_node = node.args[0] + if observed_node in statically_quantized_custom_module_nodes: + _replace_observer_or_dequant_stub_with_dequantize_node( + node, model.graph + ) + else: + if is_decomposed: + _replace_observer_with_quantize_dequantize_node_decomposed( + model, + node, + modules, + node_name_to_scope, + node_name_to_qconfig, + model_device, + ) + else: + _replace_observer_with_quantize_dequantize_node( + model, + node, + modules, + node_name_to_scope, + node_name_to_qconfig, + model_device, + ) + elif isinstance(mod, DeQuantStub): + _replace_observer_or_dequant_stub_with_dequantize_node( + node, model.graph + ) + elif _is_observed_standalone_module(mod): + convert_standalone_module( + node, modules, model, is_reference, backend_config + ) + # below this point `type_before_parametrizations` is used + # instead of `type` to handle situations with fx quant + sparsity + elif type_before_parametrizations(mod) in set(root_module_classes).union( + qat_module_classes + ).union(fused_module_classes): + # extra check for fused module classes to make sure they are fused module classes + # of target modules + if ( + type_before_parametrizations(mod) in fused_module_classes + and type_before_parametrizations(mod[0]) not in root_module_classes + ): # type: ignore[index] + continue + convert_weighted_module( + node, + modules, + observed_node_names, + node_name_to_qconfig, + backend_config, + is_decomposed, + is_reference, + model_device, + ) + elif type_before_parametrizations(mod) in custom_module_classes: + convert_custom_module( + node, + model.graph, + modules, + custom_module_class_mapping, + statically_quantized_custom_module_nodes, + ) + + # remove deadcode after converting observers to quant/dequant ops + model.graph.eliminate_dead_code() + model = GraphModule(model, model.graph) + + # TODO: maybe move this to quantize_fx.py + if not is_reference: + model = lower_to_fbgemm( + model, node_name_to_qconfig, node_name_to_scope, keep_original_weights + ) + + # TODO: this looks hacky, we want to check why we need this and see if we can + # remove this + # removes qconfig and activation_post_process modules + if _remove_qconfig_flag: + _remove_qconfig(model) + model.delete_all_unused_submodules() + model.meta.pop("_observed_graph_module_attrs", None) + return model diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/custom_config.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/custom_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e749de94bd5c3d1eb0c34a14cfcf38d441aedbff --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/custom_config.py @@ -0,0 +1,521 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.quant_type import ( + _get_quant_type_to_str, + _quant_type_from_str, + QuantType, +) + + +__all__ = [ + "ConvertCustomConfig", + "FuseCustomConfig", + "PrepareCustomConfig", + "StandaloneModuleConfigEntry", +] + + +# TODO: replace all usages with these constants +STANDALONE_MODULE_NAME_DICT_KEY = "standalone_module_name" +STANDALONE_MODULE_CLASS_DICT_KEY = "standalone_module_class" +FLOAT_TO_OBSERVED_DICT_KEY = "float_to_observed_custom_module_class" +OBSERVED_TO_QUANTIZED_DICT_KEY = "observed_to_quantized_custom_module_class" +NON_TRACEABLE_MODULE_NAME_DICT_KEY = "non_traceable_module_name" +NON_TRACEABLE_MODULE_CLASS_DICT_KEY = "non_traceable_module_class" +INPUT_QUANTIZED_INDEXES_DICT_KEY = "input_quantized_idxs" +OUTPUT_QUANTIZED_INDEXES_DICT_KEY = "output_quantized_idxs" +PRESERVED_ATTRIBUTES_DICT_KEY = "preserved_attributes" + + +@dataclass +class StandaloneModuleConfigEntry: + # qconfig_mapping for the prepare function called in the submodule, + # None means use qconfig from parent qconfig_mapping + qconfig_mapping: QConfigMapping | None + example_inputs: tuple[Any, ...] + prepare_custom_config: PrepareCustomConfig | None + backend_config: BackendConfig | None + + +class PrepareCustomConfig: + """ + Custom configuration for :func:`~torch.ao.quantization.quantize_fx.prepare_fx` and + :func:`~torch.ao.quantization.quantize_fx.prepare_qat_fx`. + + Example usage:: + + prepare_custom_config = PrepareCustomConfig() \ + .set_standalone_module_name("module1", qconfig_mapping, example_inputs, \ + child_prepare_custom_config, backend_config) \ + .set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \ + child_prepare_custom_config, backend_config) \ + .set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \ + .set_non_traceable_module_names(["module2", "module3"]) \ + .set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \ + .set_input_quantized_indexes([0]) \ + .set_output_quantized_indexes([0]) \ + .set_preserved_attributes(["attr1", "attr2"]) + """ + + def __init__(self) -> None: + self.standalone_module_names: dict[str, StandaloneModuleConfigEntry] = {} + self.standalone_module_classes: dict[type, StandaloneModuleConfigEntry] = {} + self.float_to_observed_mapping: dict[QuantType, dict[type, type]] = {} + self.non_traceable_module_names: list[str] = [] + self.non_traceable_module_classes: list[type] = [] + self.input_quantized_indexes: list[int] = [] + self.output_quantized_indexes: list[int] = [] + self.preserved_attributes: list[str] = [] + + def __repr__(self): + dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0} + return f"PrepareCustomConfig({dict_nonempty})" + + def set_standalone_module_name( + self, + module_name: str, + qconfig_mapping: QConfigMapping | None, + example_inputs: tuple[Any, ...], + prepare_custom_config: PrepareCustomConfig | None, + backend_config: BackendConfig | None, + ) -> PrepareCustomConfig: + """ + Set the configuration for running a standalone module identified by ``module_name``. + + If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. + If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. + If ``backend_config`` is None, the parent ``backend_config`` will be used instead. + """ + self.standalone_module_names[module_name] = StandaloneModuleConfigEntry( + qconfig_mapping, example_inputs, prepare_custom_config, backend_config + ) + return self + + def set_standalone_module_class( + self, + module_class: type, + qconfig_mapping: QConfigMapping | None, + example_inputs: tuple[Any, ...], + prepare_custom_config: PrepareCustomConfig | None, + backend_config: BackendConfig | None, + ) -> PrepareCustomConfig: + """ + Set the configuration for running a standalone module identified by ``module_class``. + + If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. + If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. + If ``backend_config`` is None, the parent ``backend_config`` will be used instead. + """ + self.standalone_module_classes[module_class] = StandaloneModuleConfigEntry( + qconfig_mapping, example_inputs, prepare_custom_config, backend_config + ) + return self + + def set_float_to_observed_mapping( + self, + float_class: type, + observed_class: type, + quant_type: QuantType = QuantType.STATIC, + ) -> PrepareCustomConfig: + """ + Set the mapping from a custom float module class to a custom observed module class. + + The observed module class must have a ``from_float`` class method that converts the float module class + to the observed module class. This is currently only supported for static quantization. + """ + if quant_type != QuantType.STATIC: + raise ValueError( + "set_float_to_observed_mapping is currently only supported for static quantization" + ) + if quant_type not in self.float_to_observed_mapping: + self.float_to_observed_mapping[quant_type] = {} + self.float_to_observed_mapping[quant_type][float_class] = observed_class + return self + + def set_non_traceable_module_names( + self, module_names: list[str] + ) -> PrepareCustomConfig: + """ + Set the modules that are not symbolically traceable, identified by name. + """ + self.non_traceable_module_names = module_names + return self + + def set_non_traceable_module_classes( + self, module_classes: list[type] + ) -> PrepareCustomConfig: + """ + Set the modules that are not symbolically traceable, identified by class. + """ + self.non_traceable_module_classes = module_classes + return self + + def set_input_quantized_indexes(self, indexes: list[int]) -> PrepareCustomConfig: + """ + Set the indexes of the inputs of the graph that should be quantized. + Inputs are otherwise assumed to be in fp32 by default instead. + """ + self.input_quantized_indexes = indexes + return self + + def set_output_quantized_indexes(self, indexes: list[int]) -> PrepareCustomConfig: + """ + Set the indexes of the outputs of the graph that should be quantized. + Outputs are otherwise assumed to be in fp32 by default instead. + """ + self.output_quantized_indexes = indexes + return self + + def set_preserved_attributes(self, attributes: list[str]) -> PrepareCustomConfig: + """ + Set the names of the attributes that will persist in the graph module even if they are not used in + the model's ``forward`` method. + """ + self.preserved_attributes = attributes + return self + + # TODO: remove this + @classmethod + def from_dict( + cls, prepare_custom_config_dict: dict[str, Any] + ) -> PrepareCustomConfig: + """ + Create a ``PrepareCustomConfig`` from a dictionary with the following items: + + "standalone_module_name": a list of (module_name, qconfig_mapping, example_inputs, + child_prepare_custom_config, backend_config) tuples + + "standalone_module_class" a list of (module_class, qconfig_mapping, example_inputs, + child_prepare_custom_config, backend_config) tuples + + "float_to_observed_custom_module_class": a nested dictionary mapping from quantization + mode to an inner mapping from float module classes to observed module classes, e.g. + {"static": {FloatCustomModule: ObservedCustomModule}} + + "non_traceable_module_name": a list of modules names that are not symbolically traceable + "non_traceable_module_class": a list of module classes that are not symbolically traceable + "input_quantized_idxs": a list of indexes of graph inputs that should be quantized + "output_quantized_idxs": a list of indexes of graph outputs that should be quantized + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` + + This function is primarily for backward compatibility and may be removed in the future. + """ + + def _get_qconfig_mapping(obj: Any, dict_key: str) -> QConfigMapping | None: + """ + Convert the given object into a QConfigMapping if possible, else throw an exception. + """ + if isinstance(obj, QConfigMapping) or obj is None: + return obj + if isinstance(obj, dict): + return QConfigMapping.from_dict(obj) + raise ValueError( + f"Expected QConfigMapping in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'" + ) + + def _get_prepare_custom_config( + obj: Any, dict_key: str + ) -> PrepareCustomConfig | None: + """ + Convert the given object into a PrepareCustomConfig if possible, else throw an exception. + """ + if isinstance(obj, PrepareCustomConfig) or obj is None: + return obj + if isinstance(obj, dict): + return PrepareCustomConfig.from_dict(obj) + raise ValueError( + f"Expected PrepareCustomConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'" + ) + + def _get_backend_config(obj: Any, dict_key: str) -> BackendConfig | None: + """ + Convert the given object into a BackendConfig if possible, else throw an exception. + """ + if isinstance(obj, BackendConfig) or obj is None: + return obj + if isinstance(obj, dict): + return BackendConfig.from_dict(obj) + raise ValueError( + f"Expected BackendConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'" + ) + + conf = cls() + for ( + module_name, + qconfig_dict, + example_inputs, + _prepare_custom_config_dict, + backend_config_dict, + ) in prepare_custom_config_dict.get(STANDALONE_MODULE_NAME_DICT_KEY, []): + qconfig_mapping = _get_qconfig_mapping( + qconfig_dict, STANDALONE_MODULE_NAME_DICT_KEY + ) + prepare_custom_config = _get_prepare_custom_config( + _prepare_custom_config_dict, STANDALONE_MODULE_NAME_DICT_KEY + ) + backend_config = _get_backend_config( + backend_config_dict, STANDALONE_MODULE_NAME_DICT_KEY + ) + conf.set_standalone_module_name( + module_name, + qconfig_mapping, + example_inputs, + prepare_custom_config, + backend_config, + ) + for ( + module_class, + qconfig_dict, + example_inputs, + _prepare_custom_config_dict, + backend_config_dict, + ) in prepare_custom_config_dict.get(STANDALONE_MODULE_CLASS_DICT_KEY, []): + qconfig_mapping = _get_qconfig_mapping( + qconfig_dict, STANDALONE_MODULE_CLASS_DICT_KEY + ) + prepare_custom_config = _get_prepare_custom_config( + _prepare_custom_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY + ) + backend_config = _get_backend_config( + backend_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY + ) + conf.set_standalone_module_class( + module_class, + qconfig_mapping, + example_inputs, + prepare_custom_config, + backend_config, + ) + for quant_type_name, custom_module_mapping in prepare_custom_config_dict.get( + FLOAT_TO_OBSERVED_DICT_KEY, {} + ).items(): + quant_type = _quant_type_from_str(quant_type_name) + for float_class, observed_class in custom_module_mapping.items(): + conf.set_float_to_observed_mapping( + float_class, observed_class, quant_type + ) + conf.set_non_traceable_module_names( + prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_NAME_DICT_KEY, []) + ) + conf.set_non_traceable_module_classes( + prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_CLASS_DICT_KEY, []) + ) + conf.set_input_quantized_indexes( + prepare_custom_config_dict.get(INPUT_QUANTIZED_INDEXES_DICT_KEY, []) + ) + conf.set_output_quantized_indexes( + prepare_custom_config_dict.get(OUTPUT_QUANTIZED_INDEXES_DICT_KEY, []) + ) + conf.set_preserved_attributes( + prepare_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []) + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``PrepareCustomConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig.from_dict`. + """ + + def _make_tuple(key: Any, e: StandaloneModuleConfigEntry): + qconfig_dict = e.qconfig_mapping.to_dict() if e.qconfig_mapping else None + prepare_custom_config_dict = ( + e.prepare_custom_config.to_dict() if e.prepare_custom_config else None + ) + return ( + key, + qconfig_dict, + e.example_inputs, + prepare_custom_config_dict, + e.backend_config, + ) + + d: dict[str, Any] = {} + for module_name, sm_config_entry in self.standalone_module_names.items(): + if STANDALONE_MODULE_NAME_DICT_KEY not in d: + d[STANDALONE_MODULE_NAME_DICT_KEY] = [] + d[STANDALONE_MODULE_NAME_DICT_KEY].append( + _make_tuple(module_name, sm_config_entry) + ) + for module_class, sm_config_entry in self.standalone_module_classes.items(): + if STANDALONE_MODULE_CLASS_DICT_KEY not in d: + d[STANDALONE_MODULE_CLASS_DICT_KEY] = [] + d[STANDALONE_MODULE_CLASS_DICT_KEY].append( + _make_tuple(module_class, sm_config_entry) + ) + for ( + quant_type, + float_to_observed_mapping, + ) in self.float_to_observed_mapping.items(): + if FLOAT_TO_OBSERVED_DICT_KEY not in d: + d[FLOAT_TO_OBSERVED_DICT_KEY] = {} + d[FLOAT_TO_OBSERVED_DICT_KEY][_get_quant_type_to_str(quant_type)] = ( + float_to_observed_mapping + ) + if len(self.non_traceable_module_names) > 0: + d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names + if len(self.non_traceable_module_classes) > 0: + d[NON_TRACEABLE_MODULE_CLASS_DICT_KEY] = self.non_traceable_module_classes + if len(self.input_quantized_indexes) > 0: + d[INPUT_QUANTIZED_INDEXES_DICT_KEY] = self.input_quantized_indexes + if len(self.output_quantized_indexes) > 0: + d[OUTPUT_QUANTIZED_INDEXES_DICT_KEY] = self.output_quantized_indexes + if len(self.preserved_attributes) > 0: + d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes + return d + + +class ConvertCustomConfig: + """ + Custom configuration for :func:`~torch.ao.quantization.quantize_fx.convert_fx`. + + Example usage:: + + convert_custom_config = ConvertCustomConfig() \ + .set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \ + .set_preserved_attributes(["attr1", "attr2"]) + """ + + def __init__(self) -> None: + self.observed_to_quantized_mapping: dict[QuantType, dict[type, type]] = {} + self.preserved_attributes: list[str] = [] + + def __repr__(self): + dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0} + return f"ConvertCustomConfig({dict_nonempty})" + + def set_observed_to_quantized_mapping( + self, + observed_class: type, + quantized_class: type, + quant_type: QuantType = QuantType.STATIC, + ) -> ConvertCustomConfig: + """ + Set the mapping from a custom observed module class to a custom quantized module class. + + The quantized module class must have a ``from_observed`` class method that converts the observed module class + to the quantized module class. + """ + if quant_type not in self.observed_to_quantized_mapping: + self.observed_to_quantized_mapping[quant_type] = {} + self.observed_to_quantized_mapping[quant_type][observed_class] = quantized_class + return self + + def set_preserved_attributes(self, attributes: list[str]) -> ConvertCustomConfig: + """ + Set the names of the attributes that will persist in the graph module even if they are not used in + the model's ``forward`` method. + """ + self.preserved_attributes = attributes + return self + + # TODO: remove this + @classmethod + def from_dict( + cls, convert_custom_config_dict: dict[str, Any] + ) -> ConvertCustomConfig: + """ + Create a ``ConvertCustomConfig`` from a dictionary with the following items: + + "observed_to_quantized_custom_module_class": a nested dictionary mapping from quantization + mode to an inner mapping from observed module classes to quantized module classes, e.g.:: + { + "static": {FloatCustomModule: ObservedCustomModule}, + "dynamic": {FloatCustomModule: ObservedCustomModule}, + "weight_only": {FloatCustomModule: ObservedCustomModule} + } + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` + + This function is primarily for backward compatibility and may be removed in the future. + """ + conf = cls() + for quant_type_name, custom_module_mapping in convert_custom_config_dict.get( + OBSERVED_TO_QUANTIZED_DICT_KEY, {} + ).items(): + quant_type = _quant_type_from_str(quant_type_name) + for observed_class, quantized_class in custom_module_mapping.items(): + conf.set_observed_to_quantized_mapping( + observed_class, quantized_class, quant_type + ) + conf.set_preserved_attributes( + convert_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []) + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``ConvertCustomConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. + """ + d: dict[str, Any] = {} + for ( + quant_type, + observed_to_quantized_mapping, + ) in self.observed_to_quantized_mapping.items(): + if OBSERVED_TO_QUANTIZED_DICT_KEY not in d: + d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {} + d[OBSERVED_TO_QUANTIZED_DICT_KEY][_get_quant_type_to_str(quant_type)] = ( + observed_to_quantized_mapping + ) + if len(self.preserved_attributes) > 0: + d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes + return d + + +class FuseCustomConfig: + """ + Custom configuration for :func:`~torch.ao.quantization.quantize_fx.fuse_fx`. + + Example usage:: + + fuse_custom_config = FuseCustomConfig().set_preserved_attributes( + ["attr1", "attr2"] + ) + """ + + def __init__(self) -> None: + self.preserved_attributes: list[str] = [] + + def __repr__(self): + dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0} + return f"FuseCustomConfig({dict_nonempty})" + + def set_preserved_attributes(self, attributes: list[str]) -> FuseCustomConfig: + """ + Set the names of the attributes that will persist in the graph module even if they are not used in + the model's ``forward`` method. + """ + self.preserved_attributes = attributes + return self + + # TODO: remove this + @classmethod + def from_dict(cls, fuse_custom_config_dict: dict[str, Any]) -> FuseCustomConfig: + """ + Create a ``ConvertCustomConfig`` from a dictionary with the following items: + + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` + + This function is primarily for backward compatibility and may be removed in the future. + """ + conf = cls() + conf.set_preserved_attributes( + fuse_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []) + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``FuseCustomConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. + """ + d: dict[str, Any] = {} + if len(self.preserved_attributes) > 0: + d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes + return d diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4ee15779a180ea88c7dda47c7e6a45da092714 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse.py @@ -0,0 +1,195 @@ +# mypy: allow-untyped-defs +import warnings +from collections.abc import Callable +from typing import Any + +from torch.ao.quantization.backend_config import ( + BackendConfig, + get_native_backend_config, +) +from torch.ao.quantization.backend_config.utils import ( + get_fuser_method_mapping, + get_fusion_pattern_to_extra_inputs_getter, + get_fusion_pattern_to_root_node_getter, +) +from torch.ao.quantization.utils import NodePattern, Pattern +from torch.fx import GraphModule, map_arg, Node +from torch.fx.graph import Graph + +from .custom_config import FuseCustomConfig +from .fuse_handler import _get_fusion_pattern_to_fuse_handler_cls, FuseHandler +from .match_utils import _is_match, MatchAllNode +from .pattern_utils import _sorted_patterns_dict + + +__all__ = [ + "fuse", + # TODO: We should make this private in the future + # This is currently needed for test_public_bindings for some reason + "FuseHandler", +] + + +def fuse( + model: GraphModule, + is_qat: bool, + fuse_custom_config: FuseCustomConfig | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, +) -> GraphModule: + if fuse_custom_config is None: + fuse_custom_config = FuseCustomConfig() + + if isinstance(fuse_custom_config, dict): + warnings.warn( + "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " + "in a future version. Please pass in a FuseCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) + fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) + + if isinstance(backend_config, dict): + warnings.warn( + "Passing a backend_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a BackendConfig instead.", + FutureWarning, + stacklevel=2, + ) + backend_config = BackendConfig.from_dict(backend_config) + + named_modules = dict(model.named_modules()) + + if backend_config is None: + backend_config = get_native_backend_config() + + fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict( + _get_fusion_pattern_to_fuse_handler_cls(backend_config) + ) + fuser_method_mapping = get_fuser_method_mapping(backend_config) + fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter( + backend_config + ) + fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter( + backend_config + ) + + # find fusion + fusion_pairs = _find_matches(model, model.graph, fusion_pattern_to_fuse_handler_cls) + # TODO: change this to inplace changes to graph, since we no longer construct + # new GraphModule anymore + fused_graph = Graph() + env: dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node.name]) + + def default_root_node_getter(node_pattern): + while not isinstance(node_pattern[-1], Node): + node_pattern = node_pattern[-1] + return node_pattern[-1] + + for node in model.graph.nodes: + ( + maybe_last_node, + pattern, + matched_node_pattern, + obj, + node_to_subpattern, + ) = fusion_pairs.get(node.name, (None, None, None, None, None)) + # get the corresponding subpattern for the current node + if node_to_subpattern is not None: + node_subpattern = node_to_subpattern.get(node, None) + else: + node_subpattern = None + if maybe_last_node is node: + if obj is None: + raise AssertionError( + "fuse handler object must not be None for matched root node" + ) + root_node_getter = fusion_pattern_to_root_node_getter.get( + pattern, default_root_node_getter + ) + root_node = root_node_getter(matched_node_pattern) # type: ignore[index] + extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get( + pattern, None + ) + extra_inputs = [] + if extra_inputs_getter is not None: + extra_inputs = extra_inputs_getter(matched_node_pattern) + # TODO: add validation that root_node is a module and has the same type + # as the root_module in the configuration + env[node.name] = obj.fuse( + load_arg, + named_modules, + fused_graph, + root_node, + extra_inputs, + matched_node_pattern, # type: ignore[arg-type] + fuse_custom_config, + fuser_method_mapping, + is_qat, + ) + elif maybe_last_node is None or node_subpattern is MatchAllNode: + env[node.name] = fused_graph.node_copy(node, load_arg) + # node matched in patterns and is not root is removed here + + model = GraphModule(model, fused_graph) + return model + + +def _find_matches( + root: GraphModule, + graph: Graph, + pattern_to_fuse_handler_cls: dict[Pattern, Callable], +) -> dict[str, tuple[Node, Pattern, NodePattern, FuseHandler, dict[Node, Any]]]: + modules = dict(root.named_modules()) + # node name -> (root_node, match_value) + match_map: dict[ + str, tuple[Node, Pattern, NodePattern, FuseHandler, dict[Node, Any]] + ] = {} + # a map from node to the matched subpattern + node_to_subpattern: dict[Node, Any] = {} + + # TODO: dedup with quantization matching function in match_utils.py + def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern): + if isinstance(pattern, tuple): + s, *args = pattern + current_node_pattern: list[Node] = [] + apply_match(s, node, match, current_node_pattern, node_to_subpattern) + for subpattern, arg in zip(args, node.args): + apply_match( + subpattern, arg, match, current_node_pattern, node_to_subpattern + ) + matched_node_pattern.append(tuple(current_node_pattern)) + else: + # the first pattern matches will take precedence + if node.name not in match_map: + matched_node_pattern.append(node) + # MatchAllNode here is actually MatchAllInputNode which should not + # be added to match_map + if pattern is not MatchAllNode: + node_to_subpattern[node] = pattern + root_node, pattern, handler = match + match_map[node.name] = ( + root_node, + pattern, + matched_node_pattern, + handler, + node_to_subpattern, + ) + + for node in reversed(graph.nodes): + if node.name not in match_map: + for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items(): + matched_node_pattern: list[Node] = [] + if _is_match(modules, node, pattern): + apply_match( + pattern, + node, + (node, pattern, fuse_handler_cls(node)), + matched_node_pattern, + node_to_subpattern, + ) + break + + return match_map diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse_handler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..76fe84c2c3ad5fd88303d5f04e83c3dccfd24e5a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse_handler.py @@ -0,0 +1,129 @@ +# mypy: allow-untyped-defs +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any + +import torch +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.fuser_method_mappings import get_fuser_method_new +from torch.ao.quantization.utils import _parent_name, NodePattern, Pattern +from torch.fx.graph import Graph, Node +from torch.nn.utils.parametrize import type_before_parametrizations + +from .custom_config import FuseCustomConfig +from .match_utils import MatchAllNode + + +__all__ = [ + "DefaultFuseHandler", + "FuseHandler", +] + + +# ---------------------------- +# Fusion Pattern Registrations +# ---------------------------- + + +# Base Pattern Handler +class FuseHandler(ABC): + """Base handler class for the fusion patterns""" + + @abstractmethod + def __init__(self, node: Node): + pass + + @abstractmethod + def fuse( + self, + load_arg: Callable, + named_modules: dict[str, torch.nn.Module], + fused_graph: Graph, + root_node: Node, + extra_inputs: list[Any], + matched_node_pattern: NodePattern, + fuse_custom_config: FuseCustomConfig, + fuser_method_mapping: dict[Pattern, torch.nn.Sequential | Callable], + is_qat: bool, + ) -> Node: + pass + + +class DefaultFuseHandler(FuseHandler): + def __init__(self, node: Node): # pylint: disable=useless-parent-delegation + super().__init__(node) # type:ignore[safe-super] + + def fuse( + self, + load_arg: Callable, + named_modules: dict[str, torch.nn.Module], + fused_graph: Graph, + root_node: Node, + extra_inputs: list[Any], + matched_node_pattern: NodePattern, + fuse_custom_config: FuseCustomConfig, + fuser_method_mapping: dict[Pattern, torch.nn.Sequential | Callable], + is_qat: bool, + ) -> Node: + if root_node.op != "call_module": + raise AssertionError("Expecting module node to be a call_module Node") + root_module = named_modules[str(root_node.target)] + + def get_modules(pattern): + """Given a node pattern, extract the corresponding modules + e.g. input: (relu_node, (bn_node, conv_node)) + output: (relu_module, (bn_module, conv_module)) + """ + if isinstance(pattern, (tuple, list)): + n, *args = pattern + modules: list[torch.nn.Module] = [] + modules.append(get_modules(n)) + modules.extend(get_modules(a) for a in args) + return tuple(modules) + else: + n = pattern + if n.op == "call_module": + return named_modules[n.target] + elif n.op == "call_function" and n.target is torch.nn.functional.relu: + relu = torch.nn.ReLU() + relu.training = root_module.training + return relu + elif n.op == "call_function" or n.op == "call_method": + return n.target + else: + return MatchAllNode + + # since relu can be used multiple times, we'll need to create a relu module for each match + matched_modules = get_modules(matched_node_pattern) + + def get_matched_types(m): + if isinstance(m, tuple): + return tuple(map(get_matched_types, m)) + if isinstance(m, torch.nn.Module): + return type_before_parametrizations(m) + return m + + matched_module_types = get_matched_types(matched_modules) + module_parent_name, module_name = _parent_name(root_node.target) + fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping) + # TODO: change the signature for fuser_method to take matched module patterns + # as input + fused_module = fuser_method(is_qat, *matched_modules) + setattr(named_modules[module_parent_name], module_name, fused_module) + extra_args = [load_arg(input) for input in extra_inputs] + node = fused_graph.node_copy(root_node, load_arg) + args = list(node.args) + args.extend(extra_args) + node.args = tuple(args) + return node + + +def _get_fusion_pattern_to_fuse_handler_cls( + backend_config: BackendConfig, +) -> dict[Pattern, Callable]: + fusion_pattern_to_fuse_handlers: dict[Pattern, Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config.fuser_method is not None: + # TODO: is this logic right? + fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler + return fusion_pattern_to_fuse_handlers diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/graph_module.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..87ec3179a68ee26a5b2199c3f7543fdfd73e2864 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/graph_module.py @@ -0,0 +1,205 @@ +# mypy: allow-untyped-defs +import copy +from typing import Any + +import torch +from torch.fx import GraphModule +from torch.fx.graph import Graph + + +__all__ = [ + "FusedGraphModule", + "ObservedGraphModule", + "ObservedStandaloneGraphModule", + "QuantizedGraphModule", +] + + +class FusedGraphModule(GraphModule): + def __init__( + self, + root: torch.nn.Module | dict[str, Any], + graph: Graph, + preserved_attr_names: set[str], + ): + self.preserved_attr_names = preserved_attr_names + preserved_attrs = { + attr: getattr(root, attr) + for attr in self.preserved_attr_names + if hasattr(root, attr) + } + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + + # GraphModule does not copy attributes which are not in the __dict__ + # of vanilla nn.Module. So, we override __deepcopy__ in order + # to copy the quantization specific attributes correctly. + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return FusedGraphModule( + fake_mod, + copy.deepcopy(self.graph), + copy.deepcopy(self.preserved_attr_names), + ) + + +class ObservedGraphModule(GraphModule): + def __init__( + self, + root: torch.nn.Module | dict[str, Any], + graph: Graph, + preserved_attr_names: set[str], + ): + self.preserved_attr_names = { + "_activation_post_process_map", + "_activation_post_process_indexes", + "_patterns", + "_node_name_to_qconfig", + "_prepare_custom_config", + "_equalization_node_name_to_qconfig", + "_node_name_to_scope", + "_qconfig_mapping", + "_is_qat", + "_observed_node_names", + }.union(preserved_attr_names) + preserved_attrs = { + attr: getattr(root, attr) + for attr in self.preserved_attr_names + if hasattr(root, attr) + } + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + + # GraphModule does not copy attributes which are not in the __dict__ + # of vanilla nn.Module. So, we override __deepcopy__ in order + # to copy the quantization specific attributes correctly. + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return ObservedGraphModule( + fake_mod, + copy.deepcopy(self.graph), + copy.deepcopy(self.preserved_attr_names), + ) + + +def _is_observed_module(module: Any) -> bool: + return hasattr(module, "meta") and "_observed_graph_module_attrs" in module.meta + + +def _get_observed_graph_module_attr( + model: torch.nn.Module | GraphModule, attr_name: str +) -> Any: + if hasattr(model, "meta") and "_observed_graph_module_attrs" in model.meta: # type: ignore[operator, index] + return getattr(model.meta["_observed_graph_module_attrs"], attr_name) # type: ignore[index] + return None + + +class ObservedStandaloneGraphModule(ObservedGraphModule): + def __init__( + self, + root: torch.nn.Module | dict[str, Any], + graph: Graph, + preserved_attr_names: set[str], + ): + preserved_attr_names = preserved_attr_names.union( + { + "_standalone_module_input_quantized_idxs", + "_standalone_module_output_quantized_idxs", + } + ) + super().__init__(root, graph, preserved_attr_names) + + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return ObservedStandaloneGraphModule( + fake_mod, + copy.deepcopy(self.graph), + copy.deepcopy(self.preserved_attr_names), + ) + + +def _is_observed_standalone_module(module: Any) -> bool: + return ( + _is_observed_module(module) + and module.meta["_observed_graph_module_attrs"].is_observed_standalone_module + ) + + +def _save_packed_weight(self, destination, prefix, keep_vars): + for attr_name in dir(self): + if "_packed_weight" in attr_name and isinstance( + getattr(self, attr_name), torch._C.ScriptObject + ): # type: ignore[attr-defined] + packed_weight = getattr(self, attr_name) + destination[prefix + attr_name] = packed_weight + + +class QuantizedGraphModule(GraphModule): + """This class is created to make sure PackedParams + (e.g. LinearPackedParams, Conv2dPackedParams) to appear in state_dict + so that we can serialize and deserialize quantized graph module with + torch.save(m.state_dict()) and m.load_state_dict(state_dict) + """ + + def __init__( + self, + root: torch.nn.Module | dict[str, Any], + graph: Graph, + preserved_attr_names: set[str], + ): + self.preserved_attr_names = preserved_attr_names + preserved_attrs = { + attr: getattr(root, attr) + for attr in self.preserved_attr_names + if hasattr(root, attr) + } + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + self._register_state_dict_hook(_save_packed_weight) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + attrs_to_pop = [] + for attr_name in state_dict: + if attr_name.startswith("_packed_weight") and isinstance( + state_dict[attr_name], torch._C.ScriptObject + ): # type: ignore[attr-defined] # noqa: B950 + setattr(self, attr_name, state_dict[attr_name]) + attrs_to_pop.append(attr_name) + + # pop the packed param attributesn + for attr_name in attrs_to_pop: + state_dict.pop(attr_name) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return QuantizedGraphModule( + fake_mod, + copy.deepcopy(self.graph), + copy.deepcopy(self.preserved_attr_names), + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py new file mode 100644 index 0000000000000000000000000000000000000000..73fd3e8741b6d6c26d5a352d25d4cf6986de4d9d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py @@ -0,0 +1,21 @@ +from torch.ao.quantization.qconfig import QConfigAny +from torch.fx import GraphModule + +from ._lower_to_native_backend import _lower_to_native_backend + + +__all__ = ["lower_to_fbgemm"] + + +def lower_to_fbgemm( + model: GraphModule, + qconfig_map: dict[str, QConfigAny], + node_name_to_scope: dict[str, tuple[str, type]], + keep_original_weights: bool = False, +) -> GraphModule: + """Lower a quantized reference model (with reference quantized operator patterns) + to fbgemm + """ + return _lower_to_native_backend( + model, qconfig_map, node_name_to_scope, keep_original_weights + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fa3ecf3f5a3b2b5dc67d769853f8424bae7efb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py @@ -0,0 +1,18 @@ +from torch.ao.quantization.qconfig import QConfigAny +from torch.fx import GraphModule + +from ._lower_to_native_backend import _lower_to_native_backend + + +__all__ = ["lower_to_qnnpack"] + + +def lower_to_qnnpack( + model: GraphModule, + qconfig_map: dict[str, QConfigAny], + node_name_to_scope: dict[str, tuple[str, type]], +) -> GraphModule: + """Lower a quantized reference model (with reference quantized operator patterns) + to qnnpack + """ + return _lower_to_native_backend(model, qconfig_map, node_name_to_scope) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/lstm_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/lstm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..78849692a45efab6b7ce3af208ee16f6d77286c6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/lstm_utils.py @@ -0,0 +1,228 @@ +import copy +import operator +from typing import Any, TYPE_CHECKING + +import torch +from torch.ao.quantization import ( + default_weight_fake_quant, + default_weight_observer, + FakeQuantizeBase, + QConfig, + QConfigMapping, +) +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.observer import _PartialWrapper +from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx + + +if TYPE_CHECKING: + from collections.abc import Callable + + +# TODO: move all LSTM util functions from fx/utils.py to this file +def _get_lstm_with_individually_observed_parts( + float_lstm: torch.nn.LSTM, + example_inputs: tuple[Any, ...], + backend_config: BackendConfig | None = None, + linear_output_obs_ctr: _PartialWrapper | None = None, + sigmoid_obs_ctr: _PartialWrapper | None = None, + tanh_obs_ctr: _PartialWrapper | None = None, + cell_state_obs_ctr: _PartialWrapper | None = None, + hidden_state_obs_ctr: _PartialWrapper | None = None, + split_gates: bool = False, +) -> torch.ao.nn.quantizable.LSTM: + """ + Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM` + with specific observers or fake quantizes assigned to the inner ops or submodules. + + In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is + used as an observed custom module, which is responsible for inserting its own + observers. By default, all inner ops inherit the parent custom module's QConfig. + Users who wish to override this behavior may extend `torch.ao.nn.quantizable.LSTM` + and use this helper function to customize the observer insertion logic. + + This is meant to be used to convert a float module to an observed module in the + custom module flow. + + Args: + `float_lstm`: The float LSTM module + `example_inputs`: example inputs for the forward function of the LSTM module + `backend_config`: BackendConfig to use to observe the LSTM module + `linear_output_obs_ctr`: observer or fake quantize for linear outputs Wx + b, + where W is the weight matrix, b is the bias, and x is either the inputs + or the hidden state from the previous layer (if any) + `sigmoid_obs_ctr`: observer or fake quantize for sigmoid activations + `tanh_obs_ctr`: observer or fake quantize for tanh activations + `cell_state_obs_ctr`: observer or fake quantize for the cell state + `hidden_state_obs_ctr`: observer or fake quantize for the hidden state and + the output + + Return: + A `torch.ao.nn.quantizable.LSTM` with the specified observers or fake quantizes + assigned to the inner ops. + """ + + def make_qconfig(obs_ctr: _PartialWrapper) -> QConfig: + """ + Make a QConfig with fixed qparams observers or fake quantizes. + """ + if isinstance(obs_ctr(), FakeQuantizeBase): + weight = default_weight_fake_quant + else: + weight = default_weight_observer + return QConfig(activation=obs_ctr, weight=weight) + + quantizable_lstm = torch.ao.nn.quantizable.LSTM( + float_lstm.input_size, + float_lstm.hidden_size, + float_lstm.num_layers, + float_lstm.bias, + float_lstm.batch_first, + float_lstm.dropout, + float_lstm.bidirectional, + split_gates=split_gates, + ) + quantizable_lstm.qconfig = float_lstm.qconfig + + for idx in range(float_lstm.num_layers): + quantizable_lstm.layers[idx] = ( + torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float( + float_lstm, + idx, + float_lstm.qconfig, + batch_first=False, + split_gates=split_gates, + ) + ) + + # Build QConfigMapping for the LSTM cell + # Note: FloatFunctional qconfigs will be configured separately below + cell_qm = QConfigMapping().set_global(float_lstm.qconfig) # type: ignore[arg-type] + if sigmoid_obs_ctr is not None: + cell_qm.set_module_name("input_gate", make_qconfig(sigmoid_obs_ctr)) + cell_qm.set_module_name("forget_gate", make_qconfig(sigmoid_obs_ctr)) + cell_qm.set_module_name("output_gate", make_qconfig(sigmoid_obs_ctr)) + if tanh_obs_ctr is not None: + cell_qm.set_module_name("cell_gate", make_qconfig(tanh_obs_ctr)) + + # Insert observers into each LSTM cell + # TODO: maybe make this work for layer_bw as well + for layer in quantizable_lstm.layers: + cell = layer.layer_fw.cell # type: ignore[union-attr] + if not isinstance(cell, torch.nn.Module): + raise AssertionError("cell should be a nn.Module") + cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config) + # HACK: Manually replace the activation_post_process following these ops. + # This is needed for FloatFunctional ops because there is currently no way + # to configure these ops in FX graph mode quantization today. This is because + # the FloatFunctional modules simply disappear from the graph after tracing. + # In the future, we should rewrite quantizable LSTM without FloatFunctionals. + if not split_gates: + op_index_to_activation_post_process_ctr = { + (torch.add, 0): linear_output_obs_ctr, # gates.add + (torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul + (torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul + (torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add + (torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul + } + else: + op_index_to_activation_post_process_ctr = { + (torch.add, 0): linear_output_obs_ctr, # gates.add (input) + (torch.add, 1): linear_output_obs_ctr, # gates.add (forget) + (torch.add, 2): linear_output_obs_ctr, # gates.add (cell) + (torch.add, 3): linear_output_obs_ctr, # gates.add (output) + (torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul + (torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul + (torch.add, 4): cell_state_obs_ctr, # fgate_cx_igate_cgate.add + (torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul + } + add_count = 0 + mul_count = 0 + for node in cell.graph.nodes: + op_index: tuple[Callable, int] | None = None # e.g. (torch.add, 1) + if node.target is torch.add: + op_index = (torch.add, add_count) + add_count += 1 + elif node.target is torch.mul: + op_index = (torch.mul, mul_count) + mul_count += 1 + else: + # Neither torch.add nor torch.mul + continue + if op_index not in op_index_to_activation_post_process_ctr: + continue + if len(node.users) != 1: + raise AssertionError("expected exactly one user for the node") + activation_post_process_name = next(iter(node.users.keys())).name + activation_post_process_ctr = op_index_to_activation_post_process_ctr[ + op_index + ] + if activation_post_process_ctr is not None: + setattr( + cell, activation_post_process_name, activation_post_process_ctr() + ) + layer.layer_fw.cell = cell # type: ignore[union-attr] + return quantizable_lstm + + +def _get_reference_quantized_lstm_module( + observed_lstm: torch.ao.nn.quantizable.LSTM, + backend_config: BackendConfig | None = None, +) -> torch.ao.nn.quantized.LSTM: + """ + Return a `torch.ao.nn.quantized.LSTM` created from a `torch.ao.nn.quantizable.LSTM` + with observers or fake quantizes inserted through `prepare_fx`, e.g. from + `_get_lstm_with_individually_observed_parts`. + + This is meant to be used to convert an observed module to a quantized module in the + custom module flow. + + Args: + `observed_lstm`: a `torch.ao.nn.quantizable.LSTM` observed through `prepare_fx` + `backend_config`: BackendConfig to use to produce the reference quantized model + + Return: + A reference `torch.ao.nn.quantized.LSTM` module. + """ + quantized_lstm = torch.ao.nn.quantized.LSTM( + observed_lstm.input_size, + observed_lstm.hidden_size, + observed_lstm.num_layers, + observed_lstm.bias, + observed_lstm.batch_first, + observed_lstm.dropout, + observed_lstm.bidirectional, + ) + + for i, layer in enumerate(quantized_lstm.layers): + cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell) # type: ignore[union-attr] + cell = convert_to_reference_fx(cell, backend_config=backend_config) # type: ignore[arg-type] + if not isinstance(cell, torch.fx.GraphModule): + raise AssertionError("cell must be converted to a torch.fx.GraphModule") + # HACK: Manually remove input quantize nodes and output dequantize nodes, + # since custom modules expect quint8 inputs and outputs for now. Note that + # this functionality is supposedly handled through PrepareCustomConfig's + # `set_input_quantized_indexes` and `set_output_quantized_indexes`, but that + # API doesn't currently handle tuple inputs and outputs, so we have to do + # this manually for now. In the future we should (1) relax the restriction + # on custom module input/output dtypes, and (2) expand support for complex + # input/output structures. + for node in cell.graph.nodes: + if node.target is torch.quantize_per_tensor: + arg = node.args[0] + # Remove quantize(x), quantize(hidden[0]), and quantize(hidden[1]) + if arg.target == "x" or ( + arg.target is operator.getitem and arg.args[0].target == "hidden" + ): + with cell.graph.inserting_before(node): + node.replace_all_uses_with(arg) + cell.graph.erase_node(node) + if node.target == "output": + # Remove all dequantize nodes in the output tuple + for arg in node.args[0]: + with cell.graph.inserting_before(node): + node.replace_input_with(arg, arg.args[0]) + cell.graph.eliminate_dead_code() + cell.recompile() + layer.layer_fw.cell = cell # type: ignore[union-attr] + return quantized_lstm diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/match_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/match_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..79194caa4a17b9f2db99981b184081d09df80e84 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/match_utils.py @@ -0,0 +1,231 @@ +# mypy: allow-untyped-defs +import sys +from collections.abc import Callable, Iterable +from typing import Any + +import torch +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization.utils import MatchAllNode, Pattern +from torch.fx.graph import Graph, Node +from torch.nn.utils.parametrize import type_before_parametrizations + +from .graph_module import _is_observed_standalone_module +from .quantize_handler import QuantizeHandler + + +__all__: list[str] = [] + +# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type +# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]` +_MatchResult = tuple[Node, list[Node], Pattern | None, QuantizeHandler] + +_MatchResultWithQConfig = tuple[ + Node, list[Node], Pattern | None, QuantizeHandler, QConfigAny +] + + +# Note: The order of patterns is important! match function will take whatever is matched first, so we'll +# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu. +# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns, +# we'll start from the last node of the graph and traverse back. +def _is_match(modules, node, pattern, max_uses=sys.maxsize): + """Matches a node in fx against a pattern""" + if isinstance(pattern, tuple): + self_match, *arg_matches = pattern + if self_match is getattr: + if len(pattern) != 2: + raise AssertionError("Expecting getattr pattern to have two elements") + arg_matches = [] + else: + self_match = pattern + arg_matches = [] + + if isinstance(self_match, type) and issubclass(self_match, MatchAllNode): + return True + + if node == pattern: + return True + + if not isinstance(node, Node) or len(node.users) > max_uses: + return False + + if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): + if node.op != "call_module": + return False + if type_before_parametrizations(modules[node.target]) != self_match: + return False + elif callable(self_match): + if node.op != "call_function" or node.target is not self_match: + return False + elif node.target is getattr: + if node.args[1] != pattern[1]: + return False + elif isinstance(self_match, str): + if node.op != "call_method" or node.target != self_match: + return False + elif node.target != self_match: + return False + + if not arg_matches: + return True + + if len(arg_matches) != len(node.args): + return False + + return all( + _is_match(modules, node, arg_match, max_uses=1) + for node, arg_match in zip(node.args, arg_matches) + ) + + +def _find_matches( + graph: Graph, + modules: dict[str, torch.nn.Module], + patterns: dict[Pattern, QuantizeHandler], + root_node_getter_mapping: dict[Pattern, Callable], + standalone_module_names: list[str] | None = None, + standalone_module_classes: list[type] | None = None, + custom_module_classes: list[Any] | None = None, +) -> dict[str, _MatchResult]: + """ + Matches the nodes in the input graph to quantization patterns, and + outputs the information needed to quantize them in future steps. + + Inputs: + - graph: an fx.Graph object + - modules: a mapping of fully qualified module name to instance, + for example, {'foo': ModuleFoo, ...} + - patterns: a mapping from a tuple of nodes in reverse order to + uninitialized QuantizeHandler subclass. + + Outputs a map of + node_name -> + (node, matched_values, matched_pattern, QuantizeHandler instance, + qconfig) + + For example, { + 'relu_1': (relu_1, [relu_1], torch.nn.functional.relu, + , QConfig(...)), + ... + } + """ + if custom_module_classes is None: + custom_module_classes = [] + + if standalone_module_classes is None: + standalone_module_classes = [] + + if standalone_module_names is None: + standalone_module_names = [] + + match_map: dict[str, _MatchResult] = {} + all_matched: set[str] = set() + + def _recursive_record_node_in_match_map( + last_node, match_map, node_pattern, matched_node_pattern, pattern, match_value + ): + if isinstance(node_pattern, Node): + match_map[node_pattern.name] = ( + last_node, + matched_node_pattern, + pattern, + match_value, + ) + elif not isinstance(node_pattern, Iterable): + return + else: + for n in node_pattern: + _recursive_record_node_in_match_map( + last_node, match_map, n, matched_node_pattern, pattern, match_value + ) + + # TODO: 1. merge with fuse matcher 2. document the code + def record_match(pattern, node, last_node, matched_node_pattern, match_map): + if isinstance(pattern, tuple): + s, *args = pattern + is_single_arg = len(args) == 1 + current_node_pattern: list[Node] = [] + record_match(s, node, last_node, matched_node_pattern, match_map) + if pattern[0] is not getattr: + for subpattern, arg in zip(args, node.args): + record_match(subpattern, arg, node, current_node_pattern, match_map) + if len(current_node_pattern) > 1: + # current_node_pattern is the node pattern we get from matching + # the subpattern with arguments of the node + # we use is_single_arg to recover the original structure of the pattern + # if the original pattern has a single argument, we will have + # (original_op, (original_arg, ...)) + # otherwise, we'll have a list of arguments + # (original_op, arg0, arg1, arg2, ...) + if is_single_arg: + matched_node_pattern.append(tuple(current_node_pattern)) + else: + matched_node_pattern.extend(list(current_node_pattern)) + else: + matched_node_pattern.append(current_node_pattern[0]) + else: + matched_node_pattern.append(node) + + for node in reversed(graph.nodes): + if node.name not in match_map and node.name not in all_matched: + for pattern, quantize_handler_cls in patterns.items(): + root_node_getter = root_node_getter_mapping.get(pattern) + if _is_match(modules, node, pattern) and node.name not in match_map: + matched_node_pattern: list[Node] = [] + record_match(pattern, node, node, matched_node_pattern, match_map) + quantize_handler = quantize_handler_cls( # type: ignore[operator] + matched_node_pattern, modules, root_node_getter + ) + last_node = node + # record the match for all nodes in the pattern + _recursive_record_node_in_match_map( + last_node, + match_map, + # we need to record all nodes in the matched pattern in the match_map + matched_node_pattern, + # this is a part of the value corresponding to the node + matched_node_pattern, + pattern, + quantize_handler, + ) + break + + # add custom module instances to the match result + if modules is None: + raise AssertionError("modules must not be None") + for node in graph.nodes: + if ( + node.op == "call_module" + and type(modules[node.target]) in custom_module_classes + ): + match_map[node.name] = ( + node, + node, + None, + QuantizeHandler(node, modules, is_custom_module=True), + ) + + def is_standalone_module(node_target: str, modules: dict[str, torch.nn.Module]): + if modules is None: + raise AssertionError("modules must not be None") + return ( + node_target in standalone_module_names + or type(modules[node_target]) # type: ignore[operator] + in standalone_module_classes # type: ignore[operator] + ) + + # add standalone modules to the match + for node in graph.nodes: + if node.op == "call_module" and ( + is_standalone_module(node.target, modules) + or _is_observed_standalone_module(modules[node.target]) + ): + # add node to matched nodes + match_map[node.name] = ( + node, + node, + None, + QuantizeHandler(node, modules, is_standalone_module=True), + ) + + return match_map diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/pattern_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/pattern_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e86f95d67aba092daff6a3a14a14767f29d249a2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/pattern_utils.py @@ -0,0 +1,112 @@ +# mypy: allow-untyped-defs +import copy +from collections import OrderedDict +from typing import Any + +from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize +from torch.ao.quantization.observer import ObserverBase +from torch.ao.quantization.utils import Pattern + + +__all__ = [ + "get_default_fusion_patterns", + "get_default_quant_patterns", + "get_default_output_activation_post_process_map", +] + +# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency) +QuantizeHandler = Any + +# pattern for conv bn fusion +_DEFAULT_FUSION_PATTERNS: dict[Pattern, QuantizeHandler] = OrderedDict() + + +def _register_fusion_pattern(pattern): + def insert(fn): + _DEFAULT_FUSION_PATTERNS[pattern] = fn + return fn + + return insert + + +def get_default_fusion_patterns() -> dict[Pattern, QuantizeHandler]: + return copy.copy(_DEFAULT_FUSION_PATTERNS) + + +_DEFAULT_QUANTIZATION_PATTERNS: dict[Pattern, QuantizeHandler] = OrderedDict() + +# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation +# e.g. pattern: torch.sigmoid, +# output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant +_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP: dict[Pattern, QuantizeHandler] = {} +_DEFAULT_OUTPUT_OBSERVER_MAP: dict[Pattern, QuantizeHandler] = {} + + +# Register pattern for both static quantization and qat +def _register_quant_pattern(pattern, fixed_qparams_observer=None): + def insert(fn): + _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn + if fixed_qparams_observer is not None: + _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[pattern] = ( + FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer) + ) + _DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer + return fn + + return insert + + +# Get patterns for both static quantization and qat +def get_default_quant_patterns() -> dict[Pattern, QuantizeHandler]: + return copy.copy(_DEFAULT_QUANTIZATION_PATTERNS) + + +# a map from pattern to output activation post process constructor +# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant +def get_default_output_activation_post_process_map( + is_training, +) -> dict[Pattern, ObserverBase]: + if is_training: + return copy.copy(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP) + else: + return copy.copy(_DEFAULT_OUTPUT_OBSERVER_MAP) + + +# Example use of register pattern function: +# @_register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) +# class ConvOrLinearBNReLUFusion(): +# def __init__(...): +# ... +# + + +def _sorted_patterns_dict( + patterns_dict: dict[Pattern, QuantizeHandler], +) -> dict[Pattern, QuantizeHandler]: + """ + Return a sorted version of the patterns dictionary such that longer patterns are matched first, + e.g. match (F.relu, F.linear) before F.relu. + This works for current use cases, but we may need to have a more clever way to sort + things to address more complex patterns + """ + + def get_len(pattern): + """this will calculate the length of the pattern by counting all the entries + in the pattern. + this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before + (nn.BatchNorm, nn.Conv2d) so that we can match the former first + """ + len = 0 + if isinstance(pattern, tuple): + for item in pattern: + len += get_len(item) + else: + len += 1 + return len + + return OrderedDict( + sorted( + patterns_dict.items(), + key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1, + ) + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/prepare.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..0c2fab3f27eb917b22368dae04cd908f2f81a7c2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/prepare.py @@ -0,0 +1,2251 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from dataclasses import asdict +from typing import Any + +import torch +from torch._subclasses import FakeTensor +from torch.ao.quantization import ( + _DerivedObserverOrFakeQuantize, + FixedQParamsFakeQuantize, + FixedQParamsObserver, + ObserverBase, + ObserverOrFakeQuantize, + PlaceholderObserver, +) +from torch.ao.quantization.backend_config import ( + BackendConfig, + DTypeConfig, + get_native_backend_config, +) +from torch.ao.quantization.backend_config.utils import ( + get_fusion_pattern_to_root_node_getter, + get_module_to_qat_module, + get_pattern_to_dtype_configs, +) +from torch.ao.quantization.observer import _is_activation_post_process, _PartialWrapper +from torch.ao.quantization.qconfig import _is_reuse_input_qconfig, QConfigAny +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quantize import convert, propagate_qconfig_ +from torch.ao.quantization.quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + FixedQParamsQuantizationSpec, + QuantizationSpec, + QuantizationSpecBase, + SharedQuantizationSpec, +) +from torch.ao.quantization.utils import ( + _parent_name, + get_qconfig_dtypes, + get_swapped_custom_module_class, + NodePattern, + Pattern, +) +from torch.fx import GraphModule +from torch.fx.graph import Graph, Node +from torch.fx.node import Argument + +from ._equalize import is_equalization_observer, node_supports_equalization +from .custom_config import PrepareCustomConfig, StandaloneModuleConfigEntry +from .match_utils import _find_matches, _MatchResultWithQConfig +from .pattern_utils import _sorted_patterns_dict +from .qconfig_mapping_utils import ( + _generate_node_name_to_qconfig, + _get_flattened_qconfig_dict, + _update_qconfig_for_fusion, + _update_qconfig_for_qat, +) +from .quantize_handler import ( + _default_root_node_getter, + _get_pattern_to_quantize_handlers, + QuantizeHandler, +) +from .utils import ( + _insert_dequant_stubs_for_custom_module_lstm_output, + _is_custom_module_lstm, + _maybe_get_custom_module_lstm_from_node_arg, + _qconfig_satisfies_dtype_config_constraints, + all_node_args_have_no_tensors, + assert_and_get_unique_device, + get_custom_module_class_keys, + get_new_attr_name_with_prefix, + get_non_observable_arg_indexes_and_types, + node_arg_is_bias, + node_arg_is_weight, + NON_QUANTIZABLE_WEIGHT_OPS, + ObservedGraphModuleAttrs, +) + + +__all__ = [ + "insert_observers_for_model", + "prepare", + "propagate_dtypes_for_known_nodes", +] + + +# list of dtypes to not add observers to +_DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None] +_OBS_DTYPE_LIST = [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.float16, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, +] + +_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float) + +# note: the following default target dtype info dicts are temporary, +# should be moved to the new programmable API class soon +_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = { + "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation, + "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation, +} + +_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = { + "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation, + "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation, +} + + +def _get_observer_kwargs( + quant_spec: QuantizationSpec | FixedQParamsQuantizationSpec, +): + kwargs_dict = asdict(quant_spec) + return copy.deepcopy(kwargs_dict) + + +def _get_qspec_for_arg( + arg: Node, + input_qspec_map: dict[Node, QuantizationSpecBase], + named_modules: dict[str, torch.nn.Module], +) -> QuantizationSpecBase | None: + while _is_activation_post_process_node(arg, named_modules): + arg = arg.args[0] # type: ignore[assignment] + return input_qspec_map.get(arg) + + +def _create_obs_or_fq_from_qspec( + quantization_spec: QuantizationSpecBase | None, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +): + """Create observer or fake quantize objects based on quantization spec + + Args: + quantization_spec: used to store parameters to create the observer or fake quantizer + obs_or_fq_map: this is a map from edge/output to the corresponding observer/fake_quant + instance, it may be reused for different edge/output depending on configuration + """ + if quantization_spec is None: + return None + if isinstance(quantization_spec, SharedQuantizationSpec): + edge_or_node = quantization_spec.edge_or_node + if edge_or_node not in obs_or_fq_map: + raise AssertionError( + "please make sure only refer to edge or node that has " + f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}" + ) + return obs_or_fq_map[edge_or_node] + elif isinstance(quantization_spec, DerivedQuantizationSpec): + # can't use asdict, so not calling get_observer_kwargs here + kwargs = { + "dtype": quantization_spec.dtype, + "derive_qparams_fn": quantization_spec.derive_qparams_fn, + "quant_min": quantization_spec.quant_min, + "quant_max": quantization_spec.quant_max, + "qscheme": quantization_spec.qscheme, + "ch_axis": quantization_spec.ch_axis, + } + edge_or_nodes = quantization_spec.derived_from + obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes] + # pyrefly: ignore [unsupported-operation] + kwargs["obs_or_fqs"] = obs_or_fqs + return _DerivedObserverOrFakeQuantize.with_args(**kwargs)() + elif isinstance(quantization_spec, FixedQParamsQuantizationSpec): + kwargs = _get_observer_kwargs(quantization_spec) + observer_ctr = FixedQParamsObserver.with_args(**kwargs) + if is_qat: + return FixedQParamsFakeQuantize.with_args(observer=observer_ctr)() + else: + return observer_ctr() + + if not isinstance(quantization_spec, QuantizationSpec): + raise AssertionError("quantization_spec must be a QuantizationSpec") + observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr + kwargs = _get_observer_kwargs(quantization_spec) + kwargs.pop("observer_or_fake_quant_ctr") + # we will remove is_dynamic from QuantizationSpec because + # it seems that dynamic range quantization + obs_or_fq_class = observer_or_fake_quant_ctr + if isinstance(observer_or_fake_quant_ctr, _PartialWrapper): + obs_or_fq_class = observer_or_fake_quant_ctr.p.func # type: ignore[union-attr, assignment] + if "PerChannel" not in obs_or_fq_class.__name__: # type: ignore[operator, union-attr] + kwargs.pop("ch_axis") + return observer_or_fake_quant_ctr.with_args(**kwargs)() + + +def _needs_obs_or_fq( + prev_output_dtype: Any, + prev_output_is_dynamic: bool, + cur_target_dtype: Any, + cur_target_is_dynamic: bool, + reuse_input_obs_or_fq: bool, + is_zeroth_arg: bool = False, +) -> bool: + """ + note: we will treat "not specified" as torch.float for now + utility function that checks if we should insert an observer or fake quant node + base on the requested dtype for the nodes from user + + is_zeroth_arg: we only dynamically quantize the first arg of the node right now + this should be removed when we enable configuring dynamic quantization + for a specific argument, this can be removed if we deprecate fx graph mode + quantization + + """ + + # need to insert placeholder observer for dynamic quantization so that it can + # be converted to choose_qparams -> q -> dq in convert step + if cur_target_is_dynamic: + if cur_target_dtype not in _OBS_DTYPE_LIST: + raise AssertionError( + f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}" + ) + if prev_output_dtype in _DO_NOT_OBS_DTYPE_LIST: + raise AssertionError( + "prev_output_dtype must not be in _DO_NOT_OBS_DTYPE_LIST" + ) + return is_zeroth_arg + if reuse_input_obs_or_fq: + return False + # non dynamic quantization + if cur_target_dtype in _OBS_DTYPE_LIST: + return ( + prev_output_dtype in _OBS_DTYPE_LIST + [torch.float] + and cur_target_dtype != prev_output_dtype + ) + + # lots of error checking are skipped here for now + return False + + +def _is_activation_post_process_node( + node: Node, named_modules: dict[str, torch.nn.Module] +) -> bool: + return ( + isinstance(node, torch.fx.Node) + and node.op == "call_module" + and _is_activation_post_process(named_modules[str(node.target)]) + ) + + +def _get_dtype_and_is_dynamic( + obs_or_fq: ObserverOrFakeQuantize | None, +) -> tuple[torch.dtype | None, bool]: + """Given a constructor for observer or fake quant module, returns + a Tuple of dtype and is_dynamic + """ + # TODO: instead of instantiating the instance, we can use inspect to get the default args + if obs_or_fq is None: + return None, False + else: + return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False) # type: ignore[return-value] + + +def _is_input_arg_dtype_supported_by_backend( + arg: Argument, + node: Node, + qconfig: QConfigAny, + dtype_config: DTypeConfig, + backend_config: BackendConfig, +) -> bool: + """Check if the configured qconfig for the argument + is supported by the backend or not + """ + if isinstance(arg, (list, tuple)): + return all( + _is_input_arg_dtype_supported_by_backend( + a, node, qconfig, dtype_config, backend_config + ) + for a in arg + ) + if not isinstance(arg, Node): + return True + # TODO: support check for standalone module + is_weight = node_arg_is_weight(node, arg) + is_bias = node_arg_is_bias(node, arg) + is_activation = not is_weight and not is_bias + if is_activation: + input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "input_act_obs_or_fq_ctr" + ) + input_act_obs_or_fq = ( + input_act_obs_or_fq_ctr() if input_act_obs_or_fq_ctr else None + ) + qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic( + input_act_obs_or_fq + ) + # TODO(future PR): remove the cast to bool below after figuring + # out why backend_config has is_dynamic set to None in some cases. + return (dtype_config.input_dtype is None) or ( + dtype_config.input_dtype == qconfig_dtype + and bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic) + and _qconfig_satisfies_dtype_config_constraints( + qconfig, dtype_config.input_dtype_with_constraints + ) + ) + elif is_weight: + # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well + weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "weight_obs_or_fq_ctr", None + ) + weight_obs_or_fq = weight_obs_or_fq_ctr() if weight_obs_or_fq_ctr else None + qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq) + backend_config_weight_dtype = dtype_config.weight_dtype + dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype + qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints( + qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False + ) + return backend_config_weight_dtype is None or ( + dtype_matches and qconfig_satisfies_constraints + ) + else: # bias + # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well + bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "bias_obs_or_fq_ctr", None + ) + bias_obs_or_fq = bias_obs_or_fq_ctr() if bias_obs_or_fq_ctr else None + qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq) + backend_config_bias_dtype = dtype_config.bias_dtype + return ( + backend_config_bias_dtype is None + or qconfig_bias_dtype == backend_config_bias_dtype + ) + + +def _is_output_dtype_supported_by_backend( + node: Node, + qconfig: QConfigAny, + dtype_config: DTypeConfig, +) -> bool: + """Check if the configured qconfig for the output + is supported by the backend or not + """ + # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well + backend_config_output_dtype = dtype_config.output_dtype + # TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend + # from input activation check can be reused here + qconfig_output_dtype = None + output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic( + output_act_obs_or_fq + ) + # TODO: this is a hack because we can only specify one activation_obs_or_fq for + # qconfig (qconfig.activation), and we are only supporting dynamically quantized + # linear op which has fp32 output dtype, this should be removed if we generalize + # the structure of qconfig in the future + if qconfig_output_is_dynamic: + qconfig_output_dtype = torch.float32 + dtype_matches = qconfig_output_dtype == backend_config_output_dtype + qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints( + qconfig, dtype_config.output_dtype_with_constraints + ) + return backend_config_output_dtype is None or ( + dtype_matches and qconfig_satisfies_constraints + ) + + +def _is_observer_in_same_graph( + node: Node, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat, +): + """Check if observer in same graph + when the node output is not fp32 and input is 'placeholder' + the input is assumed to be quantized, so it is observed + in a different place rather than not observed. + """ + node_output_dtype = _get_arg_target_dtype_as_output( + node, named_modules, obs_or_fq_map, is_qat + ) + if len(node.args) > 0 and isinstance(node.args[0], Node): + if ( + node_output_dtype in [torch.quint8, torch.uint8] + and node.args[0].op == "placeholder" + ): + return False + return True + + +def _is_pattern_dtype_config_and_qconfig_supported_by_backend( + pattern: Pattern | None, + matched_node_pattern: list[Node] | None, + qconfig: QConfigAny, + backend_config: BackendConfig, +) -> bool: + """Check if the dtype configuration of a pattern is supported by + the backend or not, and whether the qconfig satisfies constraints + specified in the corresponding dtype config. + """ + if backend_config is None or pattern is None: + return True + if matched_node_pattern is None or len(matched_node_pattern) < 1: + raise AssertionError("matched_node_pattern must be non-empty") + pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) + dtype_configs: list[DTypeConfig] = pattern_to_dtype_configs.get(pattern, []) + pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) + + root_node_getter = pattern_to_root_node_getter.get( + pattern, _default_root_node_getter + ) + root_node = root_node_getter(matched_node_pattern) + input_node = root_node + output_node = matched_node_pattern[0] + for dtype_config in dtype_configs: + # check if arg dtype are supported + supported = True + for arg in list(input_node.args) + list(input_node.kwargs.values()): + supported = supported and _is_input_arg_dtype_supported_by_backend( + arg, input_node, qconfig, dtype_config, backend_config + ) + # check if output dtype is supported + supported = supported and _is_output_dtype_supported_by_backend( + output_node, qconfig, dtype_config + ) + if supported: + return True + return False + + +def _get_standalone_module_configs( + node: Node, + named_modules: dict[str, torch.nn.Module], + prepare_custom_config: PrepareCustomConfig, + parent_qconfig: QConfigAny, + parent_backend_config: BackendConfig | None, +) -> tuple[QConfigMapping, tuple[Any, ...], PrepareCustomConfig, BackendConfig | None]: + """ + Returns the standalone module QConfigMapping and PrepareCustomConfig + for `node`, assuming that the module pointed to by `node` is + a standalone modules. + """ + module_name = str(node.target) + module_type = type(named_modules[module_name]) # type: ignore[index] + # name config has precedence over type config + config_entry = StandaloneModuleConfigEntry(None, (), None, None) + config_entry = prepare_custom_config.standalone_module_classes.get( + module_type, config_entry + ) + config_entry = prepare_custom_config.standalone_module_names.get( + module_name, config_entry + ) + # fallback to use parent module's qconfig if user didn't specify qconfig dict + qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global( + parent_qconfig + ) + example_inputs = config_entry.example_inputs + prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig() + backend_config = config_entry.backend_config or parent_backend_config + return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config) + + +def _qat_swap_modules( + root: torch.nn.Module, module_to_qat_module: dict[Pattern, type[torch.nn.Module]] +) -> None: + convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False) + + +def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: set[str]): + if isinstance(matched_node_pattern, Node): + s.add(matched_node_pattern.name) + elif isinstance(matched_node_pattern, (list, tuple)): + for maybe_node in matched_node_pattern: + _add_matched_node_name_to_set(maybe_node, s) + + +def _insert_obs_or_fq( + node: Node, + obs_or_fq: ObserverOrFakeQuantize, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + model_device: torch.device | None = None, +) -> Node: + """ + Attaches `obs_or_fq` to `model`, and creates a node which calls + `obs_or_fq` on the output of `node`. + + obs_or_fq: an instance of Observer or FakeQuantize module + """ + if model_device is None: + model_device = assert_and_get_unique_device(model) + if model_device: + obs_or_fq.to(model_device) + # add obs_or_fq module as attribute + if is_equalization_observer(obs_or_fq): + prefix = node.name + "_equalization_process_" + else: + prefix = "activation_post_process_" + get_new_obs_or_fq_name = get_new_attr_name_with_prefix(prefix) + obs_or_fq_name = get_new_obs_or_fq_name(model) + setattr(model, obs_or_fq_name, obs_or_fq) + named_modules[obs_or_fq_name] = obs_or_fq + with graph.inserting_after(node): + new_obs = graph.create_node("call_module", obs_or_fq_name, (node,), {}) + return new_obs + + +def _set_target_dtype_info_for_matched_node_pattern( + matched_node_pattern: NodePattern, + last_node: Node, + qconfig: QConfigAny, + qhandler: QuantizeHandler | None, + backend_config: BackendConfig, + named_modules: dict[str, torch.nn.Module], + cache_for_no_tensor_check: dict[Node, bool], + processed_nodes: set[Node], +) -> None: + """Sets the target_dtype_info for each node in matched_node_pattern + Note: processed_nodes is used to ensure we only process each node once + """ + if isinstance(matched_node_pattern, (list, tuple)): + for node_pattern in matched_node_pattern: + _set_target_dtype_info_for_matched_node_pattern( + node_pattern, + last_node, + qconfig, + qhandler, + backend_config, + named_modules, + cache_for_no_tensor_check, + processed_nodes, + ) + + # set target_dtype_info if matched_node_pattern is a Node + # other types of matched object, e.g. int, float literals, are ignored + elif isinstance(matched_node_pattern, Node): + # for pyre + if not isinstance(matched_node_pattern, Node): + raise AssertionError("matched_node_pattern must be a Node") + node = matched_node_pattern + if node in processed_nodes: + return + processed_nodes.add(node) + + if qconfig is None: + return + # TODO: refactor the following code in terms of apply a qconfig to a pattern + # e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1) + # we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act, + # and set output_obs_or_fq_ctr based on qconfig.output_act + # this also requires we extend the structure of QConfig to support more fine + # grained configurations + target_dtype_info: dict[str, Any] = _get_target_activation_dtype_for_node( + node, + qconfig, + qhandler, + named_modules, + backend_config, + cache_for_no_tensor_check, + ) + node.meta["target_dtype_info"] = target_dtype_info + + +def _get_target_activation_dtype_for_node( + node: Node, + qconfig: QConfigAny, + qhandler: QuantizeHandler | None, + named_modules: dict[str, torch.nn.Module], + backend_config: BackendConfig, + cache_for_no_tensor_check: dict[Node, bool], +) -> dict[str, Any]: + """ + For each op attribute in the op's input activation, output activation, + weight, bias - returns the settings of dtype and is_dynamic we expect + for the `quantize` call in the reference model representation, or None + if there is no `quantize` call needed. + + For example, if we have a node corresponding to `op0` in + + x0 -> op0 -> x1 + + And we want a reference quantized representation to be + + x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1 + + Then this function will return + + { + "input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False), + "output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False), + } + + TODO(future PR, if needed): explicitly spell out the non-Tensor + dtypes. + """ + args_have_no_tensors = all_node_args_have_no_tensors( + node, named_modules, cache_for_no_tensor_check + ) + if args_have_no_tensors: + return { + "input_act_obs_or_fq_ctr": None, + "output_act_obs_or_fq_ctr": None, + } + # get qconfig to determine the eventual dtype of this node + if qconfig is not None: + act_dtype, weight_dtype, input_act_is_dynamic = get_qconfig_dtypes(qconfig) + + # Currently `QConfig` only has one `activation` field. + # For static quantization, it is reused for both input + # and output activation. For dynamic quantization, this + # field is currently only used for the input activation, + # with the output activation being in fp32. + # In the future this may change as we add more fields + # to the `QConfig` object. + bias_dtype = ( + torch.float16 + if ( + act_dtype == torch.float16 + and weight_dtype == torch.float16 + and (not input_act_is_dynamic) + ) + else torch.float + ) + + is_general_tensor_value_op = ( + qhandler is not None and qhandler.is_general_tensor_value_op() + ) + + _is_standalone_module = qhandler is not None and qhandler.is_standalone_module() + + weight_index = None + if ( + isinstance(node, Node) + and node.op == "call_function" + and node.target in backend_config._pattern_complex_format_to_config + ): + weight_index = backend_config._pattern_complex_format_to_config[ + node.target + ]._input_type_to_index.get("weight") + + bias_index = None + if ( + isinstance(node, Node) + and node.op == "call_function" + and node.target in backend_config._pattern_complex_format_to_config + ): + bias_index = backend_config._pattern_complex_format_to_config[ + node.target + ]._input_type_to_index.get("bias") + + return { + "input_act_obs_or_fq_ctr": qconfig.activation, + "weight_obs_or_fq_ctr": qconfig.weight, + "bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype), + "weight_index": weight_index, + "bias_index": bias_index, + "output_act_obs_or_fq_ctr": qconfig.activation, + "reuse_input_obs_or_fq": _is_reuse_input_qconfig(qconfig), + "input_output_share_observers": is_general_tensor_value_op, + "_is_standalone_module": _is_standalone_module, + } + return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO) + + +def _get_output_act_obs_or_fq( + arg: Node, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> ObserverOrFakeQuantize | None: + """Get the constructor for observer or fake quant object for + the argument in the original graph as the output of previous node, + skipping inserted observers + + We are assuming that the observers are inserted correctly, and the dtype for + argument in quantized graph will match what is specified by the qconfig + """ + if not isinstance(arg, Node): + raise AssertionError("arg must be a Node") + if "quantization_annotation" in arg.meta: + return _create_obs_or_fq_from_qspec( + arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat + ) + + # Custom module LSTM output is a tuple that we broke down into the internal nodes in order + # to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`). + # Since we modified the graph in this case, we must trace back from the args through + # the specific nodes we added in order to reach the original LSTM node. Otherwise, we would + # not be able to accurately detect whether this node is a consumer of custom module LSTM. + custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg( + arg, named_modules + ) + output_act_obs_or_fq_ctr = None + if custom_module_lstm_node is not None: + output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"][ + "output_act_obs_or_fq_ctr" + ] + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + elif _is_activation_post_process_node(arg, named_modules): + observed_arg = arg.args[0] + if not isinstance(observed_arg, Node): + raise AssertionError("Currently we only support observing Node") + if "quantization_annotation" in observed_arg.meta: + output_act_obs_or_fq = _create_obs_or_fq_from_qspec( + observed_arg.meta["quantization_annotation"].output_qspec, + obs_or_fq_map, + is_qat, + ) + else: + if "target_dtype_info" not in observed_arg.meta: + raise AssertionError( + "expected 'target_dtype_info' in observed_arg.meta" + ) + output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"][ + "output_act_obs_or_fq_ctr" + ] + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + else: + if "target_dtype_info" in arg.meta: + output_act_obs_or_fq_ctr = arg.meta["target_dtype_info"].get( + "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + else: + output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + + return output_act_obs_or_fq + + +def _get_arg_target_dtype_as_output( + arg: Node, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> torch.dtype | None: + arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq( + arg, named_modules, obs_or_fq_map, is_qat + ) + arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic( + arg_as_output_act_obs_or_fq + ) + return arg_as_output_target_dtype + + +def _get_arg_as_input_act_obs_or_fq( + arg: Node, + node: Node, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> ObserverOrFakeQuantize | None: + """Get the observer or fake quant constructor for the Argument `arg`, as input + to Node `node` + """ + if not isinstance(arg, Node): + raise AssertionError("arg must be a Node") + # "input_qspec_map" is the more general design we'll use for pt2e path + # it is a map from input argument node to observer or fake quant constructor, for example + # for the following graph: + # x -> conv -> output + # + # we may annotate conv node like the following: + # conv.meta[...] = QuantizationAnnotation("input_qspec_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...) + # + if "quantization_annotation" in node.meta: + input_qspec_map = node.meta["quantization_annotation"].input_qspec_map + input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules) + if input_arg_qspec is None: + input_arg_obs_or_fq = _DEFAULT_FP32_OBS_OR_FQ_CTR() + else: + input_arg_obs_or_fq = _create_obs_or_fq_from_qspec( + input_arg_qspec, obs_or_fq_map, is_qat + ) + return input_arg_obs_or_fq + + # we can remove the following path in the future if fx graph mode quantization is + # no longer used + is_weight = node_arg_is_weight(node, arg) + is_bias = node_arg_is_bias(node, arg) + is_activation = not is_weight and not is_bias + obs_or_fq_ctr = None + if is_activation: + obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + elif is_weight: + if node.target not in NON_QUANTIZABLE_WEIGHT_OPS: + obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + else: + obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + return obs_or_fq_ctr() if obs_or_fq_ctr else None + + +def _maybe_insert_input_observer_for_arg_or_kwarg( + node: Node | Any, + arg: Argument, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + qhandler: QuantizeHandler | None, + prepare_custom_config: PrepareCustomConfig, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + backend_config: BackendConfig | None = None, + model_device: torch.device | None = None, +) -> Argument: + """ + Given a `node` and an `arg`, inserts an input observer between + `node` and `arg` if necessary. + """ + # for ops such as torch.cat([x0, x1]), + # traverse through the list + if isinstance(arg, (list, tuple)): + new_arg_to_return = [] + for inner_arg in arg: + new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + inner_arg, + qconfig, + model, + named_modules, + graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config, + model_device, + ) + new_arg_to_return.append(new_inner_arg) + return type(arg)(new_arg_to_return) + + if not isinstance(arg, Node): + return arg + if not isinstance(arg, Node): + raise AssertionError("arg must be a Node") + # default (no observer) + new_arg = arg + + is_standalone_module = qhandler is not None and qhandler.is_standalone_module() + # TODO: move this to a separate function + if not is_standalone_module: + # Note: qconfig can be None in this branch this we are getting act/fq from + # node.meta now + # regular flow for most nodes, except standalone modules + + if "quantization_annotation" in node.meta: + reuse_input_obs_or_fq = node.meta[ + "quantization_annotation" + ]._reuse_input_obs_or_fq + else: + if "target_dtype_info" not in node.meta: + raise AssertionError("expected 'target_dtype_info' in node.meta") + # TODO: we are assuming "target_dtype_info" exists here, maybe + # a default value also need to be provided here + target_dtype_info = node.meta["target_dtype_info"] + # for nodes that doesn't have `reuse_input_obs_or_fq` configured, + # we'll default to False, this makes configuring this field optional for users + reuse_input_obs_or_fq = target_dtype_info.get( + "reuse_input_obs_or_fq", False + ) + arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq( + arg, node, named_modules, obs_or_fq_map, is_qat + ) + ( + arg_as_input_target_dtype, + arg_as_input_target_is_dynamic, + ) = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq) + + arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq( + arg, named_modules, obs_or_fq_map, is_qat + ) + ( + arg_as_output_target_dtype, + arg_as_output_target_is_dynamic, + ) = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq) + + needs_obs_or_fq = _needs_obs_or_fq( + arg_as_output_target_dtype, + arg_as_output_target_is_dynamic, + arg_as_input_target_dtype, + arg_as_input_target_is_dynamic, + reuse_input_obs_or_fq, + is_zeroth_arg=len(node.args) > 0 and arg is node.args[0], + ) + + else: + if qconfig is None: + raise AssertionError("qconfig must not be None") + # custom flow for standalone modules + _, _, sm_prepare_custom_config, _ = _get_standalone_module_configs( + node, named_modules, prepare_custom_config, qconfig, backend_config + ) + sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes + + # for args, this is set to the index of the current arg + # for kwargs, this is left at None + cur_input_idx = None + for arg_idx, arg_to_check in enumerate(node.args): + if arg_to_check is arg: + cur_input_idx = arg_idx + break + + if cur_input_idx is None: + needs_obs_or_fq = False + else: + arg_as_output_target_dtype = _get_arg_target_dtype_as_output( + arg, named_modules, obs_or_fq_map, is_qat + ) + arg_as_input_target_dtype = ( + torch.quint8 + if cur_input_idx in sm_input_quantized_idxs + else torch.float + ) + needs_obs_or_fq = ( + arg_as_output_target_dtype != arg_as_input_target_dtype + ) and (arg_as_input_target_dtype != torch.float) + + act_post_process_ctr = qconfig.activation + arg_as_input_act_obs_or_fq = ( + act_post_process_ctr() if act_post_process_ctr else None + ) + + if needs_obs_or_fq: + existing_obs_node = None + + # Before using the new observer, check if an observer + # of the correct type already exists. If it does, use it. + # This prevents duplicate observer insertions if a node is + # used by multiple nodes. + # TODO: this is looking into how the value is used in the future + # we should remove this + # removing this means we insert one observer for each use, even if they + # have the same dtype, we can have an extra pass that removes the extra observers + for maybe_obs_node in arg.users: + if maybe_obs_node.op == "call_module": + maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] + if ( + type(maybe_obs_mod) is type(arg_as_input_act_obs_or_fq) + and maybe_obs_mod.dtype == arg_as_input_target_dtype # type: ignore[possibly-undefined] + ): + arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment] + existing_obs_node = maybe_obs_node + break + + if arg_as_input_act_obs_or_fq is None: + raise AssertionError("arg_as_input_act_obs_or_fq must not be None") + obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq + if existing_obs_node is None: + new_obs_node = _insert_obs_or_fq( + arg, + arg_as_input_act_obs_or_fq, + model, + named_modules, + graph, + model_device, + ) + # override this arg to be the observed arg + new_arg = new_obs_node + else: + new_arg = existing_obs_node + + return new_arg + + +def _maybe_insert_input_observers_for_node( + node: Node, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + qhandler: QuantizeHandler | None, + prepare_custom_config: PrepareCustomConfig, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + backend_config: BackendConfig | None = None, + model_device: torch.device | None = None, +) -> None: + """ + If needed, inserts observers to the input args and kwargs of `node`. + Note: modifies `node` inplace. + + For example, if cur_node needs an observer after prev_node, we change from + + prev_node -> cur_node + + To + + prev_node -> obs -> cur_node + + Note: backend_config only needed for standalone_module node + """ + # Look through every input arg. If that arg's target dtype does not + # match the current node's target dtype, insert an observer. + new_args = [] + for arg in node.args: + new_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + arg, + qconfig, + model, + named_modules, + graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config, + model_device, + ) + new_args.append(new_arg) + + new_kwargs = {} + for k, kwarg in node.kwargs.items(): + new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + kwarg, + qconfig, + model, + named_modules, + graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config, + model_device, + ) + new_kwargs[k] = new_kwarg + + # assign the new args and kwargs to the node, inplace + node.args = tuple(new_args) + node.kwargs = new_kwargs + + +def _maybe_insert_input_equalization_observers_for_node( + node: Node, + equalization_qconfig: Any, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + is_branch: bool, +) -> None: + """ + If `node` needs to be equalized, find the input/weight observers it needs in + `equalization_qconfig`, creates them, and inserts it into `graph`. + + If `node` does not need an equalization observer, returns None. + """ + if equalization_qconfig is None or not node_supports_equalization( + node, named_modules + ): + return + + if is_branch: + warnings.warn( + f"Cannot equalize {node} because it is part of a branch.", stacklevel=2 + ) + return + + new_args = [] + for arg in node.args: + if not isinstance(arg, Node) or node_arg_is_bias(node, arg): + new_args.append(arg) + continue + + is_weight = node_arg_is_weight(node, arg) + + act_eq_process_ctr = ( + equalization_qconfig.weight + if is_weight + else equalization_qconfig.input_activation + ) + + new_eq_obs_mod = act_eq_process_ctr() + new_eq_obs_node = _insert_obs_or_fq( + arg, new_eq_obs_mod, model, named_modules, graph + ) + + new_args.append(new_eq_obs_node) + + # assign the new args and kwargs to the node, inplace + node.args = tuple(new_args) + + +def _maybe_insert_output_observer_for_node( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Node | None: + """ + If `node` needs an output observer, creates it, inserts it into `graph` + and returns it. + + If `node` does not need an output observer, returns None. + + Note: inserting dynamic quantization ops for output is not supported in fx graph mode + quantization code path right now + """ + if node.op == "output": + raise AssertionError("observer insertion for outputs is handled elsewhere") + + is_standalone_module = False + if "quantization_annotation" in node.meta: + output_act_obs_or_fq = _create_obs_or_fq_from_qspec( + node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat + ) + else: + if "target_dtype_info" not in node.meta: + raise AssertionError("expected 'target_dtype_info' in node.meta") + is_standalone_module = node.meta["target_dtype_info"].get( + "_is_standalone_module", False + ) + output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "output_act_obs_or_fq_ctr" + ) + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq) + # uncomment after we support reuse_input_obs_or_fq properly by having separate + # implementations for this key instead of reusing the input_output_share_observers + # code + # reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False) + # for now we set this to False since reuse_input_obs_or_fq for + # the output of a node is implementation in the same code path as observer sharing, + # we should refactor this part to make it clearer in the future + # and we would be able to read this from config directly + reuse_input_obs_or_fq = False + + # Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False + # because the prev_output is the output of an fp32 op, although technically + # we should get the dtype of the output from node.meta["val"] in the future + # if we deprecate fx graph mode quantization + needs_obs_or_fq = _needs_obs_or_fq( + torch.float, False, target_dtype, target_is_dynamic, reuse_input_obs_or_fq + ) + # currently the activation in QConfig(activation=...,) is for both input + # and output, and when the activation is configured to be dynamic quantization + # e.g. PlaceholderObserver(dtype=torch.quint8, is_dynamic=True, ...), it means + # the input should by dynamically quantized, but output should not be quantized + # + # there is no way we can specify different observer/fq for input and output + # activation through QConfig today, this limitation is lifted in the + # quantizer/annotation API in pytorch 2.0 export quantization code path, + # but since this code is reused, annotating output to be dynamically quantized + # would not work either for that. + # we can change QConfig to support input/output activation if we want + # to remove the following check, or if we can deprecate fx graph mode quantization + if target_is_dynamic: + needs_obs_or_fq = False + + # we never insert observers to output of standalone module, we assume + # if needed, they are inserted inside the standalone module + needs_obs_or_fq = needs_obs_or_fq and (not is_standalone_module) + + if needs_obs_or_fq: + obs_or_fq_map[node] = output_act_obs_or_fq + return _insert_obs_or_fq( + node, output_act_obs_or_fq, model, named_modules, graph + ) + else: + return None + + +def _maybe_insert_observers_before_graph_output( + graph_output_node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> None: + """ + If the output needs to be quantized and there are any nodes + in the output which are not already observed, inserts observers + for those nodes. + """ + + def _recursive_maybe_replace_node_with_obs( + maybe_node: Argument, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + ) -> Argument: + """ + Navigate an arbitrary data structure of lists, tuples, dicts. + For each container type, recurse on all inputs. Once any Node + is found, insert an observer if needed and do not recurse further. + + For example, given a structure of + + {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}} + + we recurse down to bar1 and bar3, observe them if necessary, + and if we inserted an observer then replace the original node + with its observer. + + Returns the data structure with all nodes needing observation being + replaced by their observers. + """ + if isinstance(maybe_node, Node): + # check dtype of this node + arg_as_output_target_dtype = _get_arg_target_dtype_as_output( + maybe_node, named_modules, obs_or_fq_map, is_qat + ) + observer_mod = None + arg_as_input_target_dtype = torch.float + if "target_dtype_info" in maybe_node.meta: + observer_cls = maybe_node.meta["target_dtype_info"].get( + "input_act_obs_or_fq_ctr", None + ) + if observer_cls is not None: + observer_mod = observer_cls() + arg_as_input_target_dtype = observer_mod.dtype + # TODO: this does not handle dynamic quantization yet + need_obs = ( + arg_as_output_target_dtype != arg_as_input_target_dtype + and arg_as_input_target_dtype != torch.float + ) + if need_obs: + if observer_mod is None: + raise AssertionError( + "observer_mod must not be None when need_obs is True" + ) + # insert observer + observer_node = _insert_obs_or_fq( + maybe_node, observer_mod, model, named_modules, graph + ) + return observer_node + else: + return maybe_node + elif isinstance(maybe_node, (list, tuple)): + results = [ + _recursive_maybe_replace_node_with_obs( + inner_node, model, named_modules, graph + ) + for inner_node in maybe_node + ] + if isinstance(maybe_node, list): + return results + else: + return tuple(results) + elif isinstance(maybe_node, dict): + results_dict = {} + for k, inner_v in maybe_node.items(): + results_dict[k] = _recursive_maybe_replace_node_with_obs( + inner_v, model, named_modules, graph + ) + return results_dict + elif maybe_node is None: + return None + else: + raise Exception( # noqa: TRY002 + "Unhandled type for returned node:", maybe_node + ) + + new_args = [ + _recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph) + for old_arg in graph_output_node.args + ] + + graph_output_node.args = tuple(new_args) # type: ignore[assignment] + + +def _maybe_propagate_dtype_for_node( + node: Node, + target_dtype: torch.dtype | type, + node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig], +) -> None: + """ + Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node` + is a general tensor shape op, also call this function recursively on + the first argument, to propagate the dtype to the caller. + """ + node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None + node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None + # if this is a copy node, propagate to first arg + ( + _root_node, + _, + _pattern, + qhandler, + _qconfig, + ) = node_name_to_match_result_with_qconfig.get( + node.name, (None, None, None, None, None) + ) + # TODO: probably need to remove `is_general_tensor_value_op` + if qhandler is not None and qhandler.is_general_tensor_value_op(): + prev_node = node.args[0] + if isinstance(prev_node, Node): + _maybe_propagate_dtype_for_node( + prev_node, target_dtype, node_name_to_match_result_with_qconfig + ) + + +def propagate_dtypes_for_known_nodes( + graph: Graph, + node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig], +) -> None: + """ + Currently we assume that inputs to the graph are either `torch.float` or + `torch.quint8`, which is not always correct. For ops such as + `x.masked_fill(mask, value)`, we know that the dtype of `mask` is a + `BoolTensor`. Propagate this information throughout the graph. + + Note: not all dtypes in the graph will be correct after this pass, but a + higher percentage of them will be correct. Hopefully in the future we can + replace this with a better way to reason about dtypes of tensors. + """ + for node in graph.nodes: + non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node) + + for arg_type in non_observable_arg_dict: + non_observable_indices = non_observable_arg_dict[arg_type](node) + + for index in non_observable_indices: + arg = node.args[index] + + # when an argument is a tuple, it does not show up as another node so we need to go through + # all elements of the tuple manually + if isinstance(arg, (tuple, list)): + arg_list = list(arg) + else: + arg_list = [arg] + + for cur_arg in arg_list: + # hard coded arguments show up but aren't `Node` typed and do not need dtype propagated + if isinstance(cur_arg, torch.fx.node.Node): + _maybe_propagate_dtype_for_node( + cur_arg, arg_type, node_name_to_match_result_with_qconfig + ) + + +def _maybe_make_input_output_share_observers( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], +) -> bool: + """ + Ensures that we share an observer + for all input arguments as well as the output argument. In detail, given + a graph of + + x0 -> obs0 -> op -> x2 + / + x1 -> obs1 / + + where node obs0 points to observer instance observer0, + obs1 points to observer1 and obs2 points to observer2, we make nodes obs1 + and ob2 point to observer0. + Returns: whether the operation succeeded or not + """ + first_arg = None + # find the first non-Tensor arg + for i in range(len(node.args)): + if isinstance(node.args[i], (Node, list, tuple)): + first_arg = node.args[i] + break + + # if there is no non-Tensor arg, return directly + if first_arg is None: + return False + + if isinstance(first_arg, (list, tuple)): + first_arg_arg = first_arg[0] + elif isinstance(first_arg, Node): + first_arg_arg = first_arg + else: + return False + + # if we have a graph such as + # observed_node -> non_observed_node -> cat + # we need to navigate up to the first observer + iteration_guard = 0 + while not _is_activation_post_process_node(first_arg_arg, named_modules): + if not isinstance(first_arg_arg, Node): + return False + # did not find an activation_post_process for the op + if first_arg_arg.op == "placeholder": + return False + # trace back the args until we found the first Tensor/Node + trace_back_node = None + for i in range(len(first_arg_arg.args)): + trace_back_node = first_arg_arg.args[i] + if isinstance(trace_back_node, Node): + break + if trace_back_node is None: + return False + first_arg_arg = trace_back_node + + iteration_guard += 1 + if iteration_guard > 10000: + raise AssertionError("Unable to find observer of previous node") + + if not isinstance(first_arg_arg, Node): + raise AssertionError("first_arg_arg must be a Node") + target_to_use = first_arg_arg.target + if not isinstance(target_to_use, str): + raise AssertionError("target_to_use must be a string") + obs_mod_to_use = named_modules[target_to_use] + + if isinstance(first_arg, (list, tuple)): + # set all other input observer nodes to use that module + for input_idx, input_arg in enumerate(first_arg): + if input_idx == 0: + continue + iteration_guard = 0 + while not _is_activation_post_process_node(input_arg, named_modules): + # failed to trace back since no input arg for the current node + if len(input_arg.args) < 1: + return False + input_arg = input_arg.args[0] + iteration_guard += 1 + if iteration_guard > 10000: + raise AssertionError("Unable to find observer of previous node") + + parent_name, name = _parent_name(input_arg.target) + setattr(named_modules[parent_name], name, obs_mod_to_use) + + # set the output observer node to use that module + for output_obs_node in node.users: + if not _is_activation_post_process_node(output_obs_node, named_modules): + raise AssertionError( + "output_obs_node must be an activation post process node" + ) + parent_name, name = _parent_name(output_obs_node.target) + setattr(named_modules[parent_name], name, obs_mod_to_use) + + # TODO(future PR): delete the orphaned observer modules + return True + + +def _remove_output_observer( + node: Node, model: torch.nn.Module, named_modules: dict[str, torch.nn.Module] +): + items = list(node.users.items()) + for output_obs_node, _ in items: + if not _is_activation_post_process_node(output_obs_node, named_modules): + raise AssertionError( + "output_obs_node must be an activation post process node" + ) + output_obs_node.replace_all_uses_with(node) + model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator] + + +def _swap_custom_module_to_observed( + node: Node, + qconfig: QConfigAny, + named_modules: dict[str, torch.nn.Module], + prepare_custom_config: PrepareCustomConfig, +): + custom_module = named_modules[node.target] # type: ignore[index] + custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping + observed_custom_module_class = get_swapped_custom_module_class( + custom_module, custom_module_class_mapping, qconfig + ) + observed_custom_module = observed_custom_module_class.from_float(custom_module) + parent_name, name = _parent_name(node.target) + setattr(named_modules[parent_name], name, observed_custom_module) + + +def insert_observers_for_model( + model: GraphModule, + node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig], + node_name_to_qconfig: dict[str, QConfigAny], + prepare_custom_config: PrepareCustomConfig, + equalization_config_map: dict[str, Any], + backend_config: BackendConfig, + observed_node_names: set[str], + is_qat: bool, +) -> Node | None: + """ + Inserts observers, using the following high level algorithm: + + For each node in the graph: + 1. determine the target dtype of this node in the quantized graph, and save + it for future steps + 2. determine the target dtype or all args and kwargs of this node + 3. if any arg or kwarg's target dtype does not match the current node's + dtype, insert an observer + 4. if the current node needs an output observer, insert it + + For example: + + - starting graph: + x0 -> linear -> x1 + + - observed graph after processing x0: + x0(fp32) + + - observed graph after processing linear: + x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) + + - observed graph after processing x1: + x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1 + + After a node is processed, the naive observer placement is guaranteed to be + complete for that node and all of its predecessors. There can be future + passes which optimize the graph by deduplicating observers, etc. + """ + + # node.meta["target_dtype_info"] stores the target dtype information + # that's derived from qconfig for the Node, for example, if we have + # a conv2d node that has a qconfig + # qconfig = QConfig(activation=..., weight=...) + # # information for input and bias node omitted + # # for getattr node + # # weight = getattr(self, 'weight') + # weight.meta["target_dtype_info"] = { + # 'output_act_obs_or_fq_ctr': qconfig.weight, + # } + # # for conv2d node + # # conv2d = call_function[target=torch.nn.functional.conv2d]( + # # args=(input, weight, bias)) + # conv2d.meta["target_dtype_info"] = { + # 'input_act_obs_or_fq_ctr': qconfig.activation + # 'weight_obs_or_fq_ctr': qconfig.weight, + # 'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32), + # 'output_act_obs_or_fq_ctr': qconfig.activation, + # } + # + cache_for_no_tensor_check: dict[Node, bool] = {} + + # first, populate the dtype map based only on qconfig and qhandler + # this assumes: + # graph inputs are fp32 by default, and int8 where overridden + # other nodes output dtype is specified by the qconfig + named_modules = dict(model.named_modules(remove_duplicate=False)) + + input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes + processed_nodes: set[Node] = set() + # initialize target_dtype_info + for node in model.graph.nodes: + node.meta["target_dtype_info"] = copy.copy( + _DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO + ) + + inputs_seen_counter = 0 + outputs_seen_counter = 0 + placeholder_node_to_input_index: dict[Node, int] = {} + # TODO: we probably don't need this counter since each graph will only have + # one output node? + output_node_to_output_index: dict[Node, int] = {} + for node in model.graph.nodes: + if node.op == "placeholder": + placeholder_node_to_input_index[node] = inputs_seen_counter + inputs_seen_counter += 1 + if node.op == "output": + output_node_to_output_index[node] = outputs_seen_counter + outputs_seen_counter += 1 + + # Step 1, set the observer or fake quantize module constructor for each node in the + # matched_node_pattern + + for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values(): + ( + last_node, + matched_node_pattern, + pattern, + qhandler, + qconfig, + ) = match_res_with_qconfig + if qhandler is None: + raise AssertionError("qhandler must not be None") + _set_target_dtype_info_for_matched_node_pattern( + matched_node_pattern, + last_node, + qconfig, + qhandler, + backend_config, + named_modules, + cache_for_no_tensor_check, + processed_nodes, + ) + + # Step 2. Special cases for some operators, we might be able to remove them + # in the future if we know dtype information of each node better + + # Step 2.1. some settings are not based on patterns, we need to process each node + # instead + for node in model.graph.nodes: + if ( + node.op == "placeholder" + and placeholder_node_to_input_index[node] in input_quantized_idxs + ): + # users are not supposed to call calculate_qparams on PlaceholderObserver, and + # this is OK because we are using this as a way to encode the dtypes of input + # tensor, we won't actually insert these observers in the graph and won't + # actually call calculate_qparams + node.meta["target_dtype_info"] = copy.copy( + _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO + ) + elif node.op in ("call_module", "call_method", "call_function"): + args_have_no_tensors = all_node_args_have_no_tensors( + node, named_modules, cache_for_no_tensor_check + ) + if args_have_no_tensors: + node.meta["target_dtype_info"] = { + "input_act_obs_or_fq_ctr": None, + "output_act_obs_or_fq_ctr": None, + } + elif ( + node.op == "output" + and output_node_to_output_index[node] in output_quantized_idxs + ): + # TODO(future PR): update the output_quantized_idxs API to match + # arbitrary data structures. There is always a single output, and + # that output can have arbitrary nesting of values. List[int] is + # not the right data type for this. + + # TODO(future PR): support more dtypes in model outputs, if necessary + node.meta["target_dtype_info"] = copy.copy( + _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO + ) + + # Step 2.2, for nodes with known input dtypes, propagate them throughout the + # graph. For example, if there is a call such as + # x1 = x0.masked_fill(mask, 1) + # we propagate the type of mask to be torch.bool + propagate_dtypes_for_known_nodes( + model.graph, node_name_to_match_result_with_qconfig + ) + + # Step 3, check if the requested target_dtype_info is supported by backend or not + # if not, we'll reset the target_dtye_info to use the default (float Tensor) + + # reset the counters and set of processed_nodes + processed_nodes: set[Node] = set() + for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values(): + ( + last_node, + matched_node_pattern, + pattern, + qhandler, + qconfig, + ) = match_res_with_qconfig + is_supported_by_backend = ( + _is_pattern_dtype_config_and_qconfig_supported_by_backend( + pattern, matched_node_pattern, qconfig, backend_config + ) + ) + if qhandler is None: + raise AssertionError("qhandler must not be None") + + # get output_act_dtype so that we don't also reset the special typed nodes + # TODO: we might want to handle these more uniformly with the default path + # this can be improved if we can use node.meta["val"] + output_act_or_fq_ctr = node.meta["target_dtype_info"][ + "output_act_obs_or_fq_ctr" + ] + output_act_or_fq = output_act_or_fq_ctr() if output_act_or_fq_ctr else None + output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_or_fq) + if not is_supported_by_backend and output_act_dtype not in [ + None, + int, + float, + torch.bool, + ]: + # restore target_dtype_info to default if it is not supported by backend + _set_target_dtype_info_for_matched_node_pattern( + matched_node_pattern, + last_node, + torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig, + None, + backend_config, + named_modules, + cache_for_no_tensor_check, + processed_nodes, + ) + + # After this point, the current node and all of its arguments + # have a target_dtype_info assigned. Now, we insert observers for inputs + # of this node (if needed for this node), and the output of this node + # (if needed for this node). + + # Since we are mutating the graph as we go, we iterate over the original + # nodes before observer insertion, instead of model.graph.nodes. + nodes_before_observation = list(model.graph.nodes) + + # Avoid duplicates custom module swaps for multiple nodes with same target. + custom_module_names_already_swapped: set[str] = set() + + # TODO: reuse placeholder_node_to_input_index and output_node_to_output_index + # reset inputs/outputs counters + inputs_seen_counter = 0 + outputs_seen_counter = 0 + results_node = None + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {} + model_device = assert_and_get_unique_device(model) + + # TODO: change this to insert obs/fq by pattern instead of by node + for node in nodes_before_observation: + if node.op == "placeholder": + # if a graph input is in fp32, it does not need observation + # if a graph input is in int8, we assume the observation happens + # outside of the graph, and no additional observation is needed + pass + + elif node.op in ("call_module", "call_method", "call_function", "output"): + # check for matches + ( + last_node, + matched_node_pattern, + pattern, + qhandler, + qconfig, + ) = node_name_to_match_result_with_qconfig.get( # type: ignore[assignment] + node.name, (None, None, None, None, None) + ) + equalization_qconfig = equalization_config_map.get(node.name, None) + + this_node_dtype_info = node.meta["target_dtype_info"] + if "val" in node.meta: + output_is_a_tensor = this_node_dtype_info is not None and isinstance( + node.meta["val"], FakeTensor + ) + else: + output_is_a_tensor = this_node_dtype_info is not None + + skip_inserting_observers = ( + (qconfig is None) or not output_is_a_tensor + ) and (node.op != "output") + + # TODO: take a closer look to see if we can remove this check + # right now it is here because of `observed_node_names`, we are using + # it as an indicator for swapping the modules to reference modules in + # convert + is_supported_by_backend = ( + _is_pattern_dtype_config_and_qconfig_supported_by_backend( + pattern, matched_node_pattern, qconfig, backend_config + ) + ) + + if not skip_inserting_observers and is_supported_by_backend: + named_modules = dict(model.named_modules(remove_duplicate=False)) + if node.op != "output": + if matched_node_pattern is None: + raise AssertionError("matched_node_pattern must not be None") + # add matched nodes to the observed node name set + _add_matched_node_name_to_set( + matched_node_pattern, observed_node_names + ) + + # This is currently only used for equalization. + # Checks if the current node is in a branch in which the two + # first layers are both being quantized. + # + # ex. conv2 + # / + # x -> conv1 + # + # If this is the case, we will not apply equalization to the + # initial two layers. + is_quantized_branch = False + if ( + len(node.args) > 0 + and isinstance(node.args[0], Node) + and len(node.args[0].users) > 1 + ): + for user in node.args[0].users: + # Checks if there exists another user being quantized + is_user_quantized = node_name_to_qconfig.get( + user.name, None + ) is not None or ( + user.op == "call_module" + and isinstance( + named_modules[str(user.target)], ObserverBase + ) + ) + if user != node and is_user_quantized: + is_quantized_branch = True + + pattern_to_root_node_getter = ( + get_fusion_pattern_to_root_node_getter(backend_config) + ) + root_node_getter = pattern_to_root_node_getter.get( + pattern, _default_root_node_getter + ) + root_node = root_node_getter(matched_node_pattern) + is_input_node_of_the_pattern = node is root_node + if is_input_node_of_the_pattern: + # this modifies node inplace + _maybe_insert_input_observers_for_node( + node, + qconfig, + model, + named_modules, + model.graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config, + model_device, + ) + + # insert equalization input observers if needed + _maybe_insert_input_equalization_observers_for_node( + node, + equalization_qconfig, + model, + named_modules, + model.graph, + is_quantized_branch, + ) + + is_last_node_of_pattern = node is last_node + input_output_share_observers = node.meta["target_dtype_info"].get( + "input_output_share_observers", False + ) + reuse_input_obs_or_fq = node.meta["target_dtype_info"].get( + "reuse_input_obs_or_fq", False + ) + + if is_last_node_of_pattern: + if _is_custom_module_lstm( + node, named_modules, qconfig, qhandler + ): + # Currently custom module outputs are assumed to be already quantized, + # so we need to insert a DeQuantStub after the output. For custom module + # LSTM specifically, the outputs are also a nested tuple, so we must first + # break down the tuple to insert DeQuantStubs after the internal nodes. + + # TODO: This currently diverges from how custom modules are handled today, + # where we insert observers after the output instead of DeQuantStubs, and + # replace these observers with "dequantize" nodes during convert. Conceptually, + # these output observers are the same as DeQuantStubs. In the future, we + # should resolve this inconsistency by inserting DeQuantStubs for all custom + # modules, not just for LSTM. + _insert_dequant_stubs_for_custom_module_lstm_output( + node, model, named_modules, model.graph + ) + if node.target not in custom_module_names_already_swapped: + custom_module_names_already_swapped.add(node.target) + _swap_custom_module_to_observed( + node, qconfig, named_modules, prepare_custom_config + ) + else: + # this returns the new observer node if it was needed + maybe_output_obs_node = ( + _maybe_insert_output_observer_for_node( + node, + model, + named_modules, + model.graph, + obs_or_fq_map, + is_qat, + ) + ) + + if maybe_output_obs_node is not None: + # Update users of original node to use the output observer + # instead. For example, change + # + # next_node + # / + # cur_node -> obs + # + # to + # + # next_node + # / + # cur_node -> obs + # + # We need to save orig users before updating uses because + # the list of users will change as we update uses + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is maybe_output_obs_node: + continue + user_node.replace_input_with( + node, maybe_output_obs_node + ) + + _is_observer_in_same_graph_ = ( + _is_observer_in_same_graph( + node, named_modules, obs_or_fq_map, is_qat + ) + ) + + # for ops whose inputs and outputs share observer/fqs, we modify the graph + # to make all inputs and outputs use the first input's + # observer/fq + if ( + input_output_share_observers + and _is_observer_in_same_graph_ + ) or reuse_input_obs_or_fq: + if not _maybe_make_input_output_share_observers( + node, model, named_modules + ): + _remove_output_observer( + node, model, named_modules + ) + + if qhandler is not None and qhandler.is_custom_module(): + if ( + node.target + not in custom_module_names_already_swapped + ): + custom_module_names_already_swapped.add( + node.target + ) + _swap_custom_module_to_observed( + node, + qconfig, + named_modules, + prepare_custom_config, + ) + + else: # output + _maybe_insert_observers_before_graph_output( + node, model, named_modules, model.graph, obs_or_fq_map, is_qat + ) + + # + # After this point, the current node has input and output observers + # that it needs for itself inserted. + # + + # increment the counters, so future inputs and outputs are assigned + # correct dtypes + if node.op == "placeholder": + inputs_seen_counter += 1 + elif node.op == "output": + outputs_seen_counter += 1 + results_node = node + + return results_node + + +def _run_prepare_fx_on_standalone_modules( + model: torch.nn.Module, + is_qat: bool, + named_modules: dict[str, torch.nn.Module], + node_name_to_match_result_with_qconfig: Any, + prepare_custom_config: PrepareCustomConfig, + backend_config: BackendConfig, +) -> None: + """ + Runs prepare_fx on each standalone module. Note: this does + not modify the graph, it just replaces the unobserved modules with + their observed versions. + """ + for ( + root_node, + _, + _pattern, + qhandler, + qconfig, + ) in node_name_to_match_result_with_qconfig.values(): + if qhandler is None: + continue + elif not qhandler.is_standalone_module(): + continue + + ( + sm_qconfig_mapping, + sm_example_inputs, + sm_prepare_custom_config, + sm_backend_config, + ) = _get_standalone_module_configs( + root_node, named_modules, prepare_custom_config, qconfig, backend_config + ) + + standalone_module = named_modules[root_node.target] + prepare = torch.ao.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined] + observed_standalone_module = prepare( + standalone_module, + sm_qconfig_mapping, + is_qat, + example_inputs=sm_example_inputs, + prepare_custom_config=sm_prepare_custom_config, + backend_config=sm_backend_config, + ) + parent_name, name = _parent_name(root_node.target) + setattr(named_modules[parent_name], name, observed_standalone_module) + named_modules[root_node.target] = observed_standalone_module + + +def _save_state( + observed: GraphModule, + node_name_to_qconfig: dict[str, QConfigAny], + node_name_to_scope: dict[str, tuple[str, type]], + prepare_custom_config: PrepareCustomConfig, + equalization_node_name_to_qconfig: dict[str, Any], + qconfig_mapping: QConfigMapping, + is_qat: bool, + observed_node_names: set[str], +) -> None: + observed.meta["_observed_graph_module_attrs"] = ObservedGraphModuleAttrs( + node_name_to_qconfig=node_name_to_qconfig, + node_name_to_scope=node_name_to_scope, + prepare_custom_config=prepare_custom_config, + equalization_node_name_to_qconfig=equalization_node_name_to_qconfig, + qconfig_mapping=qconfig_mapping, + is_qat=is_qat, + observed_node_names=observed_node_names, + ) + + +def prepare( + model: GraphModule, + qconfig_mapping: QConfigMapping | dict[str, Any], + is_qat: bool, + node_name_to_scope: dict[str, tuple[str, type]], + example_inputs: tuple[Any, ...], + prepare_custom_config: PrepareCustomConfig | dict[str, Any] | None = None, + _equalization_config: QConfigMapping | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, + is_standalone_module: bool = False, +) -> GraphModule: + """standalone_module means it a submodule that is not inlined in + parent module, and will be quantized separately as one unit. + + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module + Args: + node_name_to_scope: mapping from node name to the scope of the module which contains the node. + The scope is a tuple of fully qualified path of the module and the type of the module + Returns: + model(GraphModule): prepared standalone module + attributes related to standalone module + in model.meta["_observed_graph_module_attrs"]: + is_observed_standalone_module (bool): boolean value that shows whether the + current model is a observed standalone module or not + standalone_module_input_quantized_idxs(List[Int]): a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + standalone_module_output_quantized_idxs(List[Int]): a list of + indices for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module + """ + if prepare_custom_config is None: + prepare_custom_config = PrepareCustomConfig() + if _equalization_config is None: + _equalization_config = QConfigMapping() + + if isinstance(qconfig_mapping, dict): + warnings.warn( + "Passing a QConfig dictionary to prepare is deprecated and will not be supported " + "in a future version. Please pass in a QConfigMapping instead.", + FutureWarning, + stacklevel=2, + ) + qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) + + if isinstance(_equalization_config, dict): + warnings.warn( + "Passing a QConfig dictionary to prepare for equalization is deprecated and will not " + "be supported in a future version. Please pass in a QConfigMapping instead.", + FutureWarning, + stacklevel=2, + ) + _equalization_config = QConfigMapping.from_dict(_equalization_config) + + if isinstance(prepare_custom_config, dict): + warnings.warn( + "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a PrepareCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) + prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) + + if isinstance(backend_config, dict): + warnings.warn( + "Passing a backend_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a BackendConfig instead.", + FutureWarning, + stacklevel=2, + ) + backend_config = BackendConfig.from_dict(backend_config) + + if not isinstance(qconfig_mapping, QConfigMapping): + raise AssertionError("qconfig_mapping must be a QConfigMapping") + if not isinstance(_equalization_config, QConfigMapping): + raise AssertionError("_equalization_config must be a QConfigMapping") + qconfig_mapping = copy.deepcopy(qconfig_mapping) + _equalization_config = copy.deepcopy(_equalization_config) + + # mapping from a tuple of nodes in reverse order to uninitialized + # QuantizeHandler subclass. For example, + # { + # # match a single node + # (: + # ), + # # match multiple nodes in reverse order + # ((, ): + # ), + # } + + pattern_to_quantize_handler: dict[Pattern, QuantizeHandler] = {} + if backend_config is None: + backend_config = get_native_backend_config() + pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config) + pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler) + + root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) + + # pyrefly: ignore [bad-argument-type] + _update_qconfig_for_fusion(model, qconfig_mapping) + # pyrefly: ignore [bad-argument-type] + _update_qconfig_for_fusion(model, _equalization_config) + # pyrefly: ignore [bad-argument-type] + flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping) + # TODO: support regex as well + propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict()) + + if is_qat: + module_to_qat_module = get_module_to_qat_module(backend_config) + _qat_swap_modules(model, module_to_qat_module) + # pyrefly: ignore [bad-argument-type] + _update_qconfig_for_qat(qconfig_mapping, backend_config) + + # mapping from fully qualified module name to module instance + # for example, + # { + # '': Model(...), + # 'linear': Linear(...), + # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), + # } + named_modules = dict(model.named_modules(remove_duplicate=False)) + + # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches + equalization_node_name_to_qconfig = _generate_node_name_to_qconfig( + model, + named_modules, + model.graph, + # pyrefly: ignore [bad-argument-type] + _equalization_config, + node_name_to_scope, + ) + node_name_to_qconfig = _generate_node_name_to_qconfig( + model, + named_modules, + model.graph, + # pyrefly: ignore [bad-argument-type] + qconfig_mapping, + node_name_to_scope, + ) + + # match the patterns that will get quantized + standalone_module_names = list(prepare_custom_config.standalone_module_names.keys()) + standalone_module_classes = list( + prepare_custom_config.standalone_module_classes.keys() + ) + + custom_module_classes = get_custom_module_class_keys( + prepare_custom_config.float_to_observed_mapping + ) + matches_without_qconfig = _find_matches( + model.graph, + named_modules, + pattern_to_quantize_handler, + root_node_getter_mapping, + standalone_module_names, + standalone_module_classes, + custom_module_classes, + ) + + # map qconfig instances to matches + node_name_to_match_result_with_qconfig = {} + for node_name, match_without_qconfig in matches_without_qconfig.items(): + match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name]) + node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig + + _run_prepare_fx_on_standalone_modules( + model, + is_qat, + named_modules, + node_name_to_match_result_with_qconfig, + prepare_custom_config, + backend_config, + ) + + # record names for the set of observed node, so that in convert step + # we know whether we need to convert a floating point module to reference + # quantized module or not + observed_node_names: set[str] = set() + + result_node = insert_observers_for_model( + model, + node_name_to_match_result_with_qconfig, + node_name_to_qconfig, + prepare_custom_config, + equalization_node_name_to_qconfig, + backend_config, + observed_node_names, + is_qat, + ) + model = GraphModule(model, model.graph) + + _save_state( + model, + node_name_to_qconfig, + node_name_to_scope, + prepare_custom_config, + equalization_node_name_to_qconfig, + # pyrefly: ignore [bad-argument-type] + qconfig_mapping, + is_qat, + observed_node_names, + ) + + if is_standalone_module: + if result_node is None: + raise AssertionError("result_node must not be None for standalone modules") + if not isinstance(result_node.args[0], Node): + raise AssertionError( + "standalone module only supports returning simple value currently (not tuple, dict etc.)" + ) + # these inputs are observed in parent + # converting List[int] to Tensor since module attribute is + # Union[Tensor, Module] + input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: list[int] = ( + prepare_custom_config.output_quantized_indexes + ) + observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] + # inplace modification + observed_graph_module_attrs.is_observed_standalone_module = True + observed_graph_module_attrs.standalone_module_input_quantized_idxs = ( + input_quantized_idxs + ) + observed_graph_module_attrs.standalone_module_output_quantized_idxs = ( + output_quantized_idxs + ) + return model diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..783cba8149e6e09164d01c7f9ebafdc2e6240428 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -0,0 +1,401 @@ +# mypy: allow-untyped-defs +import re +from collections import defaultdict, OrderedDict +from collections.abc import Callable +from typing import Any + +import torch +from torch.ao.nn.intrinsic import _FusedModule +from torch.ao.quantization import QConfig +from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig +from torch.ao.quantization.backend_config.utils import get_module_to_qat_module +from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.qconfig import ( + _add_module_to_qconfig_obs_ctr, + qconfig_equals, + QConfigAny, +) +from torch.ao.quantization.qconfig_mapping import ( + _MODULE_NAME_DICT_KEY, + _MODULE_NAME_REGEX_DICT_KEY, + _OBJECT_TYPE_DICT_KEY, + QConfigMapping, +) +from torch.ao.quantization.utils import _parent_name, get_qconfig_dtypes +from torch.fx import GraphModule +from torch.fx.graph import Graph + + +__all__: list[str] = [] + + +def _maybe_adjust_qconfig_for_module_name_object_type_order( + qconfig_mapping: QConfigMapping, + cur_module_path: str, + cur_object_type: Callable, + cur_object_type_idx: int, + fallback_qconfig: QConfigAny, +) -> QConfigAny: + for ( + module_name, + object_type, + index, + ), qconfig in qconfig_mapping.module_name_object_type_order_qconfigs.items(): + if ( + (module_name == cur_module_path) + and (object_type == cur_object_type) + and (index == cur_object_type_idx) + ): + return qconfig + return fallback_qconfig + + +def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping): + """ + Update the QConfigMapping to account for fused modules such as LinearReLU. + This assumes the QConfigMapping's attributes have already been converted to OrderedDicts. + """ + object_type_dict = qconfig_mapping.object_type_qconfigs + if len(object_type_dict) == 0: + return qconfig_mapping + + modules = dict(model.named_modules()) + + for node in model.graph.nodes: + if node.op == "call_module" and node.target in modules: + maybe_fused_module = modules[str(node.target)] + if not isinstance(maybe_fused_module, _FusedModule): + continue + + ops = list(maybe_fused_module._modules.values()) + fused_qconfig = object_type_dict.get(type(ops[0]), None) + + # Raise an error if the modules in the fused module have + # different qconfigs specified in the qconfig_dict + # TODO: currently it only works for modules, + # need to make this work for torch.nn.functional.relu + # TODO: currently it only works for object_type configurations, + # ideally it should work for different types of configurations, + # maybe we want to redesign this part + for op in ops[1:]: + if not qconfig_equals( + object_type_dict.get(type(op), None), fused_qconfig + ): + raise LookupError( + "During fusion, we need to specify the same " + + f"qconfigs for all module types in {type(maybe_fused_module)} " + + f"offending type: {type(op)}" + ) + + if fused_qconfig is not None: + object_type_dict[type(maybe_fused_module)] = fused_qconfig + + +def _generate_node_name_to_qconfig( + root: torch.nn.Module, + modules: dict[str, torch.nn.Module], + input_graph: Graph, + qconfig_mapping: QConfigMapping, + node_name_to_scope: dict[str, tuple[str, type]], +) -> dict[str, QConfigAny]: + global_qconfig = qconfig_mapping.global_qconfig + node_name_to_qconfig = {} + + # example: + # + # {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...} + # + # meaning in submodule 'foo.bar', we have seen 0 F.linear and + # 1 F.conv2d invocations so far. + submodule_to_object_type_to_cur_idx: dict[str, dict[Callable, int]] = defaultdict( + lambda: defaultdict(int) + ) + for node in input_graph.nodes: + qconfig = None + if node.op == "get_attr": + module_name, _ = _parent_name(node.target) + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, type(modules[module_name]), module_name, global_qconfig + ) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( + qconfig, modules.get(node.target, None) + ) + elif node.op == "call_function": + # precedence: module_name_qconfig + # > function_qconfig > global_qconfig + # module_name takes precedence over function qconfig + function_qconfig = _get_object_type_qconfig( + qconfig_mapping, node.target, global_qconfig + ) + module_path, module_type = node_name_to_scope[node.name] + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, module_type, module_path, function_qconfig + ) + + cur_object_type_idx = submodule_to_object_type_to_cur_idx[module_path][ + node.target + ] + submodule_to_object_type_to_cur_idx[module_path][node.target] += 1 + qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order( + qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig + ) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( + qconfig, modules.get(node.target, None) + ) + + elif node.op == "call_method": + module_path, module_type = node_name_to_scope[node.name] + # first use node.target (string) to get the qconfig + # this is to support configs like + # "object_type": [("reshape", qconfig)] + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, node.target, module_path, global_qconfig + ) + # if there is no special config for the method, we'll fall back to the + # config for the module that contains the call_method node + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, module_type, module_path, qconfig + ) + # currently call_method does not support modifying qconfig + # by order, we can add this later if it is needed. + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( + qconfig, modules.get(node.target, None) + ) + + elif node.op == "call_module": + # if the node is an observer, just continue - don't add it to the qconfig_map + if _is_activation_post_process(modules[node.target]): + continue + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, type(modules[node.target]), node.target, global_qconfig + ) + + module_path, module_type = node_name_to_scope[node.name] + # Note: for call_module, the module_path is the current module's name. + # to meaningfully count invocations, we need to count them in the parent + # module. + parent_name, _ = _parent_name(module_path) + cur_object_type_idx = submodule_to_object_type_to_cur_idx[parent_name][ + module_type + ] + submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1 + qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order( + qconfig_mapping, parent_name, module_type, cur_object_type_idx, qconfig + ) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( + qconfig, modules.get(node.target, None) + ) + + # regex is not supported eager mode propagate_qconfig_, we'll + # need to set the qconfig explicitly here in case regex + # is used + modules[node.target].qconfig = qconfig_with_device_check + else: + qconfig_with_device_check = None + + node_name_to_qconfig[node.name] = qconfig_with_device_check + return node_name_to_qconfig + + +def _check_is_valid_config_dict( + config_dict: Any, allowed_keys: set[str], dict_name: str +) -> None: + r"""Checks if the given config_dict has the correct keys + + Args: + `config_dict`: dictionary whose keys we want to check + """ + + for k in config_dict: + if k not in allowed_keys: + raise ValueError( + "Expected " + + dict_name + + " to have the following keys: " + + str(allowed_keys) + + ". But found '" + + k + + "' instead." + ) + + +def _compare_prepare_convert_qconfig_mappings( + prepare_qconfig_mapping: QConfigMapping, convert_qconfig_mapping: QConfigMapping +): + r"""Compare the qconfig_mapping passed in convert to the one from prepare and check the values + + Args: + `prepare_qconfig_mapping`: configuration for prepare quantization step + `convert_qconfig_mapping`: configuration for convert quantization step + """ + if not qconfig_equals( + prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig + ): + raise AssertionError( + "Expected global qconfigs to be the same in the prepare and convert quantization configs" + ) + prepare_dicts: list[OrderedDict] = [ + prepare_qconfig_mapping.object_type_qconfigs, + prepare_qconfig_mapping.module_name_qconfigs, + prepare_qconfig_mapping.module_name_regex_qconfigs, + ] + convert_dicts: list[OrderedDict] = [ + convert_qconfig_mapping.object_type_qconfigs, + convert_qconfig_mapping.module_name_qconfigs, + convert_qconfig_mapping.module_name_regex_qconfigs, + ] + dict_names = [ + _OBJECT_TYPE_DICT_KEY, + _MODULE_NAME_DICT_KEY, + _MODULE_NAME_REGEX_DICT_KEY, + ] + for i in range(len(prepare_dicts)): + for name in prepare_dicts[i]: + if name not in convert_dicts[i]: + raise AssertionError( + f"Missing key {dict_names[i]} {name} in convert QConfigMapping when it was present in prepare" + ) + if convert_dicts[i][name] is not None and not qconfig_equals( + prepare_dicts[i][name], convert_dicts[i][name] + ): + raise AssertionError( + "Expected convert QConfigMapping to have the same qconfig as prepare for key " + f"{dict_names[i]} {name}; prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}" + ) + + +def _is_qconfig_supported_by_dtype_configs( + qconfig: QConfig, dtype_configs: list[DTypeConfig] +): + for dtype_config in dtype_configs: + is_dynamic = dtype_config.is_dynamic + if is_dynamic is None: + is_dynamic = False + input_dtype = dtype_config.input_dtype or torch.float + weight_dtype = dtype_config.weight_dtype or torch.float + bias_dtype = dtype_config.bias_dtype or torch.float + output_dtype = dtype_config.output_dtype or torch.float + ( + qconfig_activation_dtype, + qconfig_weight_dtype, + qconfig_input_act_is_dynamic, + ) = get_qconfig_dtypes(qconfig) + qconfig_bias_dtype = ( + torch.float16 + if ( + qconfig_activation_dtype == torch.float16 + and qconfig_weight_dtype == torch.float16 + and not is_dynamic + ) + else torch.float + ) + + if is_dynamic: + is_match = ( + qconfig_input_act_is_dynamic + and input_dtype == qconfig_activation_dtype + and output_dtype == torch.float + and weight_dtype == qconfig_weight_dtype + ) + else: + is_match = ( + input_dtype == qconfig_activation_dtype + and output_dtype == qconfig_activation_dtype + and weight_dtype == qconfig_weight_dtype + and bias_dtype == qconfig_bias_dtype + ) + if is_match: + return True + return False + + +def _get_object_type_qconfig( + qconfig_mapping: QConfigMapping, + object_type: Callable | str, + fallback_qconfig: QConfigAny, +) -> QConfigAny: + return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig) + + +def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig): + for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items(): + if re.match(regex_pattern, module_name): + # first match wins + return qconfig + return fallback_qconfig + + +def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig): + if module_name == "": + # module name qconfig not found + return fallback_qconfig + if module_name in qconfig_mapping.module_name_qconfigs: + return qconfig_mapping.module_name_qconfigs[module_name] + else: + parent, _ = _parent_name(module_name) + return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig) + + +def _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, module_type, module_name, global_qconfig +): + # get qconfig for module_name, + # fallback to module_name_regex_qconfig, module_type_qconfig, + # global_qconfig if necessary + module_type_qconfig = _get_object_type_qconfig( + qconfig_mapping, module_type, global_qconfig + ) + module_name_regex_qconfig = _get_module_name_regex_qconfig( + qconfig_mapping, module_name, module_type_qconfig + ) + module_name_qconfig = _get_module_name_qconfig( + qconfig_mapping, module_name, module_name_regex_qconfig + ) + return module_name_qconfig + + +def _get_flattened_qconfig_dict( + qconfig_mapping: QConfigMapping, +) -> dict[Callable | str, QConfigAny]: + """flatten the global, object_type and module_name qconfig + to the same qconfig_dict so that it can be used by + propagate_qconfig_ function. + "module_name_regex" is ignored for now since it's not supported + in propagate_qconfig_, but it can be fixed later. + + For example: + Input: { + "": qconfig, + "object_type": [ + (torch.add, qconfig) + ], + "module_name": [ + ("conv", qconfig) + ] + } + + Output: { + "": qconfig, + torch.add: qconfig, + "conv": qconfig + } + """ + flattened: dict[Callable | str, QConfigAny] = {"": qconfig_mapping.global_qconfig} + flattened.update(qconfig_mapping.object_type_qconfigs) + flattened.update(qconfig_mapping.module_name_qconfigs) # type: ignore[arg-type] + return flattened + + +def _update_qconfig_for_qat( + qconfig_mapping: QConfigMapping, backend_config: BackendConfig +): + """ + Update the qconfig_mapping to account for module swaps during QAT. + During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types. + """ + module_to_qat_module_class = get_module_to_qat_module(backend_config) + object_type_dict = qconfig_mapping.object_type_qconfigs + new_object_type_dict = object_type_dict.copy() + for k, v in new_object_type_dict.items(): + if k in module_to_qat_module_class: + object_type_dict[module_to_qat_module_class[k]] = v diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/quantize_handler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/quantize_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd8d7fe3a17439b46ad673a5aaf7eae28b7082f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/quantize_handler.py @@ -0,0 +1,226 @@ +# mypy: allow-untyped-defs +from abc import ABC +from collections.abc import Callable + +import torch +from torch.ao.quantization.backend_config import ( + BackendConfig, + DTypeConfig, + ObservationType, +) +from torch.ao.quantization.utils import NodePattern, Pattern, QuantizerCls +from torch.fx.graph import Node + +from .utils import all_node_args_have_no_tensors + + +__all__ = [ + "QuantizeHandler", + "BinaryOpQuantizeHandler", + "CatQuantizeHandler", + "ConvReluQuantizeHandler", + "LinearReLUQuantizeHandler", + "BatchNormQuantizeHandler", + "EmbeddingQuantizeHandler", + "RNNDynamicQuantizeHandler", + "DefaultNodeQuantizeHandler", + "FixedQParamsOpQuantizeHandler", + "CopyNodeQuantizeHandler", + "GeneralTensorShapeOpQuantizeHandler", + "CustomModuleQuantizeHandler", + "StandaloneModuleQuantizeHandler", +] + + +def _default_root_node_getter(node_pattern): + if node_pattern is None: + return node_pattern + while not isinstance(node_pattern, Node): + node_pattern = node_pattern[-1] + return node_pattern + + +# Base Pattern Handler +class QuantizeHandler(ABC): # noqa: B024 + """Base handler class for the quantizer patterns""" + + def __init__( + self, + node_pattern: NodePattern, + modules: dict[str, torch.nn.Module], + root_node_getter: Callable | None = None, + is_custom_module=False, + is_standalone_module=False, + ): + """Records pattern information in __init__, which will be used + in convert + """ + self.node_pattern = node_pattern + self.modules = modules + if root_node_getter is None: + root_node_getter = _default_root_node_getter + self.root_node = root_node_getter(node_pattern) + self.is_custom_module_ = is_custom_module + self.is_standalone_module_ = is_standalone_module + self.num_tensor_args = 0 + # determine how many of the first two args are Tensors (versus scalars) + # this distinguishes things like "x + y" from "x + 2" or "2 + x" + if isinstance(self.root_node, Node): + cache_for_no_tensor_check: dict[Node, bool] = {} + for arg_idx in range(len(self.root_node.args)): + arg = self.root_node.args[arg_idx] + if isinstance(arg, Node) and ( + not all_node_args_have_no_tensors( + arg, self.modules, cache_for_no_tensor_check + ) + ): + self.num_tensor_args += 1 + + def is_general_tensor_value_op(self) -> bool: + """ + Returns True if the operator works for both floating point and + quantized input, and does some computation based on the input Tensor, + or the ops that only re-arranges the Tensor values or query some metadata + about the Tensor + so we need to insert observer/fake_quant for the output of the + operator (same observer instance as input) + since the distribution of values is different for input and output + Tensors (for HistogramObserver) while they share the same quantization + parameters + Example operator: avgpool2d, reshape, transpose, maxpool2d + Example observed operator: + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + return False + + def is_custom_module(self): + return self.is_custom_module_ + + def is_standalone_module(self): + return self.is_standalone_module_ + + +def _get_quantize_handler_cls( + observation_type: ObservationType, + dtype_configs: list[DTypeConfig], + num_tensor_args_to_observation_type: dict[int, ObservationType], +) -> type[QuantizeHandler]: + """ + Return a configurable QuantizeHandler that matches the given specifications from the backend. + """ + + class ConfigurableQuantizeHandler(QuantizeHandler): + def __init__( + self, + node_pattern: NodePattern, + modules: dict[str, torch.nn.Module], + root_node_getter: Callable | None = None, + ): + super().__init__(node_pattern, modules, root_node_getter) + if num_tensor_args_to_observation_type: + if self.num_tensor_args not in num_tensor_args_to_observation_type: + raise AssertionError( + f"Must provide observation_type config for tensor number {self.num_tensor_args}" + f" in num_tensor_args_to_observation_type for {node_pattern}" + ) + self.observation_type = num_tensor_args_to_observation_type[ + self.num_tensor_args + ] + else: + self.observation_type = observation_type + self.dtype_configs = dtype_configs + + def is_general_tensor_value_op(self) -> bool: + return ( + self.observation_type + == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + ) + + return ConfigurableQuantizeHandler + + +def _get_pattern_to_quantize_handlers( + backend_config: BackendConfig, +) -> dict[Pattern, QuantizerCls]: + """ + Note: Quantize handler is just a holder for some check methods like + (should_insert_observer_for_output), maybe this can be a enum as well, + we can refactor this after we convert the path for fbgemm/qnnpack fully to the + new path, this is not exposed to backend developers + """ + pattern_to_quantize_handlers = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + observation_type = config.observation_type + dtype_configs = config.dtype_configs + num_tensor_args_to_observation_type = ( + config._num_tensor_args_to_observation_type + ) + pattern_to_quantize_handlers[pattern] = _get_quantize_handler_cls( + observation_type, dtype_configs, num_tensor_args_to_observation_type + ) + return pattern_to_quantize_handlers + + +# TODO: remove this class, this is still exposed in torch.ao.quantization +# but we should be able to break bc +class BinaryOpQuantizeHandler(QuantizeHandler): + pass + + +class CatQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class ConvReluQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class LinearReLUQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class BatchNormQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class EmbeddingQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class RNNDynamicQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class DefaultNodeQuantizeHandler(QuantizeHandler): + """Common quantized op, first input and first output will be quantized""" + + +# TODO: remove this class +class FixedQParamsOpQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove +class CopyNodeQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove +class GeneralTensorShapeOpQuantizeHandler(QuantizeHandler): + pass + + +# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated +class CustomModuleQuantizeHandler(QuantizeHandler): + pass + + +# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated +class StandaloneModuleQuantizeHandler(QuantizeHandler): + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/tracer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..2c1635936845a44ab895a3c6b0c5e07e9ec9951e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/tracer.py @@ -0,0 +1,48 @@ +from collections.abc import Callable + +import torch +from torch.ao.nn.intrinsic import _FusedModule +from torch.fx._symbolic_trace import Tracer +from torch.fx.proxy import Scope + + +__all__ = [ + "QuantizationTracer", +] + + +class ScopeContextManager(torch.fx.proxy.ScopeContextManager): + def __init__( + self, scope: Scope, current_module: torch.nn.Module, current_module_path: str + ): + super().__init__(scope, Scope(current_module_path, type(current_module))) + + +class QuantizationTracer(Tracer): + def __init__( + self, skipped_module_names: list[str], skipped_module_classes: list[Callable] + ): + super().__init__() + self.skipped_module_names = skipped_module_names + self.skipped_module_classes = skipped_module_classes + # NB: initialized the module_type of top level module to None + # we are assuming people won't configure the model with the type of top level + # module here, since people can use "" for global config + # We can change this if there is a use case that configures + # qconfig using top level module type + self.scope = Scope("", None) + self.record_stack_traces = True + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + return ( + ( + ( + m.__module__.startswith("torch.nn") + or m.__module__.startswith("torch.ao.nn") + ) + and not isinstance(m, torch.nn.Sequential) + ) + or module_qualified_name in self.skipped_module_names + or type(m) in self.skipped_module_classes + or isinstance(m, _FusedModule) + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0a46d2057c5480ae036bbb847cf3d9bb185b29ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/utils.py @@ -0,0 +1,997 @@ +# mypy: allow-untyped-defs +import copy +import functools +import operator +import warnings +from collections import namedtuple +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn +from torch.ao.quantization import QConfigAny, QuantType +from torch.ao.quantization.backend_config import DTypeWithConstraints +from torch.ao.quantization.fake_quantize import ( + FakeQuantizeBase, + FixedQParamsFakeQuantize, +) +from torch.ao.quantization.observer import ( + _is_activation_post_process, + FixedQParamsObserver, + ObserverBase, +) +from torch.ao.quantization.qconfig import ( + float16_dynamic_qconfig, + float16_static_qconfig, + qconfig_equals, +) +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.stubs import DeQuantStub +from torch.ao.quantization.utils import ( + _assert_and_get_unique_device, + activation_is_statically_quantized, +) +from torch.fx import GraphModule, map_arg +from torch.fx.graph import Graph, Node + +# importing the lib so that the quantized_decomposed ops are registered +from ._decomposed import quantized_decomposed_lib # noqa: F401 +from .custom_config import PrepareCustomConfig + + +# TODO: revisit this list. Many helper methods shouldn't be public +__all__ = [ + "all_node_args_except_first", + "all_node_args_have_no_tensors", + "assert_and_get_unique_device", + "collect_producer_nodes", + "create_getattr_from_value", + "create_node_from_old_node_preserve_meta", + "EMPTY_ARG_DICT", + "get_custom_module_class_keys", + "get_linear_prepack_op_for_dtype", + "get_new_attr_name_with_prefix", + "get_non_observable_arg_indexes_and_types", + "get_qconv_prepack_op", + "get_skipped_module_name_and_classes", + "graph_module_from_producer_nodes", + "maybe_get_next_module", + "NodeInfo", + "node_arg_is_bias", + "node_arg_is_weight", + "NON_OBSERVABLE_ARG_DICT", + "NON_QUANTIZABLE_WEIGHT_OPS", + "return_arg_list", + "ObservedGraphModuleAttrs", +] + +NON_QUANTIZABLE_WEIGHT_OPS = { + torch.nn.functional.layer_norm, + torch.nn.functional.group_norm, + torch.nn.functional.instance_norm, +} + + +@dataclass +class ObservedGraphModuleAttrs: + node_name_to_qconfig: dict[str, QConfigAny] + node_name_to_scope: dict[str, tuple[str, type]] + prepare_custom_config: PrepareCustomConfig + equalization_node_name_to_qconfig: dict[str, Any] + qconfig_mapping: QConfigMapping + is_qat: bool + observed_node_names: set[str] + is_observed_standalone_module: bool = False + standalone_module_input_quantized_idxs: list[int] | None = None + standalone_module_output_quantized_idxs: list[int] | None = None + + +def node_arg_is_weight(node: Node, arg: Any) -> bool: + """Returns if node arg is weight""" + weight_index = None + if "target_dtype_info" in node.meta: + weight_index = node.meta["target_dtype_info"].get("weight_index", None) + if ( + weight_index is not None + and weight_index < len(node.args) + and node.args[weight_index] is arg + ): + return True + return node.kwargs.get("weight") is arg + + +def node_arg_is_bias(node: Node, arg: Any) -> bool: + """Returns if node arg is bias""" + bias_index = None + if "target_dtype_info" in node.meta: + bias_index = node.meta["target_dtype_info"].get("bias_index", None) + if ( + bias_index is not None + and bias_index < len(node.args) + and node.args[bias_index] is arg + ): + return True + return node.kwargs.get("bias") is arg + + +def get_custom_module_class_keys( + custom_module_mapping: dict[QuantType, dict[type, type]], +) -> list[Any]: + r"""Get all the unique custom module keys in the custom config dict + e.g. + Input: + { + QuantType.STATIC: { + CustomModule1: ObservedCustomModule + }, + QuantType.DYNAMIC: { + CustomModule2: DynamicObservedCustomModule + }, + QuantType.WEIGHT_ONLY: { + CustomModule3: WeightOnlyObservedCustomModule + }, + } + + Output: + # extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts + [CustomModule1, CustomModule2, CustomModule3] + """ + # using set to dedup + float_custom_module_classes: set[Any] = set() + for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]: + quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) + quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys()) + float_custom_module_classes |= quant_mode_custom_module_classes + return list(float_custom_module_classes) + + +def get_linear_prepack_op_for_dtype(dtype): + if dtype == torch.float16: + return torch.ops.quantized.linear_prepack_fp16 + elif dtype == torch.qint8: + return torch.ops.quantized.linear_prepack + else: + raise Exception("can't get linear prepack op for dtype:", dtype) # noqa: TRY002 + + +def get_qconv_prepack_op(conv_op: Callable) -> Callable: + prepack_ops = { + torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack, + torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack, + torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack, + torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack, + torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack, + torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack, + } + prepack_op = prepack_ops.get(conv_op) + if prepack_op is None: + raise AssertionError(f"Didn't find prepack op for {conv_op}") + return prepack_op + + +# Returns a function that can get a new attribute name for module with given +# prefix, for example, +# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer') +# >> new_name = get_new_observer_name(module) +# new_name will be an unused attribute name on module, e.g. `_observer_1` +def get_new_attr_name_with_prefix(prefix: str) -> Callable: + prefix = prefix.replace(".", "_") + + def get_new_attr_name(module: torch.nn.Module): + def get_attr_name(i: int): + return prefix + str(i) + + i = 0 + attr_name = get_attr_name(i) + while hasattr(module, attr_name): + i += 1 + attr_name = get_attr_name(i) + return attr_name + + return get_new_attr_name + + +def collect_producer_nodes(node: Node) -> list[Node] | None: + r"""Starting from a target node, trace back until we hit input or + getattr node. This is used to extract the chain of operators + starting from getattr to the target node, for example:: + + def forward(self, x): + observed = self.observer(self.weight) + return F.linear(x, observed) + + collect_producer_nodes(observed) will either return a list of nodes that + produces the observed node or None if we can't extract a self contained + graph without free variables(inputs of the forward function). + """ + nodes = [node] + frontier = [node] + while frontier: + node = frontier.pop() + all_args = list(node.args) + list(node.kwargs.values()) + for arg in all_args: + if not isinstance(arg, Node): + continue + if arg.op == "placeholder": + # hit input, can't fold in this case + return None + nodes.append(arg) + if not (arg.op == "call_function" and arg.target is getattr): + frontier.append(arg) + return nodes + + +def graph_module_from_producer_nodes( + root: GraphModule, producer_nodes: list[Node] +) -> GraphModule: + r"""Construct a graph module from extracted producer nodes + from `collect_producer_nodes` function + Args: + root: the root module for the original graph + producer_nodes: a list of nodes we use to construct the graph + Return: + A graph module constructed from the producer nodes + """ + if len(producer_nodes) == 0: + raise AssertionError("list of producer nodes can not be empty") + # since we traced back from node to getattr + producer_nodes.reverse() + graph = Graph() + env: dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node]) + + for producer_node in producer_nodes: + env[producer_node] = graph.node_copy(producer_node, load_arg) + graph.output(load_arg(producer_nodes[-1])) + graph_module = GraphModule(root, graph) + return graph_module + + +# TODO: delete +@functools.cache +def assert_and_get_unique_device(module: torch.nn.Module) -> Any: + """ + Returns the unique device for a module, or None if no device is found. + Throws an error if multiple devices are detected. + """ + return _assert_and_get_unique_device(module) + + +def create_getattr_from_value( + module: torch.nn.Module, + graph: Graph, + prefix: str, + value: Any, + device: torch.device | None = None, +) -> Node: + """ + Given a value of any type, creates a getattr node corresponding to the value and + registers the value as a buffer to the module. + """ + get_new_attr_name = get_new_attr_name_with_prefix(prefix) + attr_name = get_new_attr_name(module) + if device is None: + device = assert_and_get_unique_device(module) + new_value = ( + value.detach().clone() + if isinstance(value, torch.Tensor) + else torch.tensor(value, device=device) + ) + module.register_buffer(attr_name, new_value) + # Create get_attr with value + attr_node = graph.create_node("get_attr", attr_name) + return attr_node + + +def all_node_args_have_no_tensors( + node: Node, modules: dict[str, torch.nn.Module], cache: dict[Node, bool] +) -> bool: + """ + If we know for sure that all of this node's args have no + tensors (are primitives), return True. If we either + find a tensor or are not sure, return False. Note: this + function is not exact. + """ + if cache and node in cache: + return cache[node] + + result = False # will be overwritten + if not isinstance(node, Node): + result = True + elif node.op == "placeholder": + result = False + elif node.op == "call_module": + if not isinstance(node.target, str): + raise AssertionError("node.target must be a string for call_module nodes") + if _is_activation_post_process(modules[node.target]): + result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] + elif node.op == "call_module": + result = False + elif node.op == "call_function" and node.target is operator.getitem: + result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] + elif node.op == "get_attr": + result = False + elif node.target is getattr and node.args[1] in ["ndim", "shape"]: + # x1 = x0.ndim + result = True + elif node.op == "call_method" and node.target == "size": + # x1 = x0.size(0) + result = True + else: + found_one_tensor = False + for arg in node.args: + if isinstance(arg, list): + for list_el in arg: + if isinstance(list_el, Node): + this_list_el_args_have_no_tensors = ( + all_node_args_have_no_tensors(list_el, modules, cache) + ) + found_one_tensor = found_one_tensor or ( + not this_list_el_args_have_no_tensors + ) + # If found_one_tensor is True, there is no point in + # recursing further as the end result will always + # be True. + # TODO(future PR): remove this entire function and + # change to dtype inference without recursion. + if found_one_tensor: + result = not found_one_tensor + if cache: + cache[node] = result + return result + elif isinstance(arg, int): + pass + else: + if isinstance(arg, Node): + this_arg_args_have_no_tensors = all_node_args_have_no_tensors( + arg, modules, cache + ) + found_one_tensor = found_one_tensor or ( + not this_arg_args_have_no_tensors + ) + # If found_one_tensor is True, there is no point in + # recursing further as the end result will always + # be True. + # TODO(future PR): remove this entire function and + # change to dtype inference without recursion. + if found_one_tensor: + result = not found_one_tensor + if cache: + cache[node] = result + return result + else: + found_one_tensor = True + result = not found_one_tensor + if cache: + cache[node] = result + return result + + +def all_node_args_except_first(node: Node) -> list[int]: + """ + Returns all node arg indices after first + """ + return list(range(1, len(node.args))) + + +def return_arg_list(arg_indices: list[int]) -> Callable[[Node], list[int]]: + """ + Constructs a function that takes a node as arg and returns the arg_indices + that are valid for node.args + """ + + def arg_indices_func(node: Node) -> list[int]: + return [i for i in arg_indices if i < len(node.args)] + + return arg_indices_func + + +NodeInfo = namedtuple("NodeInfo", "op target") + +# this dict identifies which indices of a node are non tensors +# so that they can be propagated correctly since inserting observers +# for them would cause errors + +NON_OBSERVABLE_ARG_DICT: dict[ + NodeInfo, dict[type | torch.dtype, Callable[[Node], list[int]]] +] = { + NodeInfo("call_method", "masked_fill"): { + torch.bool: return_arg_list([1]), + float: return_arg_list([2]), + }, + NodeInfo("call_method", "permute"): {int: all_node_args_except_first}, + NodeInfo("call_method", "repeat"): {int: all_node_args_except_first}, + NodeInfo("call_method", "reshape"): {int: all_node_args_except_first}, + NodeInfo("call_method", "size"): {int: return_arg_list([1])}, + NodeInfo("call_method", "transpose"): {int: all_node_args_except_first}, + NodeInfo("call_method", torch.transpose): {int: all_node_args_except_first}, + NodeInfo("call_method", "unsqueeze"): {int: return_arg_list([1])}, + NodeInfo("call_method", "unsqueeze_"): {int: return_arg_list([1])}, + NodeInfo("call_method", torch.unsqueeze): {int: return_arg_list([1])}, + NodeInfo("call_method", "view"): {int: all_node_args_except_first}, +} + +EMPTY_ARG_DICT: dict[type | torch.dtype, Callable[[Node], list[int]]] = {} + + +def get_non_observable_arg_indexes_and_types( + node: Node, +) -> dict[type | torch.dtype, Callable[[Node], list[int]]]: + """ + Returns a dict with of non float tensor types as keys and values which correspond to a + function to retrieve the list (which takes the node as an argument) + """ + info = NodeInfo(node.op, node.target) + + return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT) + + +def maybe_get_next_module( + node: Node, + modules: dict[str, nn.Module], + target_module_type: type[nn.Module] | None = None, + target_functional_type: Any = None, +) -> Node | None: + """Gets the next module that matches what is needed in + is_target_module_type if it exists + + Args: + node: The node whose users we want to look at + target_module_type: Module type that we want to check + target_functional_type: Functional type that we want to check + """ + + for user in node.users: + if ( + user.op == "call_module" + and target_module_type is not None + and isinstance(modules[str(user.target)], target_module_type) + ): + return user + elif ( + user.op == "call_function" + and target_functional_type is not None + and user.target == target_functional_type + ): + return user + + return None + + +def create_node_from_old_node_preserve_meta( + quantized_graph: Graph, + create_node_args: tuple[Any, ...], + old_node: Node, +) -> Node: + """ + Creates `new_node` and copies the necessary metadata to it from `old_node`. + """ + new_node = quantized_graph.create_node(*create_node_args) + new_node.stack_trace = old_node.stack_trace + return new_node + + +def get_skipped_module_name_and_classes( + prepare_custom_config: PrepareCustomConfig, is_standalone_module: bool +) -> tuple[list[str], list[type[Any]]]: + skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names) + skipped_module_classes = copy.copy( + prepare_custom_config.non_traceable_module_classes + ) + if not is_standalone_module: + # standalone module and custom module config are applied in top level module + skipped_module_names += list( + prepare_custom_config.standalone_module_names.keys() + ) + skipped_module_classes += list( + prepare_custom_config.standalone_module_classes.keys() + ) + skipped_module_classes += get_custom_module_class_keys( + prepare_custom_config.float_to_observed_mapping + ) + + return skipped_module_names, skipped_module_classes + + +def _is_custom_module_lstm( + node: Node, + named_modules: dict[str, torch.nn.Module], + qconfig: QConfigAny = None, + # QuantizeHandler, but we cannot include the type here due to circular imports + qhandler: Any | None = None, +) -> bool: + """ + Return whether this refers to the custom module LSTM flow. + """ + mod = _get_module(node, named_modules) + if qconfig is not None and qhandler is not None: + if not isinstance( + qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler + ): # type: ignore[attr-defined] + raise AssertionError("qhandler must be a QuantizeHandler when provided") + return ( + isinstance(mod, torch.nn.LSTM) + and activation_is_statically_quantized(qconfig) + and qhandler.is_custom_module() + ) + else: + return isinstance(mod, torch.ao.nn.quantizable.LSTM) + + +def _is_custom_module_mha( + node: Node, + named_modules: dict[str, torch.nn.Module], + qconfig: QConfigAny = None, + # QuantizeHandler, but we cannot include the type here due to circular imports + qhandler: Any | None = None, +) -> bool: + """ + Return whether this refers to the custom module MultiheadAttention flow. + """ + mod = _get_module(node, named_modules) + if qconfig is not None and qhandler is not None: + if not isinstance( + qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler + ): # type: ignore[attr-defined] + raise AssertionError("qhandler must be a QuantizeHandler when provided") + return ( + isinstance(mod, torch.nn.MultiheadAttention) + and activation_is_statically_quantized(qconfig) + and qhandler.is_custom_module() + ) + else: + return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention) + + +def _get_module( + node: Node, named_modules: dict[str, torch.nn.Module] +) -> torch.nn.Module | None: + """ + If `node` refers to a call_module node, return the module, else None. + """ + if node.op == "call_module" and str(node.target) in named_modules: + return named_modules[str(node.target)] + else: + return None + + +def _insert_dequant_stub( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, +) -> Node: + """ + Attach a `DeQuantStub` to the model and create a node that calls this + `DeQuantStub` on the output of `node`, similar to how observers are inserted. + """ + prefix = "dequant_stub_" + get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix) + dequant_stub_name = get_new_dequant_stub_name(model) + dequant_stub = DeQuantStub() + setattr(model, dequant_stub_name, dequant_stub) + named_modules[dequant_stub_name] = dequant_stub + with graph.inserting_after(node): + return graph.call_module(dequant_stub_name, (node,)) + + +def _insert_dequant_stubs_for_custom_module_lstm_output( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, +) -> Node: + """ + Insert DeQuantStubs after each internal output node of custom module LSTM. + + Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)), + Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its + components through `getitem`. This function transforms the graph as follows: + + (1) Split the LSTM node into (output, (hidden0, hidden1)) + (2) Insert a DeQuantStub after each internal node + (3) Recombine the DeQuantStubs into the same structure as before + (4) Reroute all consumers of the original LSTM node and its sub-nodes + (e.g. lstm[0]) + + Before: + lstm_output + | + v + original_user(s) + After: + lstm_output + / \\ + / (getitem) \\ + / \\ + v v + output hidden + | / \\ + (DeQuantStub) (getitem) + | / \\ + v v v + output_dq hidden0 hidden1 + | | | + | (DeQuantStub) (DeQuantStub) + | | | + | v v + | hidden0_dq hidden1_dq + | \\ / + | (tuple) + | \\ / + | v v + | hidden_dq + \\ / + \\ (tuple) / + v v + lstm_output_dq + | + v + original_user(s) + + For step (4), reroute all users of the original LSTM node(s) as follows: + lstm_output -> lstm_output_dq + lstm_output[0] -> output_dq + lstm_output[1] -> hidden_dq + lstm_output[1][0] -> hidden0_dq + lstm_output[1][1] -> hidden1_dq + + Return the node `lstm_output_dq`. + """ + # (1) Split the LSTM node into (output, (hidden0, hidden1)) + # (2) Insert a DeQuantStub after each internal node + with graph.inserting_after(node): + output = graph.call_function(operator.getitem, (node, 0)) + output_dq = _insert_dequant_stub(output, model, named_modules, graph) + with graph.inserting_after(output_dq): + hidden = graph.call_function(operator.getitem, (node, 1)) + with graph.inserting_after(hidden): + hidden0 = graph.call_function(operator.getitem, (hidden, 0)) + hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph) + with graph.inserting_after(hidden0_dq): + hidden1 = graph.call_function(operator.getitem, (hidden, 1)) + hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph) + + # (3) Recombine the DeQuantStubs into the same structure as before + with graph.inserting_after(hidden1_dq): + hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],)) + with graph.inserting_after(hidden_dq): + lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],)) + + # (4) Reroute all consumers of the original LSTM node and its sub-nodes + for user in list(node.users.keys()): + if user != output and user != hidden: + user.replace_input_with(node, lstm_output_dq) + # The getitem and tuple nodes we added here may interfere with reference quantized + # pattern matching, so we need to redirect the consumers of internal nodes to the + # corresponding nodes with DeQuantStubs (e.g. lstm_output_dq[0] -> output_dq) attached, + # in order to preserve reference patterns like "dequantize - consumer - quantize". + _reroute_tuple_getitem_pattern(graph) + return lstm_output_dq + + +def _maybe_get_custom_module_lstm_from_node_arg( + arg: Node, + named_modules: dict[str, torch.nn.Module], +) -> Node | None: + """ + Given an argument of a node, if the argument refers to the path through which the node + is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise. + + This is used to determine whether a node is a consumer of custom module LSTM, and, if so, + skip inserting input observers for this node. This is because custom module LSTM produces + quantized outputs, so inserting an input observer for the consumer of custom module LSTM + would unnecessarily quantize the outputs again. + + lstm -> consumer + + In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with + DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`). + This tuple can be consumed in one of four ways: + + lstm -> getitem -> DeQuantStub -> consumer # consume lstm[0] + lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm[1] + lstm -> getitem -> getitem -> DeQuantStub -> consumer # consume lstm[1][0] or lstm[1][1] + lstm -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm + + Thus, we must match against the above patterns instead of simply checking the parent node + to determine whether this node is a consumer of a custom module LSTM. + """ + + def match_dq(a): + return isinstance(_get_module(a, named_modules), DeQuantStub) + + def match_lstm(a): + return _is_custom_module_lstm(a, named_modules) + + def match_getitem(a): + return a.op == "call_function" and a.target is operator.getitem + + def match_tuple(a): + return a.op == "call_function" and a.target is tuple + + def _match_pattern(match_pattern: list[Callable]) -> Node | None: + """ + Traverse up the graph and match the args one by one. + If there is a match, return the last matched node, or None otherwise. + """ + a = arg + for i, match in enumerate(match_pattern): + if not match(a): + return None + # Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],) + if i < len(match_pattern) - 1: + if match is match_tuple: + a = a.args[0][0] # type: ignore[assignment,index] + else: + a = a.args[0] # type: ignore[assignment] + # pyrefly: ignore [bad-return] + return a + + all_match_patterns = [ + [match_dq, match_getitem, match_lstm], + [match_tuple, match_dq, match_getitem, match_getitem, match_lstm], + [match_dq, match_getitem, match_getitem, match_lstm], + [match_tuple, match_dq, match_getitem, match_lstm], + ] + + for p in all_match_patterns: + matched_node = _match_pattern(p) + if matched_node is not None: + return matched_node + return None + + +def _reroute_tuple_getitem_pattern(graph: Graph): + """ + Search for patterns where N consecutive `tuple` call_function nodes are followed by + N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes. + If we find this pattern, reroute the consumers of the last `getitem` to skip these + N `tuple` and `getitem` nodes. + + Before: + + a b c + | \\ / + \\ tuple + \\ / + tuple + | + getitem(1) + | + getitem(0) + | + d + + After: + + b + | + d + """ + + def find_patterns( + node: Node, + index_stack: list[int], + current_pattern: list[Node], + matched_patterns: list[list[Node]], + seen: set[tuple[Node, tuple[int, ...]]], + ): + """ + Traverse the graph recursively to match for the N-tuple - N-getitem patterns, + starting at the given node. + + We use a stack to keep track of the expected `getitem` indices, since these are + reversed from the `tuple` indices. In the above example, the stack after + (b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first + and then by getitem(0). + + TODO: traverse upwards from the output and handle the case when tuple is not a + separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c))) + """ + if len(index_stack) == 0 and len(current_pattern) > 0: + matched_patterns.append(copy.copy(current_pattern)) + current_pattern.clear() + + # Avoid duplicating work + state = (node, tuple(index_stack)) + if state in seen: + return + seen.add(state) + + # Iterate through users of this node to find tuple/getitem nodes to match + for user in node.users: + if user.op == "call_function" and user.target is tuple: + for i, user_arg in enumerate(user.args[0]): # type: ignore[arg-type] + if user_arg == node: + index_stack.append(i) + current_pattern.append(user) + find_patterns( + user, index_stack, current_pattern, matched_patterns, seen + ) + elif user.op == "call_function" and user.target is operator.getitem: + if len(index_stack) > 0: + if user.args[1] == index_stack[-1]: + index_stack.pop() + current_pattern.append(user) + find_patterns( + user, index_stack, current_pattern, matched_patterns, seen + ) + return matched_patterns + + # Collect all matched patterns + matched_patterns: list[list[Node]] = [] + seen: set[tuple[Node, tuple[int, ...]]] = set() # (node, index_stack) + for node in graph.nodes: + find_patterns(node, [], [], matched_patterns, seen) + + # For each pattern, redirect all consumers of the last getitem node to the correct input + # of the first tuple node + for pattern in matched_patterns: + first_tuple = pattern[0] + last_getitem = pattern[-1] + if not (first_tuple.op == "call_function" and first_tuple.target is tuple): + raise AssertionError( + "first tuple node must be a call_function with target tuple" + ) + if not ( + last_getitem.op == "call_function" + and last_getitem.target is operator.getitem + ): + raise AssertionError( + "last getitem node must be a call_function with target operator.getitem" + ) + last_getitem_index = last_getitem.args[1] + new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index] + for user in list(last_getitem.users.keys()): + user.replace_input_with(last_getitem, new_input) # type: ignore[arg-type] + + +def _get_observer_from_activation_post_process( + activation_post_process: ObserverBase | FakeQuantizeBase, +) -> ObserverBase: + """ + If `activation_post_process` is an observer, return the observer. + If `activation_post_process` is a fake quantize, return the internal observer. + """ + if isinstance(activation_post_process, ObserverBase): + return activation_post_process + else: + if not isinstance(activation_post_process, FakeQuantizeBase): + raise AssertionError( + "activation_post_process must be an ObserverBase or FakeQuantizeBase" + ) + return activation_post_process.activation_post_process # type: ignore[return-value] + + +def _qconfig_satisfies_dtype_config_constraints( + qconfig: QConfigAny, + dtype_with_constraints: DTypeWithConstraints, + is_activation: bool = True, +) -> bool: + """ + Return whether `qconfig` satisfies the following constraints from the backend, + specified through the activation and weight DTypeWithConstraints. + + 1. QConfig specified a quantization range that falls within the backend's, if any + 2. QConfig specified a min scale value that is >= the backend's, if any + 3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has + scale and zero point that match the backend's, if any + + If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`. + If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True. + """ + + # TODO: log warnings only when the user enabled a debug flag + def _activation_post_process_satisfies_dtype_config_constraints( + activation_post_process: ObserverBase | FakeQuantizeBase, + dtype_with_constraints: DTypeWithConstraints, + debug_string: str, + ) -> bool: + observer = _get_observer_from_activation_post_process(activation_post_process) + app_quant_min = getattr(observer, "quant_min", None) + app_quant_max = getattr(observer, "quant_max", None) + # TODO: for now, just use the existing eps value as scale_min. In the future, we should + # resolve the differences between the two, either by renaming eps or some other way + app_scale_min = getattr(observer, "eps", None) + backend_quant_min = dtype_with_constraints.quant_min_lower_bound + backend_quant_max = dtype_with_constraints.quant_max_upper_bound + backend_scale_min = dtype_with_constraints.scale_min_lower_bound + backend_scale_exact_match = dtype_with_constraints.scale_exact_match + backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match + # check quantization ranges + if backend_quant_min is not None and backend_quant_max is not None: + if app_quant_min is None or app_quant_max is None: + warnings.warn( + f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}", + stacklevel=2, + ) + return False + elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max: + warnings.warn( + f"QConfig {debug_string} quantization range must fall within the backend's:\n" + f"QConfig range = ({app_quant_min}, {app_quant_max}), " + f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), " + f"ignoring {qconfig}", + stacklevel=2, + ) + return False + # check scale min + if backend_scale_min is not None: + if app_scale_min is None: + warnings.warn( + f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}", + stacklevel=2, + ) + return False + if app_scale_min < backend_scale_min: + warnings.warn( + f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to " + f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}", + stacklevel=2, + ) + return False + # check fixed scale and zero point + if ( + backend_scale_exact_match is not None + and backend_zero_point_exact_match is not None + ): + # For tests only, accept the following qconfigs for now + # TODO: handle fp16 qconfigs properly + for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]: + if qconfig_equals(qconfig, accepted_qconfig): + return True + suggestion_str = ( + "Please use torch.ao.quantization.get_default_qconfig_mapping or " + "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n" + ' qconfig_mapping = get_default_qconfig_mapping("fbgemm")\n' + " model = prepare_fx(model, qconfig_mapping, example_inputs)" + ) + if not isinstance( + activation_post_process, FixedQParamsObserver + ) and not isinstance(activation_post_process, FixedQParamsFakeQuantize): + warnings.warn( + f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize " + f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}", + stacklevel=2, + ) + return False + if ( + observer.scale != backend_scale_exact_match + or observer.zero_point != backend_zero_point_exact_match + ): + warnings.warn( + f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) " + f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), " + f"ignoring {qconfig}.\n{suggestion_str}", + stacklevel=2, + ) + return False + return True + + if qconfig is None or dtype_with_constraints.dtype is None: + return True + + activation_post_process_ctr = ( + qconfig.activation if is_activation else qconfig.weight + ) + debug_string = "activation" if is_activation else "weight" + satisfies_constraints = True + if activation_post_process_ctr is not None: + activation_post_process = activation_post_process_ctr() + if not _is_activation_post_process(activation_post_process): + raise AssertionError( + "activation_post_process must be an activation post process" + ) + # If dtypes don't match, don't check the activation_post_process and return True early + if activation_post_process.dtype != dtype_with_constraints.dtype: + return True + satisfies_constraints = ( + _activation_post_process_satisfies_dtype_config_constraints( + activation_post_process, dtype_with_constraints, debug_string + ) + ) + return satisfies_constraints diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8b186ab8389cbd2c231719c89dc2d7940194777 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_affine_quantization.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_affine_quantization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dcfefc3ef27f5d0adf356de6c849a78ad63b697 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_affine_quantization.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_numeric_debugger.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_numeric_debugger.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9a7cc45aaa60d26a3886944e03c7b3490500b37 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_numeric_debugger.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d1be5371e0dd381178773dfa223b9a910d107a2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44efa3110609de88bd5fa423dba6a1da5e024962 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e96ce5697a71356ac4c3b728c40b310933382c72 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/lowering.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/lowering.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cef3711819b9249a2b02e44d41ca2875eac8434 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/lowering.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cecb077fa9c80177b4876abaf366274d2b1f0886 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8018e2be69618ac7fe2773322edce3687fd2a54 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d1de3e0ee841ab669b8e6baf87da6b7e8e96a13 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..270e8338295802a82d1dace31868afb29e855227 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_affine_quantization.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_affine_quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..aa75f32eb8d801f271b51d12671bc2e4cf7e4eb5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_affine_quantization.py @@ -0,0 +1,891 @@ +# copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py +# and https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py +# PLEASE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC +import logging +from abc import ABCMeta +from typing import Any + +import torch +from torch.ao.quantization.observer import ( + AffineQuantizedObserverBase, + get_block_size, + Granularity, + MappingType, + TorchAODType, + ZeroPointDomain, +) + + +ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: + +logger = logging.getLogger(__name__) + +FP8_TYPES = { + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, +} +_SUB_BYTE_UINT_BOUNDS = { + torch.uint1: (0, 2**1 - 1), + torch.uint2: (0, 2**2 - 1), + torch.uint3: (0, 2**3 - 1), + torch.uint4: (0, 2**4 - 1), + torch.uint5: (0, 2**5 - 1), + torch.uint6: (0, 2**6 - 1), + torch.uint7: (0, 2**7 - 1), +} + +""" +Map from dtype to the bound value of integers +TODO: maybe can replace this with call to torch.iinfo +""" +_DTYPE_TO_QVALUE_BOUNDS: dict[torch.dtype | TorchAODType, tuple[int, int]] = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int16: (-(2**15), 2**15 - 1), + torch.int32: (-(2**31), 2**31 - 1), +} +_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS) + + +def _is_float8_type(dtype: torch.dtype) -> bool: + fp8_types = { + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + } + return dtype in fp8_types + + +# TODO: decide on if we want to allow custom quant_min/quant_max here +def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): + """Get quant_min and quant_max args based on dtype and also + verify that they are within the range of possible quant_min/quant_max + for dtype + """ + if dtype in FP8_TYPES: + quant_min_lower_bound, quant_max_upper_bound = ( + torch.finfo(dtype).min, + torch.finfo(dtype).max, + ) + elif dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise ValueError(f"Unsupported dtype: {dtype}") + else: + quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] + if quant_min is None: + quant_min = quant_min_lower_bound + if quant_max is None: + quant_max = quant_max_upper_bound + + if quant_min < quant_min_lower_bound: + raise AssertionError( + "quant_min out of bound for dtype, " + f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" + ) + + if quant_max > quant_max_upper_bound: + raise AssertionError( + "quant_max out of bound for dtype, " + f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" + ) + return quant_min, quant_max + + +def _get_reduction_params(block_size, input_size): + """Given block_size and input size find the parameters for reduction: + + Output: + shape_for_reduction: the shape we use to `view` input to prepare it for reduction + reduction_dims: the dims we'll do reduction over + + Example:: + Input: + block_size: (3, 3, 2, 10) + input_size: (3, 3, 10, 10) + + Output: + shape_for_reduction: (3, 3, 5, 2, 10) + reduction_dim: [0, 1, 3, 4] + """ + if len(block_size) != len(input_size): + raise AssertionError( + "block_size length must equal input_size length, got " + f"block_size={block_size}, input_size={input_size}" + ) + shape_for_reduction = [] + reduction_dims = [] + cur_dim = 0 + for i in range(len(block_size)): + if block_size[i] != input_size[i] and block_size[i] > 1: + if input_size[i] % block_size[i] != 0: + raise AssertionError( + f"Expecting input size at {i} dimension: {input_size[i]} to be divisible " + f"by block_size at {i} dimension: {block_size[i]}" + ) + shape_for_reduction.append(input_size[i] // block_size[i]) + shape_for_reduction.append(block_size[i]) + # reduce over the block_size[i] dim + reduction_dims.append(cur_dim + 1) + cur_dim += 2 + else: + # block_size[i] == input_size[i] or block_size[i] == 1 + shape_for_reduction.append(input_size[i]) + # we only need to reduce over the dimension if block_size is greater than 1 + # otherwise it's already the same as reduced dimension + if block_size[i] != 1: + reduction_dims.append(cur_dim) + cur_dim += 1 + return shape_for_reduction, reduction_dims + + +def _register_custom_op(lib): + """This decorator is used to preserve some high level operators for torch.export.export + while still allow them to be decomposed for inductor path + + requirement: make sure `fn.__name__[1:]` is the operator name you want to register + + NOTE: This should be applied at the top, after all other decorators have been applied + NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input, + e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make + sense for downstream system (like executorch) to accept as well + + Example: + lib = torch.library.Library("my_namespace', "FRAGMENT") + + register_custom_op = _register_custom_op(lib) + + @register_custom_op + def _the_op_that_needs_to_be_preserved(...) + ... + + # after this, `_the_op_that_needs_to_be_preserved` will be preserved as + # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after + # torch.export.export / torch._export.export_for_training + + """ + from torch._inductor.decomposition import register_decomposition + + def decorator(fn): + from torch._library.infer_schema import infer_schema + + # expecting fn.__name__ starts with `_` and we want to take the rest + # to be the name of the custom op + if fn.__name__[0] != "_": + raise AssertionError( + f"Expecting function name starts with `_`, got {fn.__name__}" + ) + if any(c in fn.__name__ for c in ".<>"): + raise AssertionError( + f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" + ) + op_name = fn.__name__[1:] + schema = op_name + infer_schema(fn, mutates_args={}) + lib.define(schema) + lib.impl(op_name, fn, "CompositeImplicitAutograd") + + lib_namespace = lib.ns + op = getattr(getattr(torch.ops, lib_namespace), op_name) + register_decomposition([op])(fn) + return op + + return decorator + + +quant_lib = torch.library.Library("pt2e_quant", "FRAGMENT") # noqa: TOR901 + +register_custom_op = _register_custom_op(quant_lib) + + +def choose_qparams_affine_with_min_max( + min_val: torch.Tensor, + max_val: torch.Tensor, + mapping_type: MappingType, + block_size: tuple[int, ...], + target_dtype: torch.dtype, + quant_min: int | None = None, + quant_max: int | None = None, + eps: float | None = None, + scale_dtype: torch.dtype | None = None, + zero_point_dtype: torch.dtype | None = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain | None = ZeroPointDomain.INT, +) -> tuple[torch.Tensor, torch.Tensor]: + """A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine` + operator that pass in min_val and max_val directly instead of deriving these from a single input. + This is used for observers in static quantization where min_val and max_val may be obtained through + tracking all the data in calibration data set. + + Args: + Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one + difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val + and then scale/zero_point, we pass in min_val/max_val directly + """ + return _choose_qparams_affine( + None, + mapping_type.name, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain.name if zero_point_domain is not None else None, + min_val, + max_val, + ) + + +@register_custom_op +def _choose_qparams_affine( + input: torch.Tensor | None, + mapping_type: str, + block_size: list[int], + target_dtype: torch.dtype, + quant_min: int | float | bool | None = None, + quant_max: int | float | bool | None = None, + eps: float | None = None, + scale_dtype: torch.dtype | None = None, + zero_point_dtype: torch.dtype | None = None, + preserve_zero: bool = True, + zero_point_domain: str | None = "INT", + min_val: torch.Tensor | None = None, + max_val: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """op definition that has compatible signatures with custom op library + + The op does the following: + 1. figure out the dimension for reduction based on block_size + 2. find min_val/max_val based on the dimension for reduction + 3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero` + and `zero_point_domain` + """ + quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) + if mapping_type not in [ + MappingType.SYMMETRIC.name, + MappingType.SYMMETRIC_NO_CLIPPING_ERR.name, + MappingType.ASYMMETRIC.name, + ]: + raise AssertionError(f"Unsupported mapping type: {mapping_type}") + if target_dtype in FP8_TYPES: + if mapping_type != MappingType.SYMMETRIC.name: + raise AssertionError( + f"Only symmetric quantization is supported for FP8 types, got {mapping_type}" + ) + + if input is not None: + if scale_dtype is None: + scale_dtype = input.dtype + if zero_point_dtype is None: + zero_point_dtype = input.dtype + if eps is None: + eps = torch.finfo(input.dtype).eps + + if len(block_size) != input.dim(): + raise AssertionError( + f"Got input dim:{input.dim()}, block_size: {block_size}" + ) + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + input = input.view(shape_for_reduction) + + min_val = torch.amin(input, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + else: + if min_val is None or max_val is None: + raise AssertionError( + f"Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}" + ) + if min_val.dtype != max_val.dtype: + raise AssertionError( + f"Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}" + ) + + if scale_dtype is None: + scale_dtype = min_val.dtype + if zero_point_dtype is None: + zero_point_dtype = min_val.dtype + if eps is None: + eps = torch.finfo(min_val.dtype).eps + + if preserve_zero: + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + else: + min_val_neg = min_val + max_val_pos = max_val + + if ( + mapping_type == MappingType.SYMMETRIC.name + or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name + ): + # scales + if mapping_type == MappingType.SYMMETRIC.name: + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + else: + if mapping_type != MappingType.SYMMETRIC_NO_CLIPPING_ERR.name: + raise AssertionError( + f"Expected mapping_type to be SYMMETRIC_NO_CLIPPING_ERR, got {mapping_type}" + ) + # calculate smin and smax individually and choose the larger one. For example, if quant_min = -8 and + # quant_max = 7. + # - If smin is bigger: There would be coverage on negative values down to -8, and less rounding + # error than the existing SYMMETRIC case. + # - If smax is bigger: it covers the positive values up to 7. The round + # error may be bigger than the existing SYMMETRIC case. Either way, there's no out-of-range fp values after + # quantization. + smin = min_val_neg / float(quant_min) + smax = max_val_pos / float(quant_max) + mask = smin > smax + scale = torch.where(mask, smin, smax) + # zeros + if not preserve_zero: + raise ValueError( + "preserve_zero == False is not supported for symmetric quantization" + ) + if ( + zero_point_domain is not None + and zero_point_domain != ZeroPointDomain.INT.name + ): + raise ValueError( + "zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" + ) + scale = torch.clamp(scale, min=eps) + zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) + else: + if mapping_type != MappingType.ASYMMETRIC.name: + raise AssertionError( + f"Expected mapping_type to be ASYMMETRIC, got {mapping_type}" + ) + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.clamp(scale, min=eps) + if zero_point_domain == ZeroPointDomain.NONE.name: + zero_point = None + else: + if preserve_zero: + zero_point = quant_min - torch.round(min_val_neg / scale) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + else: + if zero_point_domain != ZeroPointDomain.FLOAT.name: + raise AssertionError( + "if not preserve_zero, zero_point must be in FLOAT domain" + ) + mid_point = (quant_max + quant_min + 1) / 2 + zero_point = min_val_neg + scale * mid_point + + if zero_point is not None: + zero_point = zero_point.to(dtype=zero_point_dtype) + return scale.to(dtype=scale_dtype), zero_point + + +@torch.no_grad() +def quantize_affine( + input: torch.Tensor, + block_size: tuple[int, ...], + scale: torch.Tensor, + zero_point: torch.Tensor | None, + output_dtype: torch.dtype, + quant_min: int | float | None = None, + quant_max: int | float | None = None, + zero_point_domain: ZeroPointDomain | None = ZeroPointDomain.INT, +) -> torch.Tensor: + """ + Args: + input (torch.Tensor): original float32, float16 or bfloat16 Tensor + block_size: (Tuple[int, ...]): granularity of quantization, + this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (float): quantization parameter for affine quantization + zero_point (int): quantization parameter for affine quantization + output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype + quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + + Note: + How can block_size represent different granularities? + let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different + granularities: + + granularity type | block_size + per_tensor | (3, 3, 10, 10) + per_axis (axis=0) | (1, 3, 10, 10) + per_axis (axis=1) | (3, 1, 10, 10) + per_group (groupsize=2) | (3, 3, 10, 2) + per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10) + + + Output: + quantized tensor with requested dtype + """ + return _quantize_affine( + input, + block_size, + scale, + zero_point, + output_dtype, + quant_min, + quant_max, + zero_point_domain.name if zero_point_domain is not None else None, + ) + + +@register_custom_op +def _quantize_affine( + input: torch.Tensor, + block_size: list[int], + scale: torch.Tensor, + zero_point: torch.Tensor | None, + output_dtype: torch.dtype, + quant_min: int | float | bool | None = None, + quant_max: int | float | bool | None = None, + zero_point_domain: str | None = ZeroPointDomain.INT.name, +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library + + Note: + zero_point_domain is optional specifies how we quantize the floating point to quantized data: + INT: quantized_val = (float_val / scale) (integer) + zero_point (integer) + FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization + Where we do not want to round values to nearest integer and instead scale and cast. + """ + quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) + # workaround for uintx dtypes, since we don't have native Uintx dtype connected with + # torch.uintx dtypes yet + if output_dtype in _SUB_BYTE_UINT_BOUNDS: + output_dtype = torch.uint8 + return _quantize_affine_no_dtype_cast( + input, + block_size, + scale, + zero_point, + quant_min, + quant_max, + zero_point_domain, + ).to(output_dtype) + + +def _quantize_affine_no_dtype_cast( + input: torch.Tensor, + block_size: list[int], + scale: torch.Tensor, + zero_point: torch.Tensor | None, + quant_min: int | float, + quant_max: int | float, + zero_point_domain: str | None = ZeroPointDomain.INT.name, +) -> torch.Tensor: + """ + The op does the following: + 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + the shape after reduction + 2. quantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain + 3. reshape the quantized result to original shape + """ + # TODO: validations + # TODO: validate scale/zero_point dimensions are compatible with block_size + if input.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise AssertionError(f"Unsupported input dtype: {input.dtype}") + if len(block_size) != input.dim(): + raise AssertionError(f"Got input dim:{input.dim()}, block_size: {block_size}") + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + if zero_point_domain == ZeroPointDomain.INT.name: + quant = torch.clamp( + torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max + ) + elif zero_point_domain == ZeroPointDomain.NONE.name: + if zero_point is not None: + raise AssertionError( + "zero_point should be None when zero_point_domain is NONE" + ) + quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max) + elif zero_point_domain is None: + # This case handles quantization for float8 we expect no zero point and no zero point domain + if zero_point is not None: + raise AssertionError( + "zero_point should be None when zero_point_domain is None" + ) + quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + else: + if zero_point_domain != ZeroPointDomain.FLOAT.name: + raise AssertionError(f"Unexpected zero_point_domain: {zero_point_domain}") + mid_point = (quant_max + quant_min + 1) / 2 + min_val = zero_point - scale * mid_point + quant = torch.clamp( + torch.round((input - min_val) / scale), quant_min, quant_max + ) + quant = quant.view(original_shape) + + return quant + + +def dequantize_affine( + input: torch.Tensor, + block_size: tuple[int, ...], + scale: torch.Tensor, + zero_point: torch.Tensor | None, + input_dtype: torch.dtype, + quant_min: int | float | None = None, + quant_max: int | float | None = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + *, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Args: + input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument + block_size: (List[int]): granularity of quantization, + this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (Tensor): quantization parameter for affine quantization + zero_point (Tensor): quantization parameter for affine quantization + input_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for input Tensor + quant_max (Optional[int]): maximum quantized value for input Tensor + output_dtype (torch.dtype): dtype for output Tensor, default is fp32 + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + + Output: + dequantized Tensor, with requested dtype or fp32 + """ + return _dequantize_affine( + input, + block_size, + scale, + zero_point, + input_dtype, + quant_min, + quant_max, + zero_point_domain.name if zero_point_domain is not None else None, + output_dtype=output_dtype, + ) + + +@register_custom_op +def _dequantize_affine( + input: torch.Tensor, + block_size: list[int], + scale: torch.Tensor, + zero_point: torch.Tensor | None, + input_dtype: torch.dtype, + quant_min: int | float | bool | None = None, + quant_max: int | float | bool | None = None, + zero_point_domain: str | None = ZeroPointDomain.INT.name, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library""" + # TODO: validate scale/zero_point dimensions are compatible with block_size + if input_dtype not in _SUB_BYTE_UINT_BOUNDS: + if input.dtype != input_dtype: + raise AssertionError(f"Expected: {input_dtype}, got: {input.dtype}") + if output_dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise AssertionError(f"Unsupported output dtype: {output_dtype}") + quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) + return _dequantize_affine_no_dtype_check( + input, + block_size, + scale, + zero_point, + quant_min, + quant_max, + zero_point_domain, + output_dtype, + ) + + +def _dequantize_affine_no_dtype_check( + input: torch.Tensor, + block_size: list[int], + scale: torch.Tensor, + zero_point: torch.Tensor | None, + quant_min: int | float, + quant_max: int | float, + zero_point_domain: str | None = ZeroPointDomain.INT.name, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """This function converts AQT tensors to their high precision floating point representation + + The op does the following: + 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + the shape after reduction + 2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain + 3. reshape the quantized result to original shape and change dtype to the output_dtype + """ + if len(block_size) != input.dim(): + raise AssertionError(f"Got input dim:{input.dim()}, block_size: {block_size}") + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + if zero_point_domain == ZeroPointDomain.INT.name: + # Force a copy to avoid input modification due + # to upcoming in-place operations. + dequant = input.to(torch.int32, copy=True) + if zero_point is not None: + dequant = dequant - zero_point.to(torch.int32) + dequant = dequant.to(output_dtype) + dequant = dequant * scale + elif zero_point_domain == ZeroPointDomain.NONE.name: + if zero_point is not None: + raise AssertionError( + "zero_point should be None when zero_point_domain is NONE" + ) + dequant = input.to(output_dtype) + dequant = dequant * scale + elif zero_point_domain is None: + # This case handles dequantization for float8 we expect no zero point and no zero point domain + if zero_point is not None: + raise AssertionError( + "zero_point should be None when zero_point_domain is None" + ) + if not _is_float8_type(input.dtype): + raise AssertionError( + f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" + ) + dequant = input.to(output_dtype) + dequant = dequant * scale + else: + if zero_point_domain != ZeroPointDomain.FLOAT.name: + raise AssertionError(f"Unexpected zero point domain: {zero_point_domain}") + # TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this) + mid_point = (quant_max + quant_min + 1) / 2 + # This should allocate new memory and avoid input modification + dequant = input - mid_point + dequant = dequant.to(output_dtype) + dequant *= scale + if zero_point is not None: + dequant += zero_point + + return dequant.view(original_shape).to(output_dtype) + + +class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase): + def forward(self, input: torch.Tensor): + if input.numel() == 0: + return input + + input_detached = input.detach() + self.original_dtype = input_detached.dtype + if self.granularity is None: + raise AssertionError("granularity is None") + self.block_size = get_block_size(input_detached.shape, self.granularity) + + shape_for_reduction, reduction_dims = _get_reduction_params( + self.block_size, input_detached.size() + ) + input_detached = input_detached.view(shape_for_reduction) + min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): + self.min_val = min_val + self.max_val = max_val + else: + if self.min_val.shape != min_val.shape: + raise AssertionError( + f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + ) + if self.max_val.shape != max_val.shape: + raise AssertionError( + f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + ) + min_val = torch.min(self.min_val, min_val) + max_val = torch.max(self.max_val, max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + # returning original input + return input + + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + if not (hasattr(self, "min_val") and hasattr(self, "max_val")): + raise AssertionError( + "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + ) + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + [], # BlockSize is not needed because the min/max are already reduced + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + + +class AffineQuantizedMovingAverageMinMaxObserver(AffineQuantizedObserverBase): + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + averaging_constant=0.01, + quant_min: int | None = None, + quant_max: int | None = None, + eps: float | None = None, + is_dynamic=False, + scale_dtype: torch.dtype | None = None, + zero_point_dtype: torch.dtype | None = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain | None = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + self.is_dynamic = is_dynamic + self.averaging_constant = averaging_constant + if is_dynamic and self.averaging_constant != 1: + raise NotImplementedError( + "MovingAverageMinMaxObserver doesn't support dynamic quantization for " + f"averaging constant of {self.averaging_constant}" + ) + + super().__init__( + mapping_type=mapping_type, + target_dtype=target_dtype, + granularity=granularity, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + + def forward(self, input: torch.Tensor): + if input.numel() == 0: + return input + + input_detached = input.detach() + self.original_dtype = input_detached.dtype + if self.granularity is None: + raise AssertionError("granularity is None") + self.block_size = get_block_size(input_detached.shape, self.granularity) + + shape_for_reduction, reduction_dims = _get_reduction_params( + self.block_size, input_detached.size() + ) + input_detached = input_detached.view(shape_for_reduction) + min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): + self.min_val = min_val + self.max_val = max_val + else: + if self.min_val.shape != min_val.shape: + raise AssertionError( + f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + ) + if self.max_val.shape != max_val.shape: + raise AssertionError( + f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + ) + min_val = self.min_val + self.averaging_constant * (min_val - self.min_val) + max_val = self.max_val + self.averaging_constant * (max_val - self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + # returning original input + return input + + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + if not (hasattr(self, "min_val") and hasattr(self, "max_val")): + raise AssertionError( + "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + ) + + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + [], # BlockSize is not needed because the min/max are already reduced + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + + +class AffineQuantizedPlaceholderObserver(AffineQuantizedObserverBase): + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + quant_min: int | None = None, + quant_max: int | None = None, + eps: float | None = None, + is_dynamic=False, + scale_dtype: torch.dtype | None = None, + zero_point_dtype: torch.dtype | None = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain | None = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + self.is_dynamic = is_dynamic + + super().__init__( + mapping_type=mapping_type, + target_dtype=target_dtype, + granularity=granularity, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + + def forward(self, input): + self.block_size = get_block_size(input.shape, self.granularity) + self.original_dtype = input.dtype + return input + + def calculate_qparams(self): + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for PlaceholderObserver" + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_numeric_debugger.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_numeric_debugger.py new file mode 100644 index 0000000000000000000000000000000000000000..6eaeaa46a924893fa0aace363f3040d5e2d692de --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_numeric_debugger.py @@ -0,0 +1,341 @@ +import copy +import logging +from collections.abc import Callable, Sequence +from dataclasses import dataclass + +import torch +from torch.ao.ns.fx.utils import compute_sqnr +from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process +from torch.export import ExportedProgram +from torch.fx import GraphModule, Node +from torch.nn import functional as F + + +NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle" +CUSTOM_KEY = "custom" + +log = logging.getLogger(__name__) + + +def generate_numeric_debug_handle(ep: ExportedProgram) -> None: + """ + Attach numeric_debug_handle_id for all nodes in the graph module of the given + ExportedProgram, like conv2d, squeeze, conv1d, etc, except for placeholder. + Notice that nodes like getattr are out of scope since they are not in the graph. + + The graph nodes of input exported program are modified inplace. + + Here's an example of using debug handle quantize flow:: + + ep = export_for_training(eager_model, example_inputs) + generate_numeric_debug_handle(ep) + + m = ep.module() + quantizer = XNNPACKQuantizer() + m = prepare_pt2e(m, quantizer) + m = convert_pt2e(m) + """ + + # Sanity check the input data type + if not isinstance(ep, ExportedProgram): + raise ValueError( + f"Expected ep to be ExportedProgram, got {type(ExportedProgram)}" + ) + + unique_id = 0 + + def _find_max_id(node: torch.fx.Node) -> None: + nonlocal unique_id + unique_id = max( + unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0) + ) + + def _assign_debug_handle(node: torch.fx.Node) -> None: + nonlocal unique_id + if CUSTOM_KEY not in node.meta: + node.meta[CUSTOM_KEY] = {} + + if NUMERIC_DEBUG_HANDLE_KEY not in node.meta[CUSTOM_KEY]: + node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id + unique_id += 1 + + # Find the max ID that exists in the graph first, in case part of the graph + # has already been annotated. This way we guarantee there are no duplicate + # handle IDs. + bfs_trace_with_node_process(ep, _find_max_id) + + unique_id += 1 + + # Assign debug handles to all nodes in the graph that don't have one based on the + # max ID found in the previous step. + bfs_trace_with_node_process(ep, _assign_debug_handle) + + +def _detach(x: object) -> object: + detached: object = None + if isinstance(x, torch.Tensor): + detached = x.detach() + elif isinstance(x, (list, tuple)): + detached = type(x)([_detach(e) for e in x]) + elif isinstance(x, dict): + detached = {k: _detach(e) for k, e in x.items()} + else: + detached = x + return detached + + +def _tensor_shape_equals(x: object, y: object) -> bool: + if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): + return x.shape == y.shape + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + return all(_tensor_shape_equals(e1, e2) for e1, e2 in zip(x, y)) + elif isinstance(x, dict) and isinstance(y, dict): + all_equal = True + for k in x: + all_equal = all_equal and k in y and (_tensor_shape_equals(x[k], y[k])) + return all_equal + else: + log.debug("Comparing non Tensors: %s and %s, they must be equal", x, y) + return type(x) is type(y) and x == y + + +def _loss_fn( + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], x: object, y: object +) -> object: + """The returned loss will have the same structure as `x` and `y`, e.g. + if both are Tensor, we'll return a Tensor + if both are list, we'll return a list of Tensors + if both are dict, we'll return a dict with the same key, and value being the loss between the + two Tensors + """ + if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): + return loss(x.to(torch.float32), y.to(torch.float32)) + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + return type(x)([_loss_fn(loss, e1, e2) for e1, e2 in zip(x, y)]) + elif isinstance(x, dict) and isinstance(y, dict): + return {k: _loss_fn(loss, e, y[k]) for k, e in x.items()} + else: + return None + + +class OutputLogger(torch.nn.Module): + """ + Base class for capturing output values for nodes in a GraphModule, it only captures + Tensor output currently, but we can extend it to work for other types of inputs later if needed + """ + + # Mark as impure so that calls to it will not be removed during DCE. + _is_impure = True + + def __init__( + self, + debug_handle: int, + node_name: str | None = None, + nn_module_stack: object | None = None, + ) -> None: + super().__init__() + self.node_name = node_name + self.nn_module_stack = nn_module_stack + self.debug_handle = debug_handle + self.stats: list[object] = [] + + def forward(self, x: object) -> object: + self.stats.append(_detach(x)) + return x + + def __extra_repr__(self) -> str: + return ( + f"debug_handle={self.debug_handle}, node_name={self.node_name}, " + "nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})" + ) + + +def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node: + """For a given node, adds an OutputLogger that observes the output of that node, + and all its users use the OutputLogger output instead. + The OutputLogger will contain the debug_handle which can be used to compare + graphs after transforms""" + + # to avoid circular dep + from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix + + # add a logger after the node + with model.graph.inserting_after(node): + get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger") + logger_name = get_new_attr_name(model) + setattr( + model, + logger_name, + OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")), + ) + logger_node = model.graph.call_module(logger_name, (node,), {}) + + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is logger_node: + continue + user_node.replace_input_with(node, logger_node) + + return logger_node + + +def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule: + """Add output loggers to node that has numeric_debug_handle + + Args: + model (GraphModule): original model + Returns: + a model with output loggers for all nodes that has numeric_debug_handle_id + """ + # don't change the original model + model = copy.deepcopy(model) + for n in model.graph.nodes: + if ( + CUSTOM_KEY not in n.meta + or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY] + ): + continue + numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] + _insert_logger(model, n, numeric_debug_handle) + + model.recompile() + return model + + +@dataclass(frozen=True) +class QuantizationComparisonResult: + actual: torch.Tensor + ref: torch.Tensor + + @property + def mse_loss(self) -> object: + return self.loss(F.mse_loss) + + @property + def sqnr(self) -> object: + return self.loss(compute_sqnr) + + def loss( + self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + ) -> object: + return _loss_fn(loss_function, self.actual, self.ref) + + def __repr__(self) -> str: + # Don't include the tensors themselves as they are quite large to print + # out. + return ( + f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})" + ) + + def __post_init__(self) -> None: + if not isinstance(self.actual, (torch.Tensor, list, tuple, dict)): + raise ValueError( + f"`self.actual` value must be a Tensor, list, tuple or dict, got: {self.actual}" + ) + + if not isinstance(self.ref, (torch.Tensor, list, tuple, dict)): + raise ValueError( + f"`self.ref` value must be a Tensor, list, tuple or dict, got: {self.ref}" + ) + + if not _tensor_shape_equals(self.ref, self.actual): + raise ValueError( + f"Cannot compare tensors with different shapes: ref={self.ref} vs actual={self.actual}" + ) + + +@dataclass(frozen=True) +class NodeAccuracySummary: + handle: int + actual_node_name: str + actual_module_stack: str + ref_node_name: str + ref_module_stack: str + results: Sequence[QuantizationComparisonResult] + + +def _module_stack_to_str(module_stack: object) -> str: + """Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear") + to "mod.foo.0.linear" + """ + if not isinstance(module_stack, dict): + return str(module_stack) + module_values_list = list(module_stack.values()) + if len(module_values_list) > 0: + owning_module = module_values_list[-1][0] + return str(owning_module) + else: + return str(module_stack) + + +def extract_results_from_loggers( + model: GraphModule, +) -> dict[int, tuple[str | None, object, list[object]]]: + """For a given model, extract the tensors stats and related information for each debug handle. + The reason we have a list of object, instead of Tensor is because the output of node may not be + a Tensor, it could be (nested) list, tuple or dict as well. + + Returns: + A dict is keyed by the debug_handle id and the values are a list of object recorded + in loggers + + """ + # Results maps debug handle to a tensor list for each model being compared. + handles: dict[int, tuple[str | None, object, list[object]]] = {} + for _name, module in model.named_children(): + if isinstance(module, OutputLogger) and len(module.stats) > 0: + handles[module.debug_handle] = ( + module.node_name, + module.nn_module_stack, + module.stats, + ) + + return handles + + +def compare_results( + ref_results: dict[int, tuple[str | None, object, list[torch.Tensor]]], + actual_results: dict[int, tuple[str | None, object, list[torch.Tensor]]], +) -> dict[int, NodeAccuracySummary]: + """Given two dict mapping from `debug_handle_id` (int) to list of tensors + return a map from `debug_handle_id` to `NodeAccuracySummary` that contains + comparison information like SQNR, MSE etc. + + Args: + ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id + actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id + + Returns: + Dict[int, NodeAccuracySummary] + """ + comparisons = {} + for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items(): + if debug_handle not in actual_results: + log.debug( + "Cannot compare for handle %s because it wasn't found in the transformed model", + debug_handle, + ) + continue + actual_name, actual_stack, actual_stats = actual_results[debug_handle] + try: + results = [ + QuantizationComparisonResult(actual=a, ref=b) + for a, b in zip(actual_stats, ref_stats) + ] + except Exception as e: + # Add extra information for an exception from QuantizationComparisonResult + # if the shapes didn't match, to include the handle and the node names. + raise ValueError( + f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}" + ) from e + + comparisons[debug_handle] = NodeAccuracySummary( + handle=debug_handle, + actual_node_name=actual_name or "", + actual_module_stack=_module_stack_to_str(actual_stack), + ref_node_name=ref_name or "", + ref_module_stack=_module_stack_to_str(ref_stack), + results=results, + ) + + return comparisons diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..81c03e51414320e3faee0f6b8906ea38910c41ee --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py @@ -0,0 +1,81 @@ +import logging +import operator + +import torch +from torch.ao.quantization.pt2e.utils import ( + _filter_sym_size_users, + _is_valid_annotation, +) +from torch.fx.node import map_arg +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +__all__ = ["DuplicateDQPass"] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +def _maybe_duplicate_dq( + gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node +) -> None: + annotation = user.meta.get("quantization_annotation", None) + if not _is_valid_annotation(annotation): # type: ignore[arg-type] + return + with gm.graph.inserting_after(dq_node): + new_node = gm.graph.node_copy(dq_node) + + def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: + if n == dq_node: + return new_node + else: + return n + + new_args = map_arg(user.args, maybe_replace_node) + new_kwargs = map_arg(user.kwargs, maybe_replace_node) + user.args = new_args # type: ignore[assignment] + user.kwargs = new_kwargs # type: ignore[assignment] + + +class DuplicateDQPass(PassBase): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in _DEQUANTIZE_OPS: + dq_users = _filter_sym_size_users(node) + if len(dq_users) <= 1: + continue + # Do not duplicate dq for dynamic quantization + # Pattern: choose_qparam - getitem - q - dq + q_node = node.args[0] + if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS: + getitem_node = q_node.args[1] + if ( + isinstance(getitem_node, torch.fx.node.Node) + and getitem_node.op == "call_function" + and getitem_node.target is operator.getitem + ): + choose_qparam_node = getitem_node.args[0] + if ( + isinstance(choose_qparam_node, torch.fx.node.Node) + and choose_qparam_node.op == "call_function" + and choose_qparam_node.target + == torch.ops.quantized_decomposed.choose_qparams.tensor + ): + continue + for user in dq_users: + _maybe_duplicate_dq(graph_module, node, user) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/export_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/export_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..70cca73dd00dcb4bd865dda4f2718a610496323e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/export_utils.py @@ -0,0 +1,240 @@ +# mypy: allow-untyped-defs +import types + +import torch +import torch.nn.functional as F +from torch.ao.quantization.utils import _assert_and_get_unique_device + + +__all__ = [ + "model_is_exported", +] + +_EXPORTED_TRAINING_ATTR = "_exported_training" + + +class _WrapperModule(torch.nn.Module): + """Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you + are trying to export a callable. + """ + + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, *args, **kwargs): + """Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`.""" + return self.fn(*args, **kwargs) + + +def model_is_exported(m: torch.nn.Module) -> bool: + """ + Return True if the `torch.nn.Module` was exported, False otherwise + (e.g. if the model was FX symbolically traced or not traced at all). + """ + return isinstance(m, torch.fx.GraphModule) and any( + "val" in n.meta for n in m.graph.nodes + ) + + +def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool): + """ + Switch dropout patterns in the model between train and eval modes. + + Dropout has different behavior in train vs eval mode. For exported models, + however, calling `model.train()` or `model.eval()` does not automatically switch + the dropout behavior between the two modes, so here we need to rewrite the aten + dropout patterns manually to achieve the same effect. + + See https://github.com/pytorch/pytorch/issues/103681. + """ + # Avoid circular dependencies + from .utils import _get_aten_graph_module_for_pattern + + # Needed to ensure subgraph matches are self-contained + m.graph.eliminate_dead_code() + m.recompile() + + for inplace in [False, True]: + + def dropout_train(x): + return F.dropout(x, p=0.5, training=True, inplace=inplace) + + def dropout_eval(x): + return F.dropout(x, p=0.5, training=False, inplace=inplace) + + example_inputs = (torch.randn(1),) + if train_to_eval: + match_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_train), + example_inputs, + ) + replacement_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_eval), + example_inputs, + ) + else: + match_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_eval), + example_inputs, + ) + replacement_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_train), + example_inputs, + ) + + from torch.fx.subgraph_rewriter import replace_pattern_with_filters + + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + match_filters=[], + ignore_literals=True, + ) + m.recompile() + + +def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool): + """ + Switch batchnorm patterns in the model between train and eval modes. + + Batchnorm has different behavior in train vs eval mode. For exported models, + however, calling `model.train()` or `model.eval()` does not automatically switch + the batchnorm behavior between the two modes, so here we need to rewrite the aten + batchnorm patterns manually to achieve the same effect. + """ + # TODO(Leslie): This function still fails to support custom momentum and eps value. + # Enable this support in future updates. + + # Avoid circular dependencies + from .utils import _get_aten_graph_module_for_pattern + + # Needed to ensure subgraph matches are self-contained + m.graph.eliminate_dead_code() + m.recompile() + + def bn_train( + x: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ): + return F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True + ) + + def bn_eval( + x: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ): + return F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False + ) + + example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + device = _assert_and_get_unique_device(m) + is_cuda = device is not None and device.type == "cuda" + bn_train_aten = _get_aten_graph_module_for_pattern( + _WrapperModule(bn_train), + example_inputs, + is_cuda, + ) + bn_eval_aten = _get_aten_graph_module_for_pattern( + _WrapperModule(bn_eval), + example_inputs, + is_cuda, + ) + + if train_to_eval: + match_pattern = bn_train_aten + replacement_pattern = bn_eval_aten + else: + match_pattern = bn_eval_aten + replacement_pattern = bn_train_aten + + from torch.fx.subgraph_rewriter import replace_pattern_with_filters + + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + match_filters=[], + ignore_literals=True, + ) + m.recompile() + + +# TODO: expose these under this namespace? +def _move_exported_model_to_eval(model: torch.fx.GraphModule): + """ + Move an exported GraphModule to eval mode. + + This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm. + QAT users should call this before performing inference on the model. + + This call is idempotent; if the model is already in eval mode, nothing will happen. + """ + is_training = getattr(model, _EXPORTED_TRAINING_ATTR, True) + if not is_training: + return model + setattr(model, _EXPORTED_TRAINING_ATTR, False) + _replace_dropout(model, train_to_eval=True) + _replace_batchnorm(model, train_to_eval=True) + return model + + +def _move_exported_model_to_train(model: torch.fx.GraphModule): + """ + Move an exported GraphModule to train mode. + + This is equivalent to model.train() but only for certain special ops like dropout, batchnorm. + QAT users should call this before performing training on the model. + + This call is idempotent; if the model is already in train mode, nothing will happen. + """ + is_training = getattr(model, _EXPORTED_TRAINING_ATTR, False) + if is_training: + return model + setattr(model, _EXPORTED_TRAINING_ATTR, True) + _replace_dropout(model, train_to_eval=False) + _replace_batchnorm(model, train_to_eval=False) + return model + + +def _allow_exported_model_train_eval(model: torch.fx.GraphModule): + """ + Allow users to call `model.train()` and `model.eval()` on an exported model, + but with the effect of changing behavior between the two modes limited to special + ops only, which are currently dropout and batchnorm. + + Note: This does not achieve the same effect as what `model.train()` and `model.eval()` + does in eager models, but only provides an approximation. In particular, user code + branching on `training` flag will not function correctly in general because the branch + is already specialized at export time. Additionally, other ops beyond dropout and batchnorm + that have different train/eval behavior will also not be converted properly. + """ + + def _train(self, mode: bool = True): + if mode: + _move_exported_model_to_train(self) + else: + _move_exported_model_to_eval(self) + + def _eval(self): + _move_exported_model_to_eval(self) + + model.train = types.MethodType(_train, model) # type: ignore[method-assign] + model.eval = types.MethodType(_eval, model) # type: ignore[method-assign] + return model diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/graph_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/graph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b46011d8c41d3dc0d980e33caecb22b615fd22 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/graph_utils.py @@ -0,0 +1,191 @@ +# mypy: allow-untyped-defs +import itertools +import operator +from collections import OrderedDict +from collections.abc import Callable, Sequence +from typing import Any + +import torch +from torch.export import ExportedProgram +from torch.fx import Node +from torch.fx.passes.utils.source_matcher_utils import ( + check_subgraphs_connected, + get_source_partitions, + SourcePartition, +) + + +__all__ = [ + "find_sequential_partitions", + "get_equivalent_types", + "update_equivalent_types_dict", + "bfs_trace_with_node_process", +] + +_EQUIVALENT_TYPES: list[set] = [ + {torch.nn.Conv1d, torch.nn.functional.conv1d}, + {torch.nn.Conv2d, torch.nn.functional.conv2d}, + {torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d}, + {torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_}, + {torch.nn.BatchNorm2d, torch.nn.functional.batch_norm}, + {torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_}, + {torch.add, operator.add, operator.iadd, "add", "add_"}, + {torch.mul, operator.mul, operator.imul, "mul", "mul_"}, +] + + +def _create_equivalent_types_dict(): + _DICT = {} + for values in _EQUIVALENT_TYPES: + for v in values: + _DICT[v] = list(values) + return _DICT + + +_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict() + + +def get_equivalent_types() -> list[set]: + return _EQUIVALENT_TYPES + + +def update_equivalent_types_dict(customized_equivalent_types=None): + """Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT. + When customized_equivalent_types passes in, + re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT. + """ + if customized_equivalent_types is None: + raise ValueError("customized_equivalent_types should not be None") + global _EQUIVALENT_TYPES + global _EQUIVALENT_TYPES_DICT + _EQUIVALENT_TYPES = customized_equivalent_types + _EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict() + + +def _partitions_sequential(partitions: Sequence[SourcePartition]): + prev_partition = None + for partition in partitions: + if prev_partition is not None and not check_subgraphs_connected( + prev_partition, partition + ): + return False + prev_partition = partition + return True + + +def _get_matching_types(partition_type): + matching_types = [partition_type] + if partition_type in _EQUIVALENT_TYPES_DICT: + matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type]) + return matching_types + + +def _valid_type_sequence(partition_types: list[Any]): + partition_types_set = set() # type: ignore[var-annotated] + for partition_type in partition_types: + matching_types = _get_matching_types(partition_type) + matching_types_set = set(matching_types) + if len(partition_types_set & matching_types_set) > 0: + return False + partition_types_set |= matching_types_set + return True + + +def find_sequential_partitions( + gm: torch.fx.GraphModule, + partition_types: list[Any], + include_functional_equivalent=True, + filter_fn: Callable[[Node], bool] | None = None, +): + if not _valid_type_sequence(partition_types): + raise ValueError( + f"Invalid partition types: {partition_types}. Each type in the sequence must be unique" + ) + + typed_partitions: OrderedDict[Any, list[SourcePartition]] = OrderedDict() + for partition_type in partition_types: + types_to_match = _get_matching_types(partition_type) + partitions = get_source_partitions(gm.graph, types_to_match, filter_fn) + typed_partitions[partition_type] = list( + itertools.chain.from_iterable(partitions.values()) + ) + + typed_partitions_list = list(typed_partitions.values()) + fusion_candidates = itertools.product(*typed_partitions_list) + fused_partitions = [ + candidate + for candidate in fusion_candidates + if _partitions_sequential(candidate) + ] + return fused_partitions + + +def _get_submodule( + graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int +) -> tuple[str, torch.nn.Module, torch.fx.Node]: + submod_node = node.args[arg_index] + if not isinstance(submod_node, torch.fx.Node): + raise AssertionError( + f"Expected submod_node to be a torch.fx.Node, got {type(submod_node)}" + ) + if submod_node.op != "get_attr": + raise AssertionError( + f"Expected submod_node.op to be 'get_attr', got {submod_node.op}" + ) + if not isinstance(submod_node.target, str): + raise AssertionError( + f"Expected submod_node.target to be a string attribute name, got {type(submod_node.target)}" + ) + submodule = graph_module.get_submodule(submod_node.target) + # pyre-ignore + return submod_node.target, submodule, node + + +def _get_control_flow_submodules( + graph_module: torch.fx.GraphModule, +) -> list[tuple[str, torch.nn.Module, torch.fx.Node]]: + """ + Returns a list of submodules used for control flow operations + (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look + into submodules). Specifically, the returned value is a list containing a + tuple of (name of the submodule that's stored in the graph module, the + submodule itself, and the fx node that uses this submodule). + """ + control_flow_submodules = [] + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + + if node.target is torch.ops.higher_order.cond: + control_flow_submodules.append(_get_submodule(graph_module, node, 1)) + control_flow_submodules.append(_get_submodule(graph_module, node, 2)) + if node.target is torch.ops.higher_order.map_impl: + control_flow_submodules.append(_get_submodule(graph_module, node, 0)) + + return control_flow_submodules + + +def bfs_trace_with_node_process( + model: ExportedProgram | torch.fx.GraphModule, node_op: Callable +) -> None: + """Traverse the graph module and apply node_op to each node.""" + + if not isinstance(model, (ExportedProgram, torch.fx.GraphModule)): + raise AssertionError( + f"Expected GraphModule or ExportedProgram, got {type(model)}" + ) + gm = model.graph_module if isinstance(model, ExportedProgram) else model + queue = [gm] + while queue: + current_graph_module = queue.pop(0) + for node in current_graph_module.graph.nodes: + if node.op in ["output", "placeholder"]: + continue + + node_op(node) + + control_flow_submodules = [ + submodule + for _, submodule, _ in _get_control_flow_submodules(current_graph_module) + ] + queue.extend(control_flow_submodules) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/lowering.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..c306b1745badaf575060a6a2fb4ed21f6977ab75 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/lowering.py @@ -0,0 +1,60 @@ +import torch +from torch._inductor.constant_folding import constant_fold +from torch._inductor.fx_passes.freezing_patterns import freezing_passes + + +__all__ = [ + "lower_pt2e_quantized_to_x86", +] + + +def lower_pt2e_quantized_to_x86( + model: torch.fx.GraphModule, + example_inputs: tuple[torch.Tensor, ...], +) -> torch.fx.GraphModule: + """Lower a PT2E-quantized model to x86 backend. + + Args: + * `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow. + * `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model. + + Return: + A GraphModule lowered to x86 backend. + """ + + def _post_autograd_decomp_table(): # type: ignore[no-untyped-def] + decomp_table = torch.export.default_decompositions() + + # if we are post-autograd, we shouldn't + # decomp prim ops. + for k in list(decomp_table.keys()): + if not torch._export.utils._is_cia_op(k): + del decomp_table[k] + + return decomp_table + + def _node_replace(m): # type: ignore[no-untyped-def] + # Replace aten.t(x) with aten.permute(x, [1, 0]) + aten = torch.ops.aten + g = m.graph + for node in g.nodes: + if node.target is aten.t.default: + with g.inserting_before(node): + x = node.args[0] + dims = [1, 0] + perm_node = g.call_function(aten.permute.default, args=(x, dims)) + node.replace_all_uses_with(perm_node) + g.erase_node(node) + + g.lint() + m.recompile() + + lowered_model = ( + torch.export.export(model, example_inputs, strict=True) + .run_decompositions(_post_autograd_decomp_table()) + .module() + ) + _node_replace(lowered_model) + freezing_passes(lowered_model, example_inputs) + constant_fold(lowered_model) + return lowered_model diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..be5878042b046447e446c3f4ee1cb1d761f29f27 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -0,0 +1,217 @@ +# mypy: allow-untyped-defs +import logging + +import torch +from torch._export.error import InternalError +from torch.ao.quantization.pt2e.utils import ( + _filter_sym_size_users, + _find_q_dq_node_for_user, + _is_valid_annotation, +) +from torch.ao.quantization.quantizer import QuantizationSpecBase +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.ERROR) + +__all__ = ["PortNodeMetaForQDQ"] + +_METADATA_TO_PORT = [ + "stack_trace", + "quantization_tag", +] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.pt2e_quant.quantize_affine, +] + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.pt2e_quant.dequantize_affine, +] + +_CHOOSE_QPARAMS_OPS = [ + torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, + torch.ops.pt2e_quant.choose_qparams_affine, +] + + +def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None: + from_meta = from_node.meta + for meta_name in _METADATA_TO_PORT: + if meta_name in from_meta: + to_node.meta[meta_name] = from_meta[meta_name] + + +def _has_quant_annotation(node: torch.fx.Node) -> bool: + return "quantization_annotation" in node.meta + + +def _find_choose_qparams_node(node: torch.fx.Node) -> torch.fx.Node | None: + # BFS to look for choose qparams + from collections import deque + + queue = deque(list(node.users.keys())) + while len(queue): + n = queue.popleft() + if n.op == "output": + continue + if n.op == "call_function" and n.target in _CHOOSE_QPARAMS_OPS: + return n + for k in n.users: + queue.append(k) + return None + + +def _port_metadata_for_input_quant_nodes( + input_node: torch.fx.Node, + node: torch.fx.Node, + qspec: QuantizationSpecBase | None, +): + if qspec is None: + return + + is_dynamic_quant = getattr(qspec, "is_dynamic", None) + if is_dynamic_quant is not None and is_dynamic_quant is True: + choose_qparams_node = _find_choose_qparams_node(input_node) + if choose_qparams_node is None: + raise ValueError(f"No chose qparams node found for {node}") + choose_qparam_users = _filter_sym_size_users(choose_qparams_node) + if len(choose_qparam_users) != 2: + raise InternalError(f"Expecting exactly two user for {choose_qparams_node}") + scale_node = choose_qparam_users.pop() + dynamic_q_node = next(iter(scale_node.users.keys())) + dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node) + if len(dynamic_q_node_users) > 1: + raise InternalError(f"Expecting single user for {dynamic_q_node}") + dynamic_dq_node = dynamic_q_node_users.pop() + _add_metadata(choose_qparams_node, node) + _add_metadata(dynamic_q_node, node) + _add_metadata(dynamic_dq_node, node) + else: + q_node, dq_node = _find_q_dq_node_for_user(input_node, node) + if q_node is None or dq_node is None: + return + # add metadata for all the node between q_node and get_attr node + # if the q_node can be traced back to get_attr node + q_to_get_attr_nodes = [q_node] + q_node_input = q_node.args[0] + while ( + isinstance(q_node_input, torch.fx.Node) + and q_node_input.op == "call_function" + and q_node_input.target + in [ + torch.ops.aten.flatten.using_ints, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dim, + torch.ops.aten.transpose.Dimname, + torch.ops.aten.transpose.int, + torch.ops.aten.transpose_, + torch.ops.aten.view_copy.default, + torch.ops.aten.view.default, + torch.ops.aten._mkldnn_transpose, + ] + ): + q_to_get_attr_nodes.append(q_node_input) + q_node_input = q_node_input.args[0] + if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr": + for n in q_to_get_attr_nodes: + _add_metadata(n, q_node_input) + _add_metadata(dq_node, node) + + +def _port_metadata_for_output_quant_nodes( + node: torch.fx.Node, qspec: QuantizationSpecBase | None +): + if qspec is None: + return + + node_users = _filter_sym_size_users(node) + if len(node.users) == 0: + return + if len(node_users) != 1: + logger.warning(f"Expecting {node} to have single user") # noqa: G004 + q_node = node_users.pop() + if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS: + logger.warning( + f"Expecting {node} user to be a quantized op but got {q_node}" # noqa: G004 + ) # noqa: G004 + return + + _add_metadata(q_node, node) + + +class PortNodeMetaForQDQ(PassBase): + """ + Port metadata for nodes added by quantization flow. + For static quant these are: + - quantizer_per_tensor.default, dequantize_per_tensor.default + - quantizer_per_channel.default, dequantize_per_channel.default + For dynamic quant these are: + - choose_qparams.tensor + - quantizer_per_tensor.tensor, dequantize_per_tensor.tensor + - quantizer_per_channel.default, dequantize_per_channel.default + + Rules of porting metadata: + - Metadata to be ported: + - nn_module_stack + - stack_trace + - quantization_tag + - Metadata to NOT be ported: + - Everything else + - Rules: + - Statically quantized patterns: + - Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node. + - Quantize nodes on the outputs inherit metadata of the producer node. + - Example 1: + - Original: [Conv -> AvgPool -> Linear] + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] + - Inner brackets specify which nodes Q/DQ inherit metadata from + - [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ] + - Note first Q and last DQ do not inherit metadata from any nodes + - Example 2: + - Original: [Conv -> AvgPool -> Linear] + - AvgPool is not quantized + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] + - Inner brackets specify which nodes Q/DQ inherit metadata from + - [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ] + - Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because + AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation + on the nodes (in this case AvgPool node) to conclude if the node or pattern was + supposed to be quantized. And subsequently decide if the preceding Q, if any, should + inherit metadata from AvgPool. + - Dynamically quantized patterns: + - Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes + - For example, below linear is dynamically quantized while rest statically: + - Original: [Conv -> AvgPool -> Linear] + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear] + - Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]] + - Note first Q does not inherit metadata from any nodes + NB: + - The best place for porting metadata is during observer conversion to q/dq. This is because it precisely + knows which quantization spec is converted to q/dq and thus from where the metadata should be ported. + However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit. + Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant + code, this pass should like to be integrated in the refactored variant of "convert" step. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + annotation = node.meta.get("quantization_annotation", None) + if _is_valid_annotation(annotation): + input_qspec_map = node.meta["quantization_annotation"].input_qspec_map + output_qspec = node.meta["quantization_annotation"].output_qspec + for input_node, qspec in input_qspec_map.items(): + _port_metadata_for_input_quant_nodes(input_node, node, qspec) + _port_metadata_for_output_quant_nodes(node, output_qspec) + return PassResult(graph_module, True) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/prepare.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3c8b4b33d881feda0864cc65698972a3226c7c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/prepare.py @@ -0,0 +1,610 @@ +# mypy: allow-untyped-defs +from typing import Any + +import torch +from torch._subclasses import FakeTensor +from torch.ao.quantization import ( + CUSTOM_KEY, + NUMERIC_DEBUG_HANDLE_KEY, + ObserverOrFakeQuantize, + QConfigMapping, +) +from torch.ao.quantization.fx.custom_config import PrepareCustomConfig +from torch.ao.quantization.fx.prepare import ( + _create_obs_or_fq_from_qspec, + _insert_obs_or_fq, + _is_activation_post_process_node, + _save_state, +) +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization.quantizer import ( + EdgeOrNode, + QuantizationSpecBase, + SharedQuantizationSpec, +) +from torch.ao.quantization.utils import _assert_and_get_unique_device +from torch.fx import Graph, GraphModule, Node +from torch.fx.node import Argument + + +# TODO: make pt2e folder private? +__all__ = [ + "prepare", +] + + +def _find_root_edge_or_node( + edge_or_node: EdgeOrNode, shared_with_map: dict[EdgeOrNode, EdgeOrNode] +) -> EdgeOrNode: + """Find the root node for the sharing tree + Args: + edge_or_node: edge/node that we want to find the root + shared_with_map: each edge/node points to the parent, the root node will points to itself + + Returns: + root edge/node + """ + parent = shared_with_map[edge_or_node] + if parent == edge_or_node: + return edge_or_node + root = _find_root_edge_or_node(parent, shared_with_map) + # path compression + shared_with_map[edge_or_node] = root + return root + + +def _union( + parent: EdgeOrNode, + child: EdgeOrNode, + shared_with_map: dict[EdgeOrNode, EdgeOrNode], +) -> None: + """Merge the subtree for `child` with `parent`, the order is important here""" + root_parent = _find_root_edge_or_node(parent, shared_with_map) + root_child = _find_root_edge_or_node(child, shared_with_map) + # union the two trees by pointing the root of child to root of parent + shared_with_map[root_child] = root_parent + + +def _update_shared_with( + child: EdgeOrNode, + qspec: QuantizationSpecBase, + shared_with_map: dict[EdgeOrNode, EdgeOrNode], +): + """Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec` + configuration and established the relationship between `edge_or_node` with the edge/node that it + is pointing to, we'll use this information in the end to get the group id + """ + if isinstance(qspec, SharedQuantizationSpec): + parent = qspec.edge_or_node + # we point from edge_or_node to the node that it is sharing_with, e.g. + # qspec for a = SharedQuantizationSpec(b) means `a` points to `b` + _union(parent, child, shared_with_map) + + +def _unwrap_shared_qspec( + qspec: QuantizationSpecBase, + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase], + shared_with_map: dict[EdgeOrNode, EdgeOrNode], +) -> QuantizationSpecBase: + """Unwraps qspec to get the final root qspec (non SharedQuantizationSpec) + if qspec is SharedQuantizationSpec + (1). tries to find the root edge or node for the node that the qspec points to + (2). recursively find the root qspec based on the qspec for the root node + """ + if isinstance(qspec, SharedQuantizationSpec): + sharing_with = qspec.edge_or_node + root = _find_root_edge_or_node(sharing_with, shared_with_map) + qspec = edge_or_node_to_qspec[root] + return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) + return qspec + + +def _has_same_attr( + qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase, attr_name: str +): + return ( + hasattr(qspec_a, attr_name) + and hasattr(qspec_b, attr_name) + and getattr(qspec_a, attr_name) == getattr(qspec_b, attr_name) + ) or (not hasattr(qspec_a, attr_name) and not hasattr(qspec_b, attr_name)) + + +def _get_edge_or_node_to_qspec( + model: torch.fx.GraphModule, +) -> dict[EdgeOrNode, QuantizationSpecBase]: + """Get a map from EdgeOrNode to quantization spec based on annotations on the nodes""" + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase] = {} + for n in model.graph.nodes: + if hasattr(n, "meta") and "quantization_annotation" in n.meta: + qa = n.meta["quantization_annotation"] + for input_to_n, qspec in qa.input_qspec_map.items(): + input_edge = (input_to_n, n) + edge_or_node_to_qspec[input_edge] = qspec + if qa.output_qspec is not None: + output_node = n + qspec = qa.output_qspec + edge_or_node_to_qspec[output_node] = qspec + return edge_or_node_to_qspec + + +def _union_input_edge_with( + input_edge, + input_edge_root_qspec, + edge_or_node, + edge_or_node_to_qspec, + shared_with_map, +): + """Union input edge with another edge or node, used in implicit sharing to point the current input + edge to other user edges of the producer node, or the output of producer node since these are + referring to the same Tensor + """ + root_qspec = None + if edge_or_node in edge_or_node_to_qspec: + qspec = edge_or_node_to_qspec[edge_or_node] + root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) + # TODO: add assertions for types of root qspecs + if root_qspec is not None and all( + _has_same_attr(root_qspec, input_edge_root_qspec, attr) + for attr in [ + "dtype", + "is_dynamic", + "quant_min", + "quant_max", + "qscheme", + "ch_axis", + "scale", + "zero_point", + ] + ): + # the input arg to the node should reuse the existing output observer for arg + # since dtype is the same (we may want to extend this to be a more strict check + # in the future) + # so we point from `input_edge` to `arg` (output of the argument) + _union(edge_or_node, input_edge, shared_with_map) + + +def _get_edge_or_node_to_group_id( + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase], +) -> dict[EdgeOrNode, int]: + """Map from edge/node to the group ID, generated from quantization annotations, + edge/node with the same group ID should use the same observer/fake_quant instance + + This is applying SharedQuantizationSpec configuration and map each edge/node to a group + There is another implicit sharing that's built in the quantization, when we have the following: + * op1 -> op2 + * output of op1: int8_qspec + * (op1 -> op2) input edge: int8_qspec + we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor. + + Figuring out the correct group ID for all edge/node is a standard union find problem: + https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/ + + Args: + edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations + Returns: + edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that + belongs to the same group should have the same id + + Example: + op2 -> cat1 -> cat2 + op1 / / + op3 + edge_or_node_to_qspec: { + op1: int8_qspec, + op2: int8_qspec, + (op1, cat1): int8_qspc, + (op2, cat1): SharedQuantizationSpec((op1, cat1)), + cat1: SharedQuantizationSpec((op1, cat1)), + (op3, cat2): int8_qspec, + (cat1, cat2): SharedQuantizationSpec((op3, cat2)), + cat2: SharedQuantizationSpec((op3, cat2)), + } + + edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) + edge_or_node_to_group_id: { + op1: 1, + op2: 1, + (op1, cat1): 1, + (op2, cat1): 1, + cat1: 1, + (op3, cat2): 1, + (cat1, cat2): 1, + cat2: 1, + } + # everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which + # connects the two sharing group around cat1 and cat2 op due to transitive sharing + """ + # means the observer of key should be shared with observer with value, by default it will + # be shared with itself + shared_with_map: dict[EdgeOrNode, EdgeOrNode] = { + k: k for k in edge_or_node_to_qspec + } + for edge_or_node, qspec in edge_or_node_to_qspec.items(): + if isinstance(edge_or_node, torch.fx.Node): + output_node = edge_or_node + _update_shared_with(output_node, qspec, shared_with_map) + else: + input_edge = edge_or_node + input_edge_root_qspec = _unwrap_shared_qspec( + qspec, edge_or_node_to_qspec, shared_with_map + ) + + if not isinstance(input_edge, tuple): + raise AssertionError( + f"input_edge must be a tuple (arg, user), got {type(input_edge)}" + ) + arg, n = input_edge + if n.meta["quantization_annotation"].allow_implicit_sharing: + # NOTE: the order is important here, we first share with other users and then share with previous + # output because the reverse order could cause circular dependency + # e.g node1 -> node2 + # \ -> node3 + # when processing (node1, node2), if we first point (node1, node2) to node1 + # Step 1. shared_map = {(node1, node2): node1} + # Step 2. after that, we point the (node1, node2) to its other user (node1, node3) , + # which means shared_map = {(node1, node2): node1, node1: (node1, node3)} + # because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3) + # Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll + # have a circular dependency + # the following order works around this issue, but this does not allow arbitrary configuration + # of sharing so it might break in a different case in the future, when it breaks + # quantizer writer can check the notes here to debug the issue + + # sharing with other users of the producer node + # (arg, user) + if not isinstance(arg, Node) or not isinstance(n, Node): + raise Exception( # noqa: TRY002 + f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}" + ) + for user in arg.users: + if user is n: + continue + arg_to_user_edge = (arg, user) + _union_input_edge_with( + input_edge, + input_edge_root_qspec, + arg_to_user_edge, + edge_or_node_to_qspec, + shared_with_map, + ) + + # sharing with output of producer node + _union_input_edge_with( + input_edge, + input_edge_root_qspec, + arg, + edge_or_node_to_qspec, + shared_with_map, + ) + + _update_shared_with(input_edge, qspec, shared_with_map) + + # now that we get the sharing relations between all edges and nodes, we can assign group ids + cur_group_id = 0 + edge_or_node_to_group_id: dict[EdgeOrNode, int] = {} + for edge_or_node in shared_with_map: + root = _find_root_edge_or_node(edge_or_node, shared_with_map) + if root not in edge_or_node_to_group_id: + edge_or_node_to_group_id[root] = cur_group_id + cur_group_id += 1 + edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root] + + return edge_or_node_to_group_id + + +def _get_obs_or_fq_map( + edge_or_node_to_group_id: dict[EdgeOrNode, int], + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase], + is_qat: bool, +) -> dict[EdgeOrNode, ObserverOrFakeQuantize]: + """Generates the EdgeOrNode to observer/fake_quant instances + Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant + instances + """ + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {} + group_id_to_obs_or_fq: dict[int, ObserverOrFakeQuantize] = {} + for edge_or_node, qspec in edge_or_node_to_qspec.items(): + group_id = edge_or_node_to_group_id[edge_or_node] + if group_id not in group_id_to_obs_or_fq: + # TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify + # the implementation for _create_obs_or_fq_from_qspec + group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec( + qspec, obs_or_fq_map, is_qat + ) + obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id] + return obs_or_fq_map + + +def _maybe_insert_input_observer_for_arg_or_kwarg( + node: Node | Any, + arg: Argument, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + model_device: torch.device | None = None, +) -> Argument: + """ + Given a `node` and an `arg`, inserts an input observer between + `node` and `arg` if necessary. + """ + # for ops such as torch.cat([x0, x1]), + # traverse through the list + if isinstance(arg, (list, tuple)): + new_arg_to_return = [] + for inner_arg in arg: + new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + inner_arg, + qconfig, + model, + named_modules, + obs_or_fq_map, + is_qat, + model_device, + ) + new_arg_to_return.append(new_inner_arg) + return type(arg)(new_arg_to_return) + + if not isinstance(arg, Node): + return arg + if not isinstance(arg, Node): + raise AssertionError( + f"expect original argument to be a Node, but got: {type(arg)}" + ) + # default (no observer) + new_arg = arg + + # find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes + original_arg = arg + while _is_activation_post_process_node(original_arg, named_modules): + original_arg = original_arg.args[0] # type: ignore[assignment] + if not isinstance(original_arg, Node): + raise AssertionError( + f"expect original argument to be a Node, but got: {type(original_arg)}" + ) + + input_edge = (original_arg, node) + if input_edge not in obs_or_fq_map: + return new_arg + # input_edge needs to be observed + input_edge_obs_or_fq = obs_or_fq_map[input_edge] + if input_edge_obs_or_fq is None: + return new_arg + + arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg) + # the arg is observed as the output and is using the same instance as the input_edge + # we'll reuse the inserted observer/fake_quant + if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id( + input_edge_obs_or_fq + ): + return new_arg + + # otherwise, we'll insert a new observer/fake_quant node + + # skip inserting new observers if the same observer instance is inserted before for another user + # Example: + # conv1 -> obs1 -> existing_obs -> conv2 + # \ -> conv3 + # + # instead of inserting new observers we will have: + # conv1 -> obs1 -> existing_obs -> conv2 + # \ -> conv3 + for maybe_obs_node in arg.users: + if not _is_activation_post_process_node(maybe_obs_node, named_modules): + continue + maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] + if id(maybe_obs_mod) == id(input_edge_obs_or_fq): + return maybe_obs_node + + if not isinstance(model.graph, Graph): + raise AssertionError( + f"Expected model.graph to be a torch.fx.Graph, got {type(model.graph)}" + ) + new_arg = _insert_obs_or_fq( + arg, + input_edge_obs_or_fq, + model, + named_modules, + model.graph, + model_device, + ) + return new_arg + + +def _maybe_insert_input_observers_for_node( + node: Node, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + model_device: torch.device | None = None, +) -> None: + """ + If needed, inserts observers to the input args and kwargs of `node`. + Note: modifies `node` inplace. + + For example, if cur_node needs an observer after prev_node, we change from + + prev_node -> cur_node + + To + + prev_node -> obs -> cur_node + + """ + # Look through every input arg. If that arg's target dtype does not + # match the current node's target dtype, insert an observer. + new_args = [] + for arg in node.args: + new_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + arg, + qconfig, + model, + named_modules, + obs_or_fq_map, + is_qat, + model_device, + ) + new_args.append(new_arg) + + # Clone has a memory_format kwarg, zeros_like has a pin_memory kwarg, and + # gelu has a has an approximate kwarg that persist in exported graph. + # This is just a work around for these. + if not ( + node.target is torch.ops.aten.clone.default + or node.target is torch.ops.aten.zeros_like.default + or node.target is torch.ops.aten.gelu.default + or len(node.kwargs) == 0 + ): + raise AssertionError(" expecting kwargs for aten op IR to be empty") + + # assign the new args to the node, inplace + node.args = tuple(new_args) + + +def _maybe_insert_output_observer_for_node( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + model_device: torch.device | None = None, +) -> Node | None: + if node in obs_or_fq_map: + output_act_obs_or_fq = obs_or_fq_map[node] + new_output = _insert_obs_or_fq( + node, + output_act_obs_or_fq, + model, + named_modules, + graph, + model_device, + ) + # propagate numeric debug handle from original node to observer/fake_quant node + if ( + isinstance(node, Node) + and isinstance(new_output, Node) + and CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + if CUSTOM_KEY not in new_output.meta: + new_output.meta[CUSTOM_KEY] = {} + new_output.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + CUSTOM_KEY + ][NUMERIC_DEBUG_HANDLE_KEY] + return new_output + return None + + +def _maybe_insert_input_and_output_observers_for_node( + node: Node, + model: torch.fx.GraphModule, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + model_device: torch.device | None = None, +): + this_node_quantization_annotation = node.meta.get("quantization_annotation", None) + if this_node_quantization_annotation is None: + return + + named_modules = dict(model.named_modules(remove_duplicate=False)) + _maybe_insert_input_observers_for_node( + node, + None, # qconfig + model, + named_modules, + obs_or_fq_map, + is_qat, + model_device, + ) + + output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor) + if not output_is_a_tensor: + return + + # this returns the new observer node if it was needed + maybe_output_obs_node = _maybe_insert_output_observer_for_node( + node, + model, + named_modules, + model.graph, + obs_or_fq_map, + is_qat, + model_device, + ) + + if maybe_output_obs_node is None: + return + # Update users of original node to use the output observer + # instead. For example, change + # + # next_node + # / + # cur_node -> obs + # + # to + # + # next_node + # / + # cur_node -> obs + # + # We need to save orig users before updating uses because + # the list of users will change as we update uses + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is maybe_output_obs_node: + continue + user_node.replace_input_with(node, maybe_output_obs_node) + + +def prepare( + model: GraphModule, + node_name_to_scope: dict[str, tuple[str, type]], + is_qat: bool, + obs_or_fq_callback=None, +) -> GraphModule: + # Since we are mutating the graph as we go, we iterate over the original + # nodes before observer insertion, instead of model.graph.nodes. + nodes_before_observation = list(model.graph.nodes) + + # At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance + # all edge/nodes that belongs to the same group will use the same instance + # and when we insert observers we'll just query this map to get the correct observer_or_fake_quant + # instance + edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model) + edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) + obs_or_fq_map = _get_obs_or_fq_map( + edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat + ) + if obs_or_fq_callback: + obs_or_fq_callback(model, obs_or_fq_map) + model_device = _assert_and_get_unique_device(model) + + for node in nodes_before_observation: + # TODO: simplify logic for inserting observers + _maybe_insert_input_and_output_observers_for_node( + node, + model, + obs_or_fq_map, + is_qat, + model_device, + ) + + model = GraphModule(model, model.graph) + + _save_state( + model, + {}, # node_name_to_qconfig + node_name_to_scope, + PrepareCustomConfig(), + {}, # equalization_node_name_to_qconfig + QConfigMapping(), + is_qat, + set(), # observed_node_names + ) + return model diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/qat_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/qat_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9498a4f16f78f256baba85246dabeb458c9764c2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/qat_utils.py @@ -0,0 +1,1058 @@ +# mypy: allow-untyped-defs +import copy +import dataclasses +import itertools +import operator +from collections.abc import Callable +from typing import Any, TYPE_CHECKING + +import torch +import torch.nn.functional as F +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.ao.quantization.pt2e.export_utils import _WrapperModule +from torch.ao.quantization.quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + QuantizationSpecBase, + SharedQuantizationSpec, +) +from torch.fx import Graph, GraphModule, Node +from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns + +from .utils import ( + _get_aten_graph_module_for_pattern, + _is_bn_node, + _is_conv_or_conv_transpose_node, + _is_conv_transpose_fn, + fold_bn_weights_into_conv_node, +) + + +if TYPE_CHECKING: + from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch + +__all__ = [] # type: ignore[var-annotated] + + +def _get_quantized_conv_bn_example_inputs_kwargs( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + is_cuda: bool, +) -> dict[str, Any]: + """ + Optional example inputs for quantized and folded conv-bn patterns + used in convert, expressed as kwargs. + """ + kwargs = {} + # Per tensor quantization uses literals to represent scale and zero + # point, so there is no need to include them here as kwargs + if is_per_channel: + kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float) + kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int) + if has_bias and bias_is_quantized: + kwargs["bias_scale"] = torch.tensor([1], dtype=torch.float) + kwargs["bias_zero_point"] = torch.tensor([0], dtype=torch.int) + if has_bias: + kwargs["conv_bias"] = torch.randn(1) + if is_cuda: + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + kwargs[k] = v.cuda() + return kwargs + + +def _get_conv_bn_pattern(conv_fn: Callable) -> Callable: + def _conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + x = conv_fn(x, conv_weight, conv_bias) + x = F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True + ) + return x + + return _WrapperModule(_conv_bn_pattern) + + +# TODO: merge this with the `no_conv_bias` case +def _get_qat_conv_bn_pattern(conv_fn: Callable) -> Callable: + def _qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + """ + Approximated method to fuse conv and bn. It requires only one forward pass. + conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std. + This is based on `nniqat.ConvBn2d._forward_approximate`. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0 + weight_shape[weight_in_channel_axis] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype) + x = conv_fn(x, scaled_weight, zero_bias) + x = x / scale_factor.reshape(bias_shape) + x = x + conv_bias.reshape(bias_shape) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=True, + eps=bn_eps, + ) + return x + + return _WrapperModule(_qat_conv_bn_pattern) + + +def _get_qat_conv_bn_pattern_no_conv_bias(conv_fn: Callable) -> Callable: + def _qat_conv_bn_pattern_no_conv_bias( + x: torch.Tensor, + conv_weight: torch.Tensor, + # Not used, only for matching convenience + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + """ + Same as `_get_qat_conv_bn_pattern`, but handles the case with no conv bias. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0 + weight_shape[weight_in_channel_axis] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + x = conv_fn(x, scaled_weight, None) + x = x / scale_factor.reshape(bias_shape) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=True, + eps=bn_eps, + ) + return x + + return _WrapperModule(_qat_conv_bn_pattern_no_conv_bias) + + +def _append_qdq(x, is_per_channel, is_bias, kwargs): + """ + Helper function to append q-dq ops after `x`, using dummy values for the qparams + and qmin/qmax. We use dummy values here because we match with `ignore_literals=True` + and will manually replace these values after subgraph rewriting. + + Return the dq node. + """ + # Dummy args to be passed into q-dq ops + per_channel_axis = 0 + scale_key = "bias_scale" if is_bias else "weight_scale" + zp_key = "bias_zero_point" if is_bias else "weight_zero_point" + scale = kwargs[scale_key] if is_per_channel else 1.0 + zp = kwargs[zp_key] if is_per_channel else 0 + qmin = -127 + qmax = 127 + dtype = torch.int8 + + qd = torch.ops.quantized_decomposed + if is_per_channel: + x = qd.quantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype) + x = qd.dequantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype) + else: + x = qd.quantize_per_tensor(x, scale, zp, qmin, qmax, dtype) + x = qd.dequantize_per_tensor(x, scale, zp, qmin, qmax, dtype) + return x + + +def _get_quantized_qat_conv_bn_pattern( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + conv_fn: Callable, + bn_is_training: bool, +) -> Callable: + """ + Return the quantized version of QAT conv + BN pattern. + This is based on `nniqat.ConvBn2d._forward_approximate`, + used in QAT convert. We first match this pattern and replace + it with the normal [conv - bn] pattern, then fold the BN + weights into conv. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + + def _quantized_qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + scaled_weight = _append_qdq( + scaled_weight, + is_per_channel, + is_bias=False, + kwargs=kwargs, + ) + if has_bias: + zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype) + if bias_is_quantized: + zero_bias = _append_qdq( + zero_bias, + is_per_channel, + is_bias=True, + kwargs=kwargs, + ) + x = conv_fn(x, scaled_weight, zero_bias) + else: + x = conv_fn(x, scaled_weight, None) + x = x / scale_factor.reshape(bias_shape) + if has_bias: + x = x + kwargs["conv_bias"].reshape(bias_shape) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=bn_is_training, + eps=bn_eps, + ) + return x + + return _WrapperModule(_quantized_qat_conv_bn_pattern) + + +def _get_folded_quantized_qat_conv_bn_pattern( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + conv_fn: Callable, + bn_is_training: bool, +) -> Callable: + """ + Quantized QAT conv - bn pattern with bn weights being folded into conv. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + + def _folded_quantized_qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + conv_weight = _append_qdq( + conv_weight, + is_per_channel, + is_bias=False, + kwargs=kwargs, + ) + if has_bias: + bias = kwargs["conv_bias"] + if bias_is_quantized: + bias = _append_qdq( + bias, + is_per_channel, + is_bias=True, + kwargs=kwargs, + ) + else: + bias = None + x = conv_fn(x, conv_weight, bias) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=bn_is_training, + eps=bn_eps, + ) + return x + + return _WrapperModule(_folded_quantized_qat_conv_bn_pattern) + + +def _has_conv_bias_filter( + match: "InternalMatch", + original_graph: Graph, + pattern_graph: Graph, +) -> bool: + """ + Match filter for the subgraph rewriter that returns True if the conv node in + the original graph has bias. + """ + for n in match.nodes_map.values(): + if _is_conv_or_conv_transpose_node(n): + return len(n.args) > 2 and n.args[2] is not None + raise ValueError("Could not find conv node in matched conv + bn pattern") + + +def _no_conv_bias_filter( + match: "InternalMatch", + original_graph: Graph, + pattern_graph: Graph, +) -> bool: + """ + Match filter for the subgraph rewriter that returns True if the conv node in + the original graph does NOT have bias. + """ + return not _has_conv_bias_filter(match, original_graph, pattern_graph) + + +def _is_quantize(n: Node) -> bool: + return n.target in [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + ] + + +def _is_dequantize(n: Node) -> bool: + return n.target in [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ] + + +def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> dict[str, tuple[Node, Node]]: + """ + Helper function to extract the nodes in the conv-bn fusion pattern after + subgraph rewriting, in the form of a map: + + {name: (original_node, replacement_node)} + + The following names must exist in the map: + + "conv", "conv_weight", "conv_input", "bn", "getitem" + + The following names may exist in the map: + + "conv_weight_q", "conv_weight_dq", "conv_bias", + "conv_bias_q", "conv_bias_dq" + """ + + def _get_nodes(nodes: list[Node]) -> tuple[Node, Node, Node | None]: + """ + Return a 3-tuple of (conv_node, bn_node, getitem_node). + This asserts that the match contains exactly one of each node. + """ + conv_node, bn_node, getitem_node = None, None, None + for n in nodes: + if n.op != "call_function": + continue + if _is_conv_or_conv_transpose_node(n): + if conv_node is not None: + raise AssertionError( + f"Found multiple conv nodes in match, previous: {conv_node}, new: {n}" + ) + conv_node = n + if _is_bn_node(n): + if bn_node is not None: + raise AssertionError( + f"Found multiple bn nodes in match, previous: {bn_node}, new: {n}" + ) + bn_node = n + if n.target is operator.getitem: + if getitem_node is not None: + raise AssertionError( + f"Found multiple getitem nodes in match, previous: {getitem_node}, new: {n}" + ) + getitem_node = n + if conv_node is None: + raise AssertionError( + "Expected exactly one conv node in the match, found none" + ) + if bn_node is None: + raise AssertionError( + "Expected exactly one bn node in the match, found none" + ) + return (conv_node, bn_node, getitem_node) + + def _get_q_dq_nodes(n: Node) -> tuple[Node, Node, Node]: + """ + Return a 3-tuple of (orig_node, q_node, dq_node). + """ + if not _is_dequantize(n): + raise AssertionError(f"Expected a dequantize node, got: {n}") + q_node = n.args[0] + if not isinstance(q_node, Node): + raise AssertionError( + f"Expected quantize node to be a torch.fx.Node, got {type(q_node)}" + ) + if not _is_quantize(q_node): + raise AssertionError( + f"Expected q_node to be a quantize node, got target={q_node.target}" + ) + orig_node = q_node.args[0] + if not isinstance(orig_node, Node): + raise AssertionError( + f"Expected original node to be a torch.fx.Node, got {type(orig_node)}" + ) + return (orig_node, q_node, n) + + original_nodes = list(_filter_nodes_map(r.nodes_map).values()) + o_conv, o_bn, o_getitem = _get_nodes(original_nodes) + r_conv, r_bn, r_getitem = _get_nodes(r.replacements) + + # Create the mapping from original node to replacement node + if o_getitem is not None: + raise AssertionError(f"Expected o_getitem to be None, got {o_getitem}") + if r_getitem is not None: + raise AssertionError(f"Expected r_getitem to be None, got {r_getitem}") + mapping = { + "conv": (o_conv, r_conv), + "bn": (o_bn, r_bn), + } + + # Extract conv input and weight + # Note: here we extract the original nodes indirectly through the pattern nodes + # because the args of the original nodes are no longer available after replacement + (p_conv, _, _) = _get_nodes(list(r.nodes_map.keys())) + (p_conv_input, p_conv_weight, *_) = p_conv.args + (r_conv_input, r_conv_weight, *_) = r_conv.args + if not isinstance(p_conv_input, Node): + raise AssertionError( + f"Expected p_conv_input to be a Node, got {type(p_conv_input)}" + ) + if not isinstance(p_conv_weight, Node): + raise AssertionError( + f"Expected p_conv_weight to be a Node, got {type(p_conv_weight)}" + ) + if not isinstance(r_conv_input, Node): + raise AssertionError( + f"Expected r_conv_input to be a Node, got {type(r_conv_input)}" + ) + if not isinstance(r_conv_weight, Node): + raise AssertionError( + f"Expected r_conv_weight to be a Node, got {type(r_conv_weight)}" + ) + o_conv_input = r.nodes_map[p_conv_input] + o_conv_weight = r.nodes_map[p_conv_weight] + + # If conv weight is quantized, extract the q - dq nodes + if _is_dequantize(p_conv_weight): + p_conv_weight, p_conv_weight_q, p_conv_weight_dq = _get_q_dq_nodes( + p_conv_weight + ) + r_conv_weight, r_conv_weight_q, r_conv_weight_dq = _get_q_dq_nodes( + r_conv_weight + ) + o_conv_weight = r.nodes_map[p_conv_weight] + o_conv_weight_q = r.nodes_map[p_conv_weight_q] + o_conv_weight_dq = r.nodes_map[p_conv_weight_dq] + mapping["conv_weight_q"] = (o_conv_weight_q, r_conv_weight_q) + mapping["conv_weight_dq"] = (o_conv_weight_dq, r_conv_weight_dq) + mapping["conv_input"] = (o_conv_input, r_conv_input) + mapping["conv_weight"] = (o_conv_weight, r_conv_weight) + + # Extract conv bias + if len(p_conv.args) > 2 and len(r_conv.args) > 2: + p_conv_bias = p_conv.args[2] + r_conv_bias = r_conv.args[2] + if not isinstance(p_conv_bias, Node): + raise AssertionError( + f"Expected p_conv_bias to be a Node, got {type(p_conv_bias)}" + ) + if not isinstance(r_conv_bias, Node): + raise AssertionError( + f"Expected r_conv_bias to be a Node, got {type(r_conv_bias)}" + ) + o_conv_bias = r.nodes_map[p_conv_bias] + + # If conv bias is quantized, extract the q - dq nodes + if _is_dequantize(p_conv_bias): + p_conv_bias, p_conv_bias_q, p_conv_bias_dq = _get_q_dq_nodes(p_conv_bias) + r_conv_bias, r_conv_bias_q, r_conv_bias_dq = _get_q_dq_nodes(r_conv_bias) + o_conv_bias = r.nodes_map[p_conv_bias] + o_conv_bias_q = r.nodes_map[p_conv_bias_q] + o_conv_bias_dq = r.nodes_map[p_conv_bias_dq] + mapping["conv_bias_q"] = (o_conv_bias_q, r_conv_bias_q) + mapping["conv_bias_dq"] = (o_conv_bias_dq, r_conv_bias_dq) + mapping["conv_bias"] = (o_conv_bias, r_conv_bias) + return mapping + + +def _filter_nodes_map(nodes_map: dict[Node, Node]) -> dict[Node, Node]: + """ + Return a filtered `nodes_map` returned from the subgraph rewriter. + The filtered `nodes_map` will contain only nodes that are actually + matched in the pattern, excluding None or placeholder nodes. + """ + new_nodes_map: dict[Node, Node] = {} + for pattern_node, graph_node in nodes_map.items(): + # bias can be None + if graph_node is None: + continue + # skip pattern placeholder nodes + if pattern_node.op == "placeholder": + continue + new_nodes_map[pattern_node] = graph_node + return new_nodes_map + + +# TODO: this is error prone, use the replace_literals_with_placeholders hack instead +def _copy_over_literal_conv_args(original_node: Node, new_node: Node): + """ + Copy over literal args in conv, such as stride and padding, from the matched node + in the original graph to its replacement in the new graph. + + This is needed due to the following limitation in the subgraph rewriter when used + with dynamo export: literal (non-tensor) args are not supported in the match and + replacement patterns. This is because dynamo export automatically inlines these + literal args, making them dead placeholder nodes. In the future, we should check + if dynamo export can optionally disable this inlining, or if subgraph rewriter + can do the copying for us. See https://github.com/pytorch/pytorch/issues/100419. + + Note: Unlike other tensor args like conv weights and biases, literal args are + preserved in the original nodes after replacement, so we can access them here. + """ + if not _is_conv_or_conv_transpose_node(original_node): + raise AssertionError( + f"Expected original_node to be a conv node, got {original_node}" + ) + if not _is_conv_or_conv_transpose_node(new_node): + raise AssertionError(f"Expected new_node to be a conv node, got {new_node}") + # x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups] + new_args = list(new_node.args) + if len(new_args) < 3: + # bias is optional, when it is not present, it means it is None + new_args.append(None) + new_node.args = tuple(new_args[:3]) + original_node.args[3:] + + +def _update_conv_input_qspec_map_after_replacement( + original_node: Node, replacement_node: Node +): + """ + Update the `input_qspec_map` in the annotation after subgraph rewriting. + + The original annotation referred to the nodes in the original graph, + so the keys in the `input_qspec_map` will need to be updated to reflect + the corresponding nodes in the replacement graph. + """ + if not _is_conv_or_conv_transpose_node(original_node): + raise AssertionError( + f"Expected original_node to be a conv node, got {original_node}" + ) + if not _is_conv_or_conv_transpose_node(replacement_node): + raise AssertionError( + f"Expected replacement_node to be a conv node, got {replacement_node}" + ) + if "quantization_annotation" not in original_node.meta: + return + original_input_qspec_map = original_node.meta[ + "quantization_annotation" + ].input_qspec_map + input_qspec_map = {} + # get the list of configs, it should be ordered as input, weight, bias + # note: this is really hacky, we need a better solution, hopefully + # in subgraph_rewriter, issue tracking the problem: https://github.com/pytorch/pytorch/issues/101820 + all_configs = list(original_input_qspec_map.items()) + # input activation + input_qspec_map[replacement_node.args[0]] = all_configs[0][1] + # weight + input_qspec_map[replacement_node.args[1]] = all_configs[1][1] + # bias + if len(replacement_node.args) > 2 and len(all_configs) > 2: + input_qspec_map[replacement_node.args[2]] = all_configs[2][1] + replacement_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map + + +def _update_special_qspecs_after_replacement( + node: Node, + original_to_replacement_node: dict[Node, Node], +): + """ + Update the `SharedQuantizationSpec`s and `DerivedQuantizationSpec`s + used in `node`'s quantization annotation after subgraph rewriting. + + The original annotation referred to the nodes in the original graph, + so the nodes used in these special quantization specs will need to + be updated to the corresponding nodes in the replacement graph. + """ + + def _get_new_edge_or_node(edge_or_node: EdgeOrNode): + if isinstance(edge_or_node, Node): + _node = edge_or_node + return original_to_replacement_node.get(_node, _node) + elif ( + isinstance(edge_or_node, tuple) + and len(edge_or_node) == 2 + and all(isinstance(x, Node) for x in edge_or_node) + ): + src, dest = edge_or_node + return ( + original_to_replacement_node.get(src, src), + original_to_replacement_node.get(dest, dest), + ) + else: + raise ValueError("unexpected type for edge_or_node: ", type(edge_or_node)) + + def _get_new_qspec(qspec: QuantizationSpecBase): + if isinstance(qspec, SharedQuantizationSpec): + new_edge_or_node = _get_new_edge_or_node(qspec.edge_or_node) + return SharedQuantizationSpec(new_edge_or_node) + elif isinstance(qspec, DerivedQuantizationSpec): + new_derived_from = [_get_new_edge_or_node(x) for x in qspec.derived_from] + return dataclasses.replace(qspec, derived_from=new_derived_from) + else: + return qspec + + if "quantization_annotation" not in node.meta: + return + annotation = node.meta["quantization_annotation"] + for input_node, qspec in annotation.input_qspec_map.items(): + annotation.input_qspec_map[input_node] = _get_new_qspec(qspec) + annotation.output_qspec = _get_new_qspec(annotation.output_qspec) + + +def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: + # Example inputs for conv-bn1d patterns + _conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for conv-bn2d patterns + _conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return m + is_cuda_options = [True, False] if torch.cuda.is_available() else [False] + for is_cuda in is_cuda_options: + m = _fuse_conv_bn_qat_helper( + m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fuse_conv_bn_qat_helper( + m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fuse_conv_bn_qat_helper( + m, F.conv_transpose1d, _conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fuse_conv_bn_qat_helper( + m, F.conv_transpose2d, _conv2d_bn_example_inputs, is_cuda=is_cuda + ) + return m + + +def _fuse_conv_bn_qat_helper( + m: GraphModule, + conv_fn: Callable, + example_inputs: tuple[Any, ...], + is_cuda: bool, +) -> GraphModule: + """ + Given a graph of decomposed aten ops, replace the (conv + bn) pattern with + the fused QAT subgraph equivalent. The input graph should already be annotated. + The annotations in the original nodes will be preserved in the corresponding + nodes in the new subgraph. + + Note: This also handles the (conv + bn + relu) pattern. + """ + m.graph.eliminate_dead_code() + m.recompile() + + conv_bn_pattern = _get_conv_bn_pattern(conv_fn) + match_pattern = _get_aten_graph_module_for_pattern( + conv_bn_pattern, + example_inputs, + is_cuda, + ) + + # Step (1): Replace patterns with conv bias + # + # Here we do replacement separately for cases with and without conv bias, since + # the replacement patterns for these two cases are substantially different. + # TODO: use the public replace_pattern API once it also returns replacement nodes + + qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn) + replacement_pattern_with_conv_bias = _get_aten_graph_module_for_pattern( + qat_conv_bn_pattern, + example_inputs, + is_cuda, + ) + replacements_with_conv_bias = replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern_with_conv_bias, + match_filters=[_has_conv_bias_filter], + ignore_literals=True, + ) + m.recompile() + + # Step (2): Replace patterns without conv bias + + qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn) + replacement_pattern_no_conv_bias = _get_aten_graph_module_for_pattern( + qat_conv_bn_pattern_no_conv_bias, + example_inputs, + is_cuda, + ) + replacements_no_conv_bias = replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern_no_conv_bias, + match_filters=[_no_conv_bias_filter], + ignore_literals=True, + ) + m.recompile() + + # Step (3): Post processing + # + # Due to limited functionality in the subgraph rewriter, here we manually + # update the replacement graph as follows: + # + # (a) Copy over metadata from original subgraph. This ensures the stack traces + # and annotations are preserved in the new subgraph + # + # (b) Copy over literal args for conv from the original subgraph + # TODO: do this for literal args for batchnorm as well + # + # (c) Update all references of the old nodes in the original subgraph to refer + # to the corresponding nodes in the new subgraph in the annotations + # + # In the future, we should try to push as much of this functionality into the + # subgraph rewriter as possible, so we don't have to manually copy anything over. + # For more detail, see https://github.com/pytorch/pytorch/issues/100419. + + all_original_to_replacement_nodes = {} + for r in replacements_with_conv_bias + replacements_no_conv_bias: + replacement_dict = _get_conv_bn_pattern_nodes(r) + # The original conv node's "nn_module_stack" + conv_nn_module = replacement_dict["conv"][0].meta.get("nn_module_stack", None) + for k, node_tuple in replacement_dict.items(): + original_node, replacement_node = node_tuple + # Step (3a): Copy over metadata for all nodes in [conv - bn - getitem] + replacement_node.meta = original_node.meta + # If original_node is a get_attr node, it doesn't have nn_module_stack. + # In this case, we copy nn_module_stack from the original conv node. + if ( + k in ["conv_input", "conv_weight"] + and conv_nn_module + and "nn_module_stack" not in replacement_node.meta + ): + replacement_node.meta["nn_module_stack"] = copy.deepcopy(conv_nn_module) + if _is_conv_or_conv_transpose_node(original_node): + # Step (3b): Copy over conv literal args + _copy_over_literal_conv_args(original_node, replacement_node) + # Step (3c): Update old references in the conv node's input_qspec_map + _update_conv_input_qspec_map_after_replacement( + original_node, replacement_node + ) + all_original_to_replacement_nodes[original_node] = replacement_node + + # Step (3c): Update old references in the special qspecs for all nodes in the graph + for n in m.graph.nodes: + _update_special_qspecs_after_replacement(n, all_original_to_replacement_nodes) + + return m + + +def _duplicate_dequantize_node(m: GraphModule): + """ + Helper function to duplicate all dequantize nodes in the graph if the + node has more than one user. For example: + + Before: + quantize -> dequantize -> a + \\--> b + \\--> c + + After: + quantize -> dequantize_1 -> a + \\--> dequantize_2 -> b + \\--> dequantize_3 -> c + + This is useful for subgraph rewriting. E.g. if we wish to match the + pattern [dequantize - a] above, subgraph matching would fail because + the dequantize node has users outside the matched portion of the graph. + Instead, we match [dequantize_1 - a], which is safe. + """ + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor + for n in m.graph.nodes: + if n.op != "call_function" or n.target != dq_op or len(n.users) == 1: + continue + for user in list(n.users): + with m.graph.inserting_before(n): + new_node = m.graph.create_node("call_function", dq_op, n.args, n.kwargs) + user.replace_input_with(n, new_node) + m.graph.erase_node(n) + m.recompile() + + +def _remove_extra_dequantize(m: GraphModule): + """ + Removes duplicate dequant nodes in the graph, for an operator that has + multiple dequant nodes as a user. Replace them with a single dequant node + that can be shared across all uses. This should be seen as the "reverse" + of `_duplicate_dequantize_node`. + """ + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor + for n in m.graph.nodes: + dq_users = [ + user + for user in n.users + if user.op == "call_function" and user.target == dq_op + ] + if len(dq_users) > 1: + with m.graph.inserting_after(dq_users[0]): + new_node = m.graph.create_node( + "call_function", dq_op, dq_users[0].args, {} + ) + for dq_user in dq_users: + dq_user.replace_all_uses_with(new_node) + m.graph.erase_node(dq_user) + m.recompile() + + +def _copy_over_q_dq_args(original_node: Node, replacement_node: Node): + """ + Given a pair of quantize or dequantize nodes, copy over all literal args + from the original node to the replacement node. + """ + # For quantize_per_tensor, scale and zp are literals and need to be copied + # For quantize_per_channel, scale and zp are get_attr nodes and should be skipped + if original_node.target != replacement_node.target: + raise AssertionError( + "Expected original and replacement nodes to have the same target, got " + f"{original_node.target} != {replacement_node.target}" + ) + if original_node.target in ( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ): + # Args: input, [scale, zp, qmin, qmax, dtype] + start_copy_arg_index = 1 + elif original_node.target in ( + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ): + # Args: input, scale, zp, [axis, qmin, qmax, dtype] + start_copy_arg_index = 3 + else: + raise ValueError( + f"Expected quantize/dequantize nodes, got '{original_node.target}'" + ) + replacement_node.args = ( + replacement_node.args[:start_copy_arg_index] + + original_node.args[start_copy_arg_index:] + ) + + +def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: + # Example inputs for quantized and folded conv-bn1d patterns used in convert + _quantized_conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for quantized and folded conv-bn2d patterns used in convert + _quantized_conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return m + is_cuda_options = [True, False] if torch.cuda.is_available() else [False] + for is_cuda in is_cuda_options: + m = _fold_conv_bn_qat_helper( + m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fold_conv_bn_qat_helper( + m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fold_conv_bn_qat_helper( + m, F.conv_transpose1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fold_conv_bn_qat_helper( + m, F.conv_transpose2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda + ) + + # remove in place add from batchnorm tracking training stats + for node in m.graph.nodes: + if ( + node.target is torch.ops.aten.add_.Tensor + and node.args[0].op == "get_attr" + and node.args[1] == 1 + and ( + torch.nn.modules.batchnorm.BatchNorm2d + in [val[1] for val in node.meta["source_fn_stack"]] + or torch.nn.modules.batchnorm.BatchNorm1d + in [val[1] for val in node.meta["source_fn_stack"]] + ) + ): + m.graph.erase_node(node) + + m.graph.eliminate_dead_code() + m.recompile() + + return m + + +def _fold_conv_bn_qat_helper( + m: GraphModule, + conv_fn: Callable, + example_inputs: tuple[Any, ...], + is_cuda: bool, +) -> GraphModule: + """ + Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv. + """ + + m.graph.eliminate_dead_code() + m.recompile() + _duplicate_dequantize_node(m) + + # Step (1): Replace QAT pattern with simple [conv - bn] pattern + replacements = [] + replacement_options = itertools.product( + [True, False], # is_per_channel + [True, False], # has_bias + [True, False], # bias_is_quantized + [True, False], # bn_is_training + ) + for ( + is_per_channel, + has_bias, + bias_is_quantized, + bn_is_training, + ) in replacement_options: + # For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily + # filter out one of the values for this flag to avoid having duplicate patterns + if not has_bias and bias_is_quantized: + continue + kwargs = _get_quantized_conv_bn_example_inputs_kwargs( + is_per_channel, has_bias, bias_is_quantized, is_cuda + ) + match_pattern = _get_quantized_qat_conv_bn_pattern( + is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training + ) + match_pattern = _get_aten_graph_module_for_pattern( + match_pattern, + example_inputs, + is_cuda, + **kwargs, + ) + replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern( + is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training + ) + replacement_pattern = _get_aten_graph_module_for_pattern( + replacement_pattern, + example_inputs, + is_cuda, + **kwargs, + ) + replacements.extend( + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + ignore_literals=True, + ) + ) + m.recompile() + _remove_extra_dequantize(m) + + for r in replacements: + node_map = _get_conv_bn_pattern_nodes(r) + + # Step (2): Copy over metadata from original subgraph + for original_node, replacement_node in node_map.values(): + replacement_node.meta = original_node.meta + + # Step (3): Copy over args for weight (and optionally bias) q - dq nodes + _copy_over_q_dq_args(*node_map["conv_weight_q"]) + _copy_over_q_dq_args(*node_map["conv_weight_dq"]) + if "conv_bias_q" in node_map: + if "conv_bias_dq" not in node_map: + raise AssertionError( + "Expected 'conv_bias_dq' to be present in node_map when 'conv_bias_q' is present" + ) + _copy_over_q_dq_args(*node_map["conv_bias_q"]) + _copy_over_q_dq_args(*node_map["conv_bias_dq"]) + + # Step (4): Fold BN weights into conv + conv_bias = None + (_, conv_node) = node_map["conv"] + (_, bn_node) = node_map["bn"] + (_, conv_weight) = node_map["conv_weight"] + if "conv_bias" in node_map: + (_, conv_bias) = node_map["conv_bias"] + fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m) + + # Copy over literal args for conv + for original_node in _filter_nodes_map(r.nodes_map).values(): + if _is_conv_or_conv_transpose_node(original_node): + _copy_over_literal_conv_args(original_node, conv_node) + + m.graph.eliminate_dead_code() + m.recompile() + return m diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8876d439feb41929ca9b64f3f023db499eac007b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__init__.py @@ -0,0 +1,6 @@ +from .rewrite import reference_representation_rewrite + + +__all__ = [ + "reference_representation_rewrite", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cc55d7539fb54f7b83b5064f33c56c6d4fd5723 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..52084784f5036a92a909ad7f044d733677e48618 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py @@ -0,0 +1,825 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable +from dataclasses import dataclass +from functools import partial +from typing import Any + +import torch +from torch._export.utils import _disable_aten_to_metadata_assertions +from torch._higher_order_ops.out_dtype import out_dtype +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.ao.quantization.pt2e.export_utils import _WrapperModule +from torch.ao.quantization.pt2e.utils import ( + _get_aten_graph_module_for_pattern, + _replace_literals_with_existing_placeholders, + _replace_literals_with_new_placeholders, + remove_tensor_overload_for_qdq_ops, +) +from torch.fx import GraphModule +from torch.fx.subgraph_rewriter import replace_pattern + + +__all__ = [ + "reference_representation_rewrite", +] + + +def _qdq_quantized_linear( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + torch.int8, + ) + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_linear( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. + # This results in failure to match the pattern. + # Therefore, we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max) + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None, + ) + # TODO: change to mul.Scalar + # Note: we are quantizing bias with these scales without signal from user, but it might be OK + bias_scale = x_scale * weight_scale + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + acc_i32 = acc_i32 + bias_i32 + # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values + acc_i32 = ( + out_dtype( + torch.ops.aten.mul.Tensor, + torch.int32, + acc_i32, + x_scale * weight_scale / out_scale, + ) + + out_zero_point + ) + out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8) + return out_i8 + + +def _qdq_dynamic_quantized_linear( + x_fp32, + x_quant_min, + x_quant_max, + x_eps, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, +): + x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams( + x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8 + ) + x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + torch.int8, + ) + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + return out_fp32 + + +def _reference_dynamic_quantized_linear( + x_fp32, + x_quant_min, + x_quant_max, + x_eps, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, +): + x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams( + x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8 + ) + # decomposed representation for quantize_per_tensor + # TODO: use out_dtype(mul, ...) here when the op is ready + x_fp32 = x_fp32 / x_scale # fp32 + # round modes might be different here + # pytorch is rounding to even, which is also common for most of the backends + x_fp32 = torch.round(x_fp32) # fp32 + x_i32 = x_fp32.to(dtype=torch.int32) # int32 + x_i32 = x_i32 + x_zero_point # int32 + # clamp works for fp32, int32 and int8 dtypes + x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max) # int32 + x_i8 = x_i32.to(dtype=torch.int8) + + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None, + ) + bias_scale = x_scale * weight_scale + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + acc_i32 = acc_i32 + bias_i32 + out_fp32 = acc_i32 * (x_scale * weight_scale) + return out_fp32 + + +def _qdq_quantized_conv2d( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + stride = [1, 1] + padding = [0, 0] + dilation = [1, 1] + transposed = False + output_padding = [0, 0] + groups = 1 + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + torch.int8, + ) + out_fp32 = torch.ops.aten.convolution.default( + x_fp32, + weight_fp32, + bias_fp32, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_conv2d( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + stride = [1, 1] + padding = [0, 0] + dilation = [1, 1] + transposed = False + output_padding = [0, 0] + groups = 1 + # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. + # This results in failure to match the pattern. + # Therefore, we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max) + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.convolution.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + # Note: we are quantizing bias with these scales without signal from user, but it might be OK + bias_scale = x_scale * weight_scale + # bias quantization to int32 uses bias_scale = x_scale * weight_scale due to: + # Take linear calculation for example + # Out_(i, j)_fp32 = Sum_(over k)[X_(i, k)_fp32 * W_(i, k)_fp32] + bias_(i)_fp32 + # Represent X, W fp32 as their dequant transforms + # A_fp32 = (A_q - A_zero_point)/A_scale + # Out_(i, j)_fp32 = Sum_(over k)[(X_(i, k)_fp32 - X_zp) * X_scale * (W_(i, k)_fp32 - W_zp) * W_scale] + bias_(i)_fp32 + # Factor out X_scale and W_scale + # Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32 + # In order to addition of bias_(i)_fp32 inside, we must do + # Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale # noqa: B950 + # Note we had to multiply bias_fp32 with X_scale * W_scale = bias_scale + # Thus bias quantization to int32 must be with X_scale * W_scale + + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + # Unsqueeze to match broadcast dims + # Unfortnuately I cannot do bias_i32.unsqueeze(0) due to literal matching nightmare + # in graph pattern replacement + bias_i32 = bias_i32.unsqueeze(-1) + bias_i32 = bias_i32.unsqueeze(-1) + acc_i32 = acc_i32 + bias_i32 + # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values + acc_i32 = ( + out_dtype( + torch.ops.aten.mul.Tensor, + torch.int32, + acc_i32, + x_scale * weight_scale / out_scale, + ) + + out_zero_point + ) + out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8) + return out_i8 + + +def _qdq_quantized_add_relu( + x_i8, + x_scale, + x_zero_point, + y_i8, + y_scale, + y_zero_point, + out_scale, + out_zero_point, + quant_min, + quant_max, +): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8 + ) + y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8 + ) + out_fp32 = x_fp32 + y_fp32 + out_fp32 = torch.ops.aten.relu(out_fp32) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_add_relu( + x_i8, + x_scale, + x_zero_point, + y_i8, + y_scale, + y_zero_point, + out_scale, + out_zero_point, + quant_min, + quant_max, +): + """ + See comments for `_reference_quantized_add` for more information on + how to derive the formula for out_i8 based on x_i8 and y_i8 + """ + x_i32 = x_i8.to(torch.int32) + y_i32 = y_i8.to(torch.int32) + # TODO: change this to mul.Scalar? + x_i32 = out_dtype( + torch.ops.aten.mul.Tensor, + torch.int32, + (x_i32 - x_zero_point), + (x_scale / out_scale), + ) + y_i32 = out_dtype( + torch.ops.aten.mul.Tensor, + torch.int32, + (y_i32 - y_zero_point), + (y_scale / out_scale), + ) + out_i32 = x_i32 + y_i32 + out_zero_point + # out_i32 = torch.ops.aten.clamp(out_i32, out_zero_point) + out_i8 = torch.ops.aten.clamp(out_i32, out_zero_point, quant_max).to(torch.int8) + return out_i8 + + +def _qdq_quantized_add( + x_i8, + x_scale, + x_zero_point, + y_i8, + y_scale, + y_zero_point, + out_scale, + out_zero_point, + quant_min, + quant_max, +): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8 + ) + y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8 + ) + out_fp32 = x_fp32 + y_fp32 + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_add( + x_i8, + x_scale, + x_zero_point, + y_i8, + y_scale, + y_zero_point, + out_scale, + out_zero_point, + quant_min, + quant_max, +): + """ + # How to Derive the formula for out_i8 based on x_i8 and y_i8 + # (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8) + + # out_i8 is quantized output, we can write down the formula for it first: + out_i8 = out_f32 / out_scale + out_zero_point (1) + + # then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8 + out_f32 = x_f32 + y_f32 (2) + x_fp32 = (x_i8 - x_zero_point) * x_scale (3) + y_fp32 = (y_i8 - y_zero_point) * y_scale (4) + + # applying the above formula to the out_i8 equation we can get the following: + out_i8 = out_fp32 / out_scale + out_zero_point # (1) + = (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32 + = ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4) + """ + x_i32 = x_i8.to(torch.int32) + y_i32 = y_i8.to(torch.int32) + # TODO: use out_dtype op + x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32) + y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32) + out_i32 = x_i32 + y_i32 + out_zero_point + quant_min = -128 + quant_max = 127 + out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8) + return out_i8 + + +def _qdq_quantized_max_pool2d( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + kernel_size = 1 + stride = 1 + padding = 0 + dilation = 1 + ceil_mode = False + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + out_fp32, _ = torch.ops.aten.max_pool2d_with_indices.default( + x_fp32, kernel_size, stride, padding, dilation, ceil_mode + ) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_max_pool2d( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + kernel_size = 1 + stride = 1 + padding = 0 + dilation = 1 + ceil_mode = False + # to preserve x_quant_min, x_quant_max in the graph for pattern matching + x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max) + x_i32 = x_i8.to(torch.int32) + out_i32, _ = torch.ops.aten.max_pool2d_with_indices.default( + x_i32 - x_zero_point, kernel_size, stride, padding, dilation, ceil_mode + ) + out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point + out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max) + out_i8 = out_fp32.to(torch.int8) + return out_i8 + + +def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max): + x = torch.ops.quantized_decomposed.quantize_per_tensor( + x_fp32, scale, zero_point, quant_min, quant_max, torch.int8 + ) + return x + + +def _reference_quantize_per_tensor_int8( + x_fp32, scale, zero_point, quant_min, quant_max +): + # TODO: use out_dtype(mul, ...) here when the op is ready + x = x_fp32 / scale # fp32 + # round modes might be different here + # pytorch is rounding to even, which is also common for most of the backends + x = torch.round(x) # fp32 + x = x.to(dtype=torch.int32) # int32 + x = x + zero_point # int32 + # clamp works for fp32, int32 and int8 dtypes + x = torch.clamp(x, quant_min, quant_max) # int32 + x = x.to(dtype=torch.int8) + return x + + +def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, scale, zero_point, quant_min, quant_max, torch.int8 + ) + return x_fp32 + + +def _reference_dequantize_per_tensor_int8( + x_i8, scale, zero_point, quant_min, quant_max +): + # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. + # This results in failure to match the pattern. + # Therefore, we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max) + # TODO: use out_dtype op + # note: x_i8.to(torch.int32) does not work here + # TODO: debug the implementation later when torchdynamo time out issue is resolved + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + + +def _quantize_per_channel_int8( + x_fp32, scales, zero_points, ch_axis, quant_min, quant_max +): + out_i8 = torch.ops.quantized_decomposed.quantize_per_channel( + x_fp32, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantize_per_channel_int8( + x_fp32, scales, zero_points, ch_axis, quant_min, quant_max +): + x_fp32 = torch.transpose(x_fp32, ch_axis, -1) + out_i32 = torch.ops.aten.clamp( + torch.round(x_fp32 / scales).to(torch.int32) + zero_points, quant_min, quant_max + ) + out_i32 = torch.transpose(out_i32, ch_axis, -1) + return out_i32.to(torch.int8) + + +def _dequantize_per_channel_int8( + x_i8, scales, zero_points, ch_axis, quant_min, quant_max +): + # the following will be replaced as placeholders + out_fp32 = torch.ops.quantized_decomposed.dequantize_per_channel( + x_i8, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8 + ) + return out_fp32 + + +def _reference_dequantize_per_channel_int8( + x_i8, scales, zero_points, ch_axis, quant_min, quant_max +): + # the following will be replaced as placeholders + # in order to preserve the quant_min/quant_max args for pattern matching (e.g. matching for int4 quantized ops) + # we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max) + x_i8 = torch.transpose(x_i8, ch_axis, -1) + x_i32 = x_i8.to(torch.int32) + out_fp32 = (x_i32 - zero_points).to(torch.float) * scales + out_fp32 = torch.transpose(out_fp32, ch_axis, -1) + return out_fp32 + + +def _replace_ph_qdq_per_channel_replacement(gm: torch.fx.GraphModule): + return _replace_literals_with_existing_placeholders( + gm, exclude_literals=[-1], literal_to_ph_idx={1: 3, -128: 4, 127: 5} + ) + + +@dataclass +class _RewriteInfo: + """Data needed for rewrite, this includes example inputs, pattern and replacement functions + and post transformation functions for the exported pattern and replacement GraphModule + """ + + # example inputs used for exporting the pattern into GraphModule + example_inputs: tuple[Any, ...] + pattern: Callable + replacement: Callable + # post transformation on the exported pattern and replacement GraphModule + pattern_post_trans: Callable[[GraphModule], GraphModule] | None = None + replacement_post_trans: Callable[[GraphModule], GraphModule] | None = None + + +def reference_representation_rewrite(model: GraphModule) -> GraphModule: + _QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (2, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randint(-128, 127, (5, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( + torch.randn((2, 5), dtype=torch.float), + -128, + 127, + torch.finfo(torch.float32).eps, + torch.randint(-128, 127, (5, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + ) + + _QUANTIZED_CONV2d_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( + torch.randn(1, 3, 3, 3, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( + torch.randn(1, 3, 3, 3, dtype=torch.float), + torch.randn(3, dtype=torch.float), + torch.zeros(3, dtype=torch.int), + 1, + -128, + 127, + ) + + _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(3, dtype=torch.float), + torch.zeros(3, dtype=torch.int), + 1, + -128, + 127, + ) + + _REWRITE_INFO_LIST = [ + _RewriteInfo( + _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS, + _WrapperModule(_qdq_dynamic_quantized_linear), + _WrapperModule(_reference_dynamic_quantized_linear), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, + ), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, + ), + ), + _RewriteInfo( + _QUANTIZED_LINEAR_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_linear), + _WrapperModule(_reference_quantized_linear), + _replace_literals_with_new_placeholders, + _replace_literals_with_new_placeholders, + ), + _RewriteInfo( + _QUANTIZED_CONV2d_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_conv2d), + _WrapperModule(_reference_quantized_conv2d), + partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), + partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), + ), + _RewriteInfo( + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_add_relu), + _WrapperModule(_reference_quantized_add_relu), + ), + _RewriteInfo( + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_add), + _WrapperModule(_reference_quantized_add), + ), + _RewriteInfo( + _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_max_pool2d), + _WrapperModule(_reference_quantized_max_pool2d), + _replace_literals_with_new_placeholders, + _replace_literals_with_new_placeholders, + ), + _RewriteInfo( + _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, + _WrapperModule(_quantize_per_tensor_int8), + _WrapperModule(_reference_quantize_per_tensor_int8), + ), + _RewriteInfo( + _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, + _WrapperModule(_dequantize_per_tensor_int8), + _WrapperModule(_reference_dequantize_per_tensor_int8), + ), + _RewriteInfo( + _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, + _WrapperModule(_quantize_per_channel_int8), + _WrapperModule(_reference_quantize_per_channel_int8), + _replace_ph_qdq_per_channel_replacement, + _replace_ph_qdq_per_channel_replacement, + ), + _RewriteInfo( + _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, + _WrapperModule(_dequantize_per_channel_int8), + _WrapperModule(_reference_dequantize_per_channel_int8), + _replace_ph_qdq_per_channel_replacement, + _replace_ph_qdq_per_channel_replacement, + ), + ] + + remove_tensor_overload_for_qdq_ops(model) + + with _disable_aten_to_metadata_assertions(): + for rewrite_info in _REWRITE_INFO_LIST: + example_inputs = rewrite_info.example_inputs + pattern = rewrite_info.pattern + replacement = rewrite_info.replacement + pattern_post_trans = rewrite_info.pattern_post_trans + replacement_post_trans = rewrite_info.replacement_post_trans + pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment] + remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type] + replacement = _get_aten_graph_module_for_pattern( # type: ignore[assignment] + replacement, + example_inputs, # type: ignore[arg-type] + ) + remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type] + if pattern_post_trans: + pattern = pattern_post_trans(pattern) + if replacement_post_trans: + replacement = replacement_post_trans(replacement) + pattern.recompile() # type: ignore[attr-defined] + replacement.recompile() # type: ignore[attr-defined] + replace_pattern(model, pattern, replacement) + + return model diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..69a74ea6a0dfaf541a7617a16419013cce597bdd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/utils.py @@ -0,0 +1,625 @@ +# mypy: allow-untyped-defs +import operator +import types +from collections.abc import Callable +from typing import Any + +import torch +import torch.ao.quantization.pt2e._affine_quantization # noqa: F401 +import torch.nn.functional as F +import torch.utils._pytree as pytree + +# Makes sure that quantized_decomposed ops are registered +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.ao.quantization.quantizer import QuantizationAnnotation +from torch.export.unflatten import _assign_attr, _AttrKind +from torch.fx import GraphModule, Node +from torch.nn.utils.fusion import fuse_conv_bn_weights + + +__all__ = [ + "fold_bn_weights_into_conv_node", + "remove_tensor_overload_for_qdq_ops", +] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool: + """ + Assuming dest is one of the ops inserted by quant workflow, this function + finds if source and dest are connected. Assumption is that only quant workflow + inserted ops exist between source and dest + """ + quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS + quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor) + while dest.target in quant_workflow_ops: + if not isinstance(dest.args[0], torch.fx.Node): + raise ValueError( + f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}" + ) + dest = dest.args[0] + return dest == source + + +def _find_q_dq_node_for_user( + produer: torch.fx.Node, user: torch.fx.Node +) -> tuple[Any, Any]: + """ + Find q, dq pair corresponding to [producer -> q -> dq -> user] + Utils works by finding dq arg of user and ensuring it is connected to + producer + """ + dq_node = None + for n in user.args: + if ( + isinstance(n, torch.fx.Node) + and n.op == "call_function" + and n.target in _DEQUANTIZE_OPS + ): + if _is_connected(produer, n): + dq_node = n + break + if dq_node is None: + for n in user.kwargs: + if ( + isinstance(n, torch.fx.Node) + and n.op == "call_function" + and n.target in _DEQUANTIZE_OPS + ): + if _is_connected(produer, n): + dq_node = n + break + if dq_node is None: + return (None, None) + + q_node = None + if ( + isinstance(arg := dq_node.args[0], torch.fx.Node) + and arg.op == "call_function" + and arg.target in _QUANTIZE_OPS + ): + q_node = arg + return (q_node, dq_node) + + +def _is_sym_size_node(node: Node): + return ( + node.op == "call_function" + and node.target is torch.ops.aten.sym_size.default + or node.target is torch.ops.aten.sym_numel.default + or node.target is torch.ops.aten.sym_numel + or node.target is torch.ops.aten.sym_size + ) + + +def _filter_sym_size_users(node: torch.fx.Node) -> list[torch.fx.Node]: + node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users)) + return node_users + + +def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool: + if annotation is None: + return False + input_qspec_map = annotation.input_qspec_map + output_qspec = annotation.output_qspec + if len(input_qspec_map) == 0 and output_qspec is None: + return False + return True + + +def _get_tensor_constant_from_node(node, m): + if node is None: + return None + if node.op != "get_attr": + raise AssertionError(f"Expected node.op to be 'get_attr', got {node.op}") + target_atoms = node.target.split(".") + attr_itr = m + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def _get_all_arguments(orig_args, orig_kwargs, args_schema): + all_args = [] + for i, schema in enumerate(args_schema): + if schema.name in orig_kwargs: + all_args.append(orig_kwargs[schema.name]) + elif not schema.kwarg_only and i < len(orig_args): + all_args.append(orig_args[i]) + else: + all_args.append(schema.default_value) + return all_args + + +def _is_supported_batch_norm_for_training(node: Node): + """ + Return True if the given node refers to an aten batch norm op QAT supports. + """ + supported_ops = [ + torch.ops.aten.batch_norm.default, + torch.ops.aten._native_batch_norm_legit.default, + # Note: we won't need this op anymore after batch norm consolidation + # For now, we need to continue to support it because it gives better + # training numerics than `_native_batch_norm_legit` + torch.ops.aten.cudnn_batch_norm.default, + torch.ops.aten.miopen_batch_norm.default, + ] + return node.target in supported_ops + + +# TODO: move this to torch/ao/quantization/utils.py +def _is_conv_node(n: Node): + """ + Return whether the node refers to an aten conv op. + """ + return n.op == "call_function" and n.target in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv1d.padding, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv3d.padding, + ] + + +def _is_conv_transpose_node(n: Node): + """ + Return whether the node refers to an aten conv_transpose op. + """ + return n.op == "call_function" and n.target in [ + torch.ops.aten.conv_transpose1d, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d, + torch.ops.aten.conv_transpose2d.input, + ] + + +def _is_conv_or_conv_transpose_node(n: Node): + """ + Return whether the node refers to an aten conv or conv transpose op. + """ + return _is_conv_node(n) or _is_conv_transpose_node(n) + + +def _is_conv_transpose_fn(conv_fn: Callable): + return conv_fn in [F.conv_transpose1d, F.conv_transpose2d] + + +def _is_bn_node(n: Node): + return ( + _is_supported_batch_norm_for_training(n) + or n.target is torch.ops.aten._native_batch_norm_legit_no_training.default + ) + + +def fold_bn_weights_into_conv_node( + conv_node: Node, + conv_weight_node: Node, + conv_bias_node: Node | None, + bn_node: Node, + m: GraphModule, +) -> None: + # conv args: input, weight, bias, stride, padding, dilation, ... + conv_w = _get_tensor_constant_from_node(conv_weight_node, m) + conv_b = _get_tensor_constant_from_node(conv_bias_node, m) + transpose = _is_conv_transpose_node(conv_node) + + # eval bn args: input, weight, bias, running mean, running var, momentum, eps + # train bn args: input, weight, bias, running mean, running var, training, momentum, eps + bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr] + bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema) + bn_w = _get_tensor_constant_from_node(bn_args[1], m) + bn_b = _get_tensor_constant_from_node(bn_args[2], m) + bn_rm = _get_tensor_constant_from_node(bn_args[3], m) + bn_rv = _get_tensor_constant_from_node(bn_args[4], m) + if bn_node.target is torch.ops.aten._native_batch_norm_legit_no_training.default: + eps_arg_index = 6 + elif _is_supported_batch_norm_for_training(bn_node): + eps_arg_index = 7 + else: + raise ValueError("BN node target is unexpected ", bn_node.target) + bn_eps = bn_args[eps_arg_index] + + fused_weight, fused_bias = fuse_conv_bn_weights( + conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose + ) + + # update the weight and bias for conv + conv_args = list(conv_node.args) + # filling in the default bias argument + if len(conv_args) == 2: + conv_args.append(None) + + # calling data since the fused_weight and fused_bias are nn.Parameter + weight_attr_name = conv_weight_node.target + if not isinstance(weight_attr_name, str): + raise AssertionError( + f"Expected conv_weight_node.target to be a string attribute name, got {type(weight_attr_name)}" + ) + _assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER) + if conv_bias_node is not None: + bias_attr_name = conv_bias_node.target + _assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER) + else: + bias_attr_name = weight_attr_name + "_bias" + _assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER) + with m.graph.inserting_before(conv_node): + get_bias_node = m.graph.get_attr(bias_attr_name) + # NOTE: here we assume the bias of conv is not quantized! + conv_args[2] = get_bias_node + conv_node.args = tuple(conv_args) + + # native_batch_norm has 3 outputs, we expect getitem calls on the output + # and we want to replace the uses of getitem 0 with the output of conv + # + if bn_node.target is torch.ops.aten.batch_norm.default: + # With the new training ir, instead of batch_norm + getitem, + # we only have the batch_norm node. + # + # Before: + # conv -> bn -> users + # After: + # conv -> users + # bn has no users now + bn_node.replace_all_uses_with(conv_node) + else: + # Before: + # conv -> bn - (first output) -> users1 + # \ - (second output) -> users2 + # \ - (third output) -> users3 + # After: + # conv -> (first output) -> users1 + # bn - + # \ - (second output) -> users2 + # \ - (third output) -> users3 + # if users2 and users3 are empty then bn will be removed through dead code elimination + for user in bn_node.users: + if ( + user.op != "call_function" + or user.target != operator.getitem + or user.args[1] != 0 + ): + continue + user.replace_all_uses_with(conv_node) + + # If the BN node does not have users, erase it from the graph + # Note: we need to do this manually because the model can still be in train + # mode at this point, in which case DCE won't erase the BN node automatically + # since the node refers to a mutating op. Here we still need to call DCE first + # to get rid of the unused getitem nodes that consume the BN node. + m.graph.eliminate_dead_code() + if len(bn_node.users) == 0: + m.graph.erase_node(bn_node) + + +# fuse conv bn weights, inplace modification of the graph_module and graph +def _fuse_conv_bn_(m: GraphModule) -> None: + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return + for n in m.graph.nodes: + if n.op != "call_function" or n.target not in ( + torch.ops.aten._native_batch_norm_legit_no_training.default, + torch.ops.aten.batch_norm.default, + ): + continue + bn_node = n + n = bn_node.args[0] + if not _is_conv_or_conv_transpose_node(n): + continue + conv_node = n + conv_weight_node = conv_node.args[1] + conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None + fold_bn_weights_into_conv_node( + conv_node, conv_weight_node, conv_bias_node, bn_node, m + ) + + m.graph.eliminate_dead_code() + m.recompile() + + +def _get_node_name_to_scope(model: GraphModule) -> dict[str, tuple[str, type]]: + # TODO: move this information to fx node itself + node_name_to_scope: dict[str, tuple[str, type]] = {} + for n in model.graph.nodes: + nn_module_stack = n.meta.get("nn_module_stack", None) + current_scope = ("", type(None)) + if nn_module_stack: + bt = list(nn_module_stack.values())[-1] + current_scope = (bt[0].split(".")[-1], bt[1]) + node_name_to_scope[n.name] = current_scope + return node_name_to_scope + + +def _get_aten_graph_module_for_pattern( + pattern: Callable, + example_inputs: tuple[Any, ...], + is_cuda: bool = False, + **kwargs, +) -> GraphModule: + """ + Convert the pattern to an FX graph with decomposed aten ops. + """ + if is_cuda: + example_inputs = tuple( + x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs + ) + + with torch._export.config.patch(use_new_tracer_experimental=True): + aten_pattern = torch.export.export( + pattern, # type: ignore[arg-type] + example_inputs, + kwargs, + strict=True, + ).module(check_guards=False) + + aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] + aten_pattern.recompile() # type: ignore[operator] + + # ep.module() adds copy_ nodes for the mutated inputs. + # For patterns, it doesn't matter + for node in aten_pattern.graph.nodes: # type: ignore[union-attr] + if ( + node.op == "call_function" + and node.target is torch.ops.aten.copy_.default + and len(node.users) == 0 + ): + aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr] + + aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] + aten_pattern.recompile() # type: ignore[operator] + + return aten_pattern # type: ignore[return-value] + + +def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None: + """Remove .tensor overload for quantize/dequantize ops so that we can + use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e + """ + _MAP = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel, + torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel, + torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp, + } + for n in match_pattern.graph.nodes: + if n.op != "call_function": + continue + if n.target in _MAP: + n.target = _MAP[n.target] + + +def _is_literal(arg): + if isinstance(arg, (int, float)): + return True + if isinstance(arg, (tuple, list)): + return all(map(_is_literal, arg)) + return False + + +def _replace_literals_with_new_placeholders( + gm: torch.fx.GraphModule, + merge_dup: bool = False, + exclude_literals: list[Any] | None = None, +): + """Replace the literals in the graph with placeholder nodes that's created on the fly while we + traverse the graph, so that the literal arguments in the graph can be matched and replaced + + To use this, the pattern and replacement graph should have the exact same number of literal args + and they should be used in the exact same order in the pattern and replacement graph. + + If the literal arguments are not used in the same order in pattern and replacement graph, please + use `_replace_literals_with_existing_placeholders` instead + + Args: + `gm`: input GraphModule that we'll transform + `merge_dup`: boolean flag to indicate that if the same literal appears multiple times in + the graph, whether they should correspond to the same placeholder or not + `exclude_literals`: a list of literals that will not be replaced with placeholders + + Example: + + # 1. Original Graph + def pattern(self, x): + return x + 3 + + def replacement(self, x): + return x - 3 + + example_inputs = (torch.randn(1, 3, 3, 3),) + pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs) + replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus) + + # 2. Before calling replace literals we'll see the following graph: + def pattern(self, x): + return x + 3 + + def replacement(self, x): + return x - 3 + + pattern_gm = _replace_literals_with_new_placeholders(pattern_gm) + replacement_gm = _replace_literals_with_new_placeholders(replacement_gm) + + # 3. After replacing literals with new placeholder nodes + + def pattern(self, x, new_ph): + return x + new_ph + + def pattern(self, x, new_ph): + return x - new_ph + + """ + last_ph = None + cnt = 0 + literal_to_ph: dict[float | bool | int | torch.dtype, Node] = {} + if exclude_literals is None: + exclude_literals = [] + + in_spec = gm._in_spec + assert in_spec.type is tuple + args_spec = in_spec.child(0) + assert args_spec.type is tuple + args_spec_children = args_spec.children() + for node in gm.graph.nodes: + if node.op == "placeholder": + last_ph = node + cnt += 1 + continue + with gm.graph.inserting_after(last_ph): + new_args = [] + for arg in node.args: + if _is_literal(arg) and arg not in exclude_literals: + if merge_dup and arg in literal_to_ph: + new_args.append(literal_to_ph[arg]) + else: + ph_node = gm.graph.placeholder("arg" + str(cnt)) + new_args.append(ph_node) + args_spec_children.append(pytree.treespec_leaf()) + cnt += 1 + if merge_dup: + literal_to_ph[arg] = ph_node + else: + new_args.append(arg) + new_args = tuple(new_args) + + node.args = new_args + + # Update `num_nodes`, `num_leaves`, `num_children`. + args_spec = pytree.treespec_tuple(args_spec_children) + gm._in_spec = in_spec = pytree.treespec_tuple([args_spec, *in_spec.children()[1:]]) + return gm + + +def _replace_literals_with_existing_placeholders( + gm: torch.fx.GraphModule, + exclude_literals: list[Any] | None = None, + literal_to_ph_idx: dict[float | int | bool | torch.dtype, int] | None = None, +): + """Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments + in the graph can be matched and replaced + + To use this, all literal args in the graph should be unique and each of them should correspond + to exactly one placeholder node + + # 1. Original Graph + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + x_i8 = torch.clamp(x_i8, quant_min, quant_max) + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + + example_inputs = ( + torch.randn(1, 3, 3, 3), + 1.0, + 0, + -128, + 127, + ) + pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs) + replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus) + + # 2. Before calling replace literals we'll see the following graph: + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + x_i8 = torch.clamp(x_i8, -128, 127) + return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32) + + # Note that literal args appear in different order in pattern and replacement graph, so + # we can't use _replace_literals_with_new_placeholders + + literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4} + pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx) + replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx) + + # 3. After replacing literals with existing placeholder nodes + + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + x_i8 = torch.clamp(x_i8, quant_min, quant_max) + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + """ + if exclude_literals is None: + exclude_literals = [] + + if literal_to_ph_idx is None: + literal_to_ph_idx = {} + + phs = [node for node in gm.graph.nodes if node.op == "placeholder"] + + for node in gm.graph.nodes: + if node.op != "call_function": + continue + new_args = [] + for arg in node.args: + if ( + _is_literal(arg) + and arg not in exclude_literals + and arg in literal_to_ph_idx + ): + ph_idx = literal_to_ph_idx[arg] + ph_node = phs[ph_idx] + new_args.append(ph_node) + else: + new_args.append(arg) + new_args = tuple(new_args) + node.args = new_args + return gm + + +# TODO: Handle this in export itself and don't wrap the model in another GraphModule +# in prepare and convert +def _disallow_eval_train(model: GraphModule): + """ + Disallow calling `model.train()` or `model.eval()` on the given GraphModule. + This is useful for exported models, where these methods don't actually behave as expected. + """ + error_message = """ + Calling train() or eval() is not supported for exported models. + Please call `torch.ao.quantization.move_exported_model_to_train(model)` (or eval) instead. + + If you cannot replace the calls to `model.train()` and `model.eval()`, you may override + the behavior for these methods by calling `torch.ao.quantization.allow_exported_model_train_eval(model)`, + which does the above automatically for you. Note that this has limited effect on switching + behavior between train and eval modes, and should be used only for special ops such as dropout + and batchnorm. + """ + + def _train(self, mode: bool = True): + raise NotImplementedError(error_message) + + def _eval(self, mode: bool = True): + raise NotImplementedError(error_message) + + model.train = types.MethodType(_train, model) # type: ignore[method-assign] + model.eval = types.MethodType(_eval, model) # type: ignore[method-assign] + return model diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5cd5e8696d39781004960f47e6f44d3b1987ff4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__init__.py @@ -0,0 +1,22 @@ +from .quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + FixedQParamsQuantizationSpec, + QuantizationAnnotation, + QuantizationSpec, + QuantizationSpecBase, + Quantizer, + SharedQuantizationSpec, +) + + +__all__ = [ + "EdgeOrNode", + "Quantizer", + "QuantizationSpecBase", + "QuantizationSpec", + "FixedQParamsQuantizationSpec", + "SharedQuantizationSpec", + "DerivedQuantizationSpec", + "QuantizationAnnotation", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..15404cc560117713bf8c952f594c051b1c13e3a4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .quantizer import QuantizationAnnotation, Quantizer + + +if TYPE_CHECKING: + import torch + from torch.fx import Node + +__all__ = [ + "ComposableQuantizer", +] + + +class ComposableQuantizer(Quantizer): + """ + ComposableQuantizer allows users to combine more than one quantizer into a single quantizer. + This allows users to quantize a model with multiple quantizers. E.g., embedding quantization + maybe supported by one quantizer while linear layers and other ops might be supported by another + quantizer. + + ComposableQuantizer is initialized with a list of `Quantizer` instances. + The order of the composition matters since that is the order in which the quantizers will be + applies. + Example: + ``` + embedding_quantizer = EmbeddingQuantizer() + linear_quantizer = MyLinearQuantizer() + xnnpack_quantizer = ( + XNNPackQuantizer() + ) # to handle ops not quantized by previous two quantizers + composed_quantizer = ComposableQuantizer( + [embedding_quantizer, linear_quantizer, xnnpack_quantizer] + ) + prepared_m = prepare_pt2e(model, composed_quantizer) + ``` + """ + + def __init__(self, quantizers: list[Quantizer]): + super().__init__() + self.quantizers = quantizers + self._graph_annotations: dict[Node, QuantizationAnnotation] = {} + + def _record_and_validate_annotations( + self, gm: torch.fx.GraphModule, quantizer: Quantizer + ) -> None: + for n in gm.graph.nodes: + if "quantization_annotation" in n.meta: + # check if the annotation has been changed by + # comparing QuantizationAnnotation object id + if n in self._graph_annotations and ( + id(self._graph_annotations[n]) + != id(n.meta["quantization_annotation"]) + ): + raise RuntimeError( + f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}" + ) + else: + self._graph_annotations[n] = n.meta["quantization_annotation"] + else: + if n in self._graph_annotations: + raise RuntimeError( + f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}" + ) + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + for quantizer in self.quantizers: + quantizer.annotate(model) + self._record_and_validate_annotations(model, quantizer) + return model + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + for quantizer in self.quantizers: + model = quantizer.transform_for_annotation(model) + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3b8ef1030bfdcdeb88b58179f4f2ea83c895aad2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py @@ -0,0 +1,94 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import copy + +import torch +import torch.nn.functional as F +from torch.ao.quantization.observer import PerChannelMinMaxObserver +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, +) +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + OperatorConfig, + OperatorPatternType, + QuantizationConfig, +) + + +__all__ = [ + "get_embedding_operators_config", + "EmbeddingQuantizer", +] + + +def get_embedding_operators_config() -> OperatorConfig: + weight_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(eps=2**-12), + ) + quantization_config = QuantizationConfig(None, None, weight_quantization_spec, None) + ops: list[OperatorPatternType] = [[torch.nn.Embedding]] + ops.append([F.embedding]) + supported_config_and_operators = OperatorConfig( + config=quantization_config, operators=ops + ) + return copy.deepcopy(supported_config_and_operators) + + +class EmbeddingQuantizer(Quantizer): + @classmethod + def get_supported_quantization_configs(cls) -> list[QuantizationConfig]: + op_configs: set[QuantizationConfig] = { + spec for spec, _ in cls.get_supported_operators() + } + return list(op_configs) + + @classmethod + def get_supported_operator_for_quantization_config( + cls, quantization_config: QuantizationConfig + ) -> list[OperatorPatternType]: + for config, ops in cls.get_supported_operators(): + # note: this assumes each entry in cls.supported_spec_and_operators + # corresponds to one spec, e.g. we don't have + # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] + # where the first and second entry have the same spec but did not + # merge the op list + if config == quantization_config: + return ops + return [] + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + self._annotate_embedding_ops(model.graph) + return model + + def _annotate_embedding_ops(self, graph: torch.fx.Graph) -> None: + embedding_config: OperatorConfig = get_embedding_operators_config() + for node in graph.nodes: + # Keep node parsing based annotations instead of module partitioners + # just as an example of alternate ways of annotating + if ( + node.op == "call_function" + and node.target is torch.ops.aten.embedding.default + ): + if embedding_config.config.weight is None: + raise ValueError( + "Embedding config must have a valid weight quantization spec." + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + node.args[0]: embedding_config.config.weight, + } + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> list[OperatorConfig]: + return [get_embedding_operators_config()] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/quantizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e10526b4cc4ca58d099523d32ebd57a393a1dd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/quantizer.py @@ -0,0 +1,182 @@ +# mypy: allow-untyped-defs +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Annotated + +import torch +from torch import Tensor +from torch.ao.quantization import ObserverOrFakeQuantize +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor +from torch.fx import Node + + +__all__ = [ + "Quantizer", + "QuantizationSpecBase", + "QuantizationSpec", + "FixedQParamsQuantizationSpec", + "EdgeOrNode", + "SharedQuantizationSpec", + "DerivedQuantizationSpec", + "QuantizationAnnotation", +] + + +class QuantizationSpecBase(ABC): # noqa: B024 + """Base class for different types of quantization specs that allows users to + specify how to quantize a Tensor (input/output of a Node) in the model + """ + + +@dataclass(eq=True, frozen=True) +class QuantizationSpec(QuantizationSpecBase): + """Quantization spec for common operators that allows user to specify how to + quantize a Tensor, this includes dtype, quant_min, quant_max etc. + """ + + dtype: torch.dtype + # observer or fake_quantize constructor such as + # MinMaxObserver, PerChannelHistogramObserver etc. + # or we can attach some custom args to them + # e.g. MinMaxObserver.with_args(eps=eps) + observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor + quant_min: int | None = None + quant_max: int | None = None + qscheme: torch.qscheme | None = None + ch_axis: int | None = None + is_dynamic: bool = False + + def __post_init__(self): + # TODO: add init for quant_min/quant_max + # quant_min must be less than quant_max + if ( + self.quant_min is not None + and self.quant_max is not None + and self.quant_min > self.quant_max + ): + raise ValueError( + f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}." + ) + + # ch_axis must be less than the number of channels + # but no way to check here. Just check that it is not < 0. + if self.ch_axis is not None and self.ch_axis < 0: + raise ValueError("Ch_axis is < 0.") + + +@dataclass(eq=True, frozen=True) +class FixedQParamsQuantizationSpec(QuantizationSpecBase): + dtype: torch.dtype + scale: float + zero_point: int + quant_min: int | None = None + quant_max: int | None = None + qscheme: torch.qscheme | None = None + is_dynamic: bool = False + + +""" +The way we refer to other points of quantization in the graph will be either +an input edge or an output value +input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node] +output value is an fx Node +""" +EdgeOrNode = Annotated[tuple[Node, Node] | Node, None] +EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer" + + +@dataclass(eq=True, frozen=True) +class SharedQuantizationSpec(QuantizationSpecBase): + """ + Quantization spec for the Tensors whose quantization parameters are shared with other Tensors + """ + + # the edge or node to share observer or fake quant instances with + edge_or_node: EdgeOrNode + + +@dataclass(eq=True, frozen=True) +class DerivedQuantizationSpec(QuantizationSpecBase): + """Quantization spec for the Tensors whose quantization parameters are derived from other Tensors""" + + derived_from: list[EdgeOrNode] + derive_qparams_fn: Callable[[list[ObserverOrFakeQuantize]], tuple[Tensor, Tensor]] + dtype: torch.dtype + quant_min: int | None = None + quant_max: int | None = None + qscheme: torch.qscheme | None = None + ch_axis: int | None = None + is_dynamic: bool = False + + +@dataclass +class QuantizationAnnotation: + """How are input argument or output should be quantized, + expressed as QuantizationSpec, this corresponds to how a Tensor in the + operator Graph is observed (PTQ) or fake quantized (QAT) + """ + + # a map from torch.fx.Node to a type of QuantizationSpecBase + input_qspec_map: dict[Node, QuantizationSpecBase | None] = field( + default_factory=dict + ) + + # How the output of this node is quantized, expressed as QuantizationSpec + # TODO: change the value to QuantizationSpec in a separate PR + output_qspec: QuantizationSpecBase | None = None + + # For a Node: node1 and edge: (node1, node2), since they are observing the same + # Tensor, we may want to implicitly share observers, this flag allows people to + # turn off this behavior for the output of the node + allow_implicit_sharing: bool = True + + # whether the node is annotated or not + _annotated: bool = False + + +class Quantizer(ABC): + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + """Allows for user defined transforms to run before annotating the graph. + This allows quantizer to allow quantizing part of the model that are otherwise not quantizable. + For example quantizer can + a) decompose a compound operator like scaled dot product attention, + into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa + or b) transform scalars to tensor to allow quantizing scalares. + + Note: this is an optional method + """ + return model + + # annotate nodes in the graph with observer or fake quant constructors + # to convey the desired way of quantization + @abstractmethod + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + pass + + # validate the annotated graph is supported by the backend + @abstractmethod + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + def prepare_obs_or_fq_callback( + self, + model: torch.fx.GraphModule, + edge_or_node_to_obs_or_fq: dict[EdgeOrNode, ObserverOrFakeQuantize], + ) -> None: + """A callback that will be called after the observers or fake quants are created + for each sharing group, but before they are inserted into the graph. The + callback can be used to make final quantization adjustments, such as enforcing + specific scale and zero point on model input or output. + + Args: + * `model`: the graph module being prepared. + * `edge_or_node_to_obs_or_fq`: a dictionary mapping each annotated edge and + node to the corresponding observer or fake quant object. Note that multiple + edges and/or nodes can map to the same observer / fake quant instance if + they were annotated with SharedQuantizationSpec. This dictionary can be + modified by the callback. + """ + return diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..06463ae0f2f3adb815d34b0f539fb6cde423e1ab --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/utils.py @@ -0,0 +1,90 @@ +from collections.abc import Callable + +from torch.ao.quantization.pt2e.utils import _is_sym_size_node +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationAnnotation, + QuantizationSpecBase, +) +from torch.fx import Node + + +__all__: list[str] = [] + + +def _annotate_input_qspec_map( + node: Node, input_node: Node, qspec: QuantizationSpecBase | None +) -> None: + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + quantization_annotation.input_qspec_map[input_node] = qspec + node.meta["quantization_annotation"] = quantization_annotation + + +def _annotate_output_qspec(node: Node, qspec: QuantizationSpecBase | None) -> None: + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + quantization_annotation.output_qspec = qspec + node.meta["quantization_annotation"] = quantization_annotation + + +def _node_only_used_for_sym_size(node: Node, partition_nodes: list[Node]) -> bool: + """ + This utility is used to handle cases when dynami_shape=True tracing leads + to symint nodes in the pattern of linear module. In those cases, we need to + distinguish between the nodes that are in input for just extracting value of + some dimensions (and symint nodes) vs. the one that is activation. + For example: + graph(x, y, weight): + size_0 = torch.ops.aten.sym_size([x], [0]) + size_1 = torch.ops.aten.sym_size([y], [1]) + view_size = size_0 * size_1 + size_3 = torch.ops.aten.sym_size([x], [2]) + vie_out = torch.ops.aten.view(x, [view_size, size_3]) + return mm(view_out, weight) + In the example above y node is not actual input. It exist only to extract size_1 + """ + if _is_sym_size_node(node): + return True + + return all( + ((user not in partition_nodes) or _is_sym_size_node(user)) + for user in node.users + ) + + +def _get_module_name_filter(module_name: str) -> Callable[[Node], bool]: + """Get the module_name_filter function for a given module name, the filter accepts + a node and checks if the node comes from a module that has certain module name + + For example: + node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 + + + >> module_name_filter = _get_module_name_filter("blocks.sub") + >> print(module_name_filter(node)) + True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" + """ + + def module_name_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + # get_attr nodes doesn't have nn_module_stack? + nn_module_stack = n.meta.get("nn_module_stack", {}) + + def _normalize_path(n: str) -> str: + prefix = 0 + # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. + if n.startswith("L['self']."): + prefix = len("L['self'].") + return n[prefix:] + + names = [_normalize_path(n) for n, _ in nn_module_stack.values()] + return module_name in names + + return module_name_filter diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cde0e2d12a6d00abfef6c2564b679286d99262 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -0,0 +1,1605 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import operator +import warnings +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import Any, Optional, TYPE_CHECKING, TypeAlias + +import torch +import torch.nn.functional as F +from torch.ao.quantization.fake_quantize import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, +) +from torch.ao.quantization.observer import ( + HistogramObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, + PlaceholderObserver, +) +from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, + SharedQuantizationSpec, +) +from torch.ao.quantization.quantizer.utils import _get_module_name_filter +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + get_bias_qspec, + get_input_act_qspec, + get_output_act_qspec, + get_weight_qspec, + QuantizationConfig, +) +from torch.fx import Node +from torch.fx.passes.utils.source_matcher_utils import ( + get_source_partitions, + SourcePartition, +) + + +FilterFn: TypeAlias = Callable[[list[Node]], bool] + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + +__all__ = [ + "X86InductorQuantizer", + "get_default_x86_inductor_quantization_config", + "get_x86_inductor_linear_dynamic_fp16_config", +] + + +@dataclass +class _X86InductorQuantizationAnnotation(QuantizationAnnotation): + # _is_output_of_quantized_pattern: + # * Node as output node of a fusion pattern. + # * The fusion pattern supports int8 data type. + # * The fusion pattern has inputs annotated to insert observer. + # * The quantization_config is not `None`. + _is_output_of_quantized_pattern: bool = False + + +# Operators that: +# 1. Operators are optimized to run with int8 when int8 input provided. +# 2. Operators do not support int8 input and produce fp32 output. +int8_in_int8_out_ops: set = { + torch.ops.aten.max_pool2d.default, + torch.ops.aten.cat.default, + torch.ops.aten.avg_pool2d.default, + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.flatten.using_ints, +} + +# Operators that support the int8 data type for quantization config propagation. +# A superset of int8_in_int8_out_ops incorporating additional operators. +propagation_quantizable_ops = int8_in_int8_out_ops + +# Operators support the int8 data type +# and recipe is configured by default in X86InductorQuantizer. +default_quantizable_ops = propagation_quantizable_ops | { + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, +} + +# A superset of default_quantizable_ops includes operators support the int8 data type +# but not enabled by default recipe of X86InductorQuantizer. +quantizable_ops = default_quantizable_ops | { + torch.ops.aten.matmul.default, +} + +QUANT_ANNOTATION_KEY = "quantization_annotation" + + +def _skip_annotate(nodes: list[Node], filter_fn: FilterFn | None = None) -> bool: + """Determine whether to skip annotation for a list of nodes.""" + + # 1) Skip annotate if any node is already annotated + if _is_any_annotated(nodes): + return True + + # 2) Proceed annotate if a) a filter function is provided + # and b) the given nodes list passes the filter function check. + if filter_fn and filter_fn(nodes): + return False + + return True + + +def _create_module_name_filter(module_name: str) -> FilterFn: + """Create a filter function for a given module name. + + The filter function takes a list of nodes (as determined by the annotate function) + and return True if *all* nodes come from the specified module name, False otherwise. + + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` + + >> module_name_filter = _create_module_name_filter_inner("sub") + >> print(module_name_filter([relu, linear_1])) + # True # These two nodes are determined by `_annotate_linear_unary` function and from "sub". + """ + + filter_fn = _get_module_name_filter(module_name) + + def check_all_nodes_from_module(nodes: list[Node]) -> bool: + all_nodes_from_module_name: bool = all(filter_fn(n) for n in nodes) + return all_nodes_from_module_name + + return check_all_nodes_from_module + + +def _create_operator_type_filter( + operator_type: Callable, +) -> FilterFn: + """Create a filter function for a given operator type. + + The filter function takes a list of nodes and returns True if it contains + exactly one node with the specified operator type, False otherwise. + + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` + + >> operator_type_filter = _create_operator_type_filter(torch.ops.aten.linear.default) + >> print(operator_type_filter([relu, linear_1])) + # True # These two nodes are determined by `_annotate_linear_unary` function and the second node is `linear`. + """ + + def operator_type_filter(nodes: list[Node]): + num_nodes_with_operator_type = sum( + node.target == operator_type for node in nodes + ) + if num_nodes_with_operator_type > 1: + raise NotImplementedError( + f"Several nodes within a single pattern are {operator_type}." + ) + return num_nodes_with_operator_type == 1 + + return operator_type_filter + + +def _global_config_filter(nodes: list[Node]) -> bool: + """Filter function for global configuration. + + This filter function takes a list of nodes and returns True if there is exactly one node + in the list that is a default quantizable operation, False otherwise. + """ + num_nodes_in_default_quantizable_ops = sum( + node.target in default_quantizable_ops for node in nodes + ) + if num_nodes_in_default_quantizable_ops > 1: + raise NotImplementedError( + "Several nodes within a single pattern are default quantizable operations." + ) + return num_nodes_in_default_quantizable_ops == 1 + + +def _map_module_function_to_aten_operator_type(): + module_function_to_aten_operator: dict[Callable, torch._ops.OpOverloadPacket] = {} + map_list = ( + ([torch.nn.Conv2d, F.conv1d], torch.ops.aten.conv1d.default), + ([torch.nn.Conv2d, F.conv2d], torch.ops.aten.conv2d.default), + ([torch.nn.Linear, F.linear], torch.ops.aten.linear.default), + ([torch.nn.MaxPool2d, F.max_pool2d], torch.ops.aten.max_pool2d.default), + ( + [ + torch.cat, + ], + torch.ops.aten.cat.default, + ), + ([torch.nn.AvgPool2d, F.avg_pool2d], torch.ops.aten.avg_pool2d.default), + ( + [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], + torch.ops.aten.adaptive_avg_pool2d.default, + ), + ( + [ + torch.flatten, + ], + torch.ops.aten.flatten.using_ints, + ), + ( + [ + torch.matmul, + ], + torch.ops.aten.matmul.default, + ), + ) + for map_item in map_list: + module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[arg-type, call-overload] + return module_function_to_aten_operator + + +def _mark_nodes_as_annotated(nodes: list[Node]): + for node in nodes: + if node is not None: + if QUANT_ANNOTATION_KEY not in node.meta: + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation() + node.meta[QUANT_ANNOTATION_KEY]._annotated = True + + +def _is_node_annotated(_node): + """ + return True if the node is annotated, otherwise return False + """ + return ( + QUANT_ANNOTATION_KEY in _node.meta + and _node.meta[QUANT_ANNOTATION_KEY]._annotated + ) + + +def _is_any_annotated(nodes: list[Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False. + """ + return any(_is_node_annotated(node) for node in nodes) + + +def _is_all_annotated(nodes: list[Node]): + """ + Given a list of nodes (that represents an operator pattern), + return True if all of the node is annotated, otherwise return False. + """ + return all(_is_node_annotated(node) for node in nodes) + + +def _is_quantized_op_pt2e(node: torch.fx.Node): + """ + Used for pt2e flow to check if the node is a quantized node: + Case1: the node has been annotated as output node of a fusion pattern. + Case2: the node has been annotated as single quantized node. + """ + if not _is_any_annotated([node]): + # The node has not been annotated, directly return False + return False + quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None) + if not isinstance(quantization_annotation, _X86InductorQuantizationAnnotation): + raise AssertionError( + "quantization_annotation must be an _X86InductorQuantizationAnnotation" + ) + return quantization_annotation._is_output_of_quantized_pattern + + +@functools.lru_cache +def get_default_x86_inductor_quantization_config( + is_qat: bool = False, + is_dynamic: bool = False, + reduce_range: bool = False, +): + """ + reduce_range is False by default. Set it to True on earlier CPUs without VNNI to avoid accuracy issue. + """ + extra_args: dict[str, Any] = {"eps": 2**-12} + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + # Copy from x86 default qconfig from torch/ao/quantization/qconfig.py + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=127 if reduce_range else 255, + qscheme=torch.per_tensor_affine, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + FusedMovingAvgObsFakeQuantize if is_qat else PerChannelMinMaxObserver + ) + + if is_qat: + # Only support per channel quant for now + extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + bias_quantization_spec = None # will use placeholder observer by default + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + return quantization_config + + +@functools.lru_cache +def get_x86_inductor_linear_dynamic_fp16_config(): + """ + For linear_dynamic_fp16. The name may be confusing. + The op's behavior is fp32_input * (fp16_weight -> to_fp32) -> fp32_output. + """ + weight_quantization_spec = QuantizationSpec( + dtype=torch.float16, + observer_or_fake_quant_ctr=PlaceholderObserver, + ) + quantization_config = QuantizationConfig( + None, # input_quantization_spec + None, # output_quantization_spec + weight_quantization_spec, + None, # bias_quantization_spec + ) + return quantization_config + + +def _annotate_nodes_not_quantize(nodes: Node | list[Node]) -> None: + """Annotate nodes to exclude them from quantization (their `quantization_config` is `None`).""" + if not isinstance(nodes, list): + nodes = [nodes] + for node in nodes: + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True + ) + + +def _config_checker(method: Callable) -> Callable: + @functools.wraps(method) + def wrapper( + quantizer: "X86InductorQuantizer", + name: Any, + quantization_config: Optional["QuantizationConfig"], + ) -> "X86InductorQuantizer": + if quantizer._need_skip_config(quantization_config): + warnings.warn( + f"Skip the quantization config for {name}.", + stacklevel=2, + ) + return quantizer + return method(quantizer, name, quantization_config) + + return wrapper + + +@dataclass +class _CurrentQuantizationMode: + r"""Configuration defining the current quantization mode for the quantizer. + + All possible current quantization modes are listed below: + ---------------------------------------------------------------------------------------------------------- + | dynamic_state + qat_state |--------------------------------------------------------------------------------------------- + | None | True | False + ---------------------------------------------------------------------------------------------------------- + None | quantizer does not receive a non-None `quantization_config` | \ | \ + False | quantizer will not do QAT | dynamic | static + True | quantizer will do QAT | QAT + dynamic | QAT + static + """ + + qat_state: bool | None + dynamic_state: bool | None + + +class X86InductorQuantizer(Quantizer): + module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type() + + def __init__(self) -> None: + super().__init__() + self.global_config: QuantizationConfig | None = None + self.operator_type_qconfig: dict[ + torch._ops.OpOverloadPacket, QuantizationConfig | None + ] = {} + self.module_name_qconfig: dict[str, QuantizationConfig | None] = {} + + def _get_current_quantization_mode(self) -> _CurrentQuantizationMode: + """Retrieves the current quantization mode based on all configurations.""" + qat_state = None + dynamic_state = None + + # As we use `_need_skip_config` to skip all invalid configurations, + # we can safely assume that the all existing non-None configurations + # have the same quantization mode. + # pyrefly: ignore [bad-assignment] + for qconfig in ( + list(self.module_name_qconfig.values()) + + list(self.operator_type_qconfig.values()) + + [self.global_config] + ): + if qconfig is not None: + # Query the `is_qat` state + if qat_state is None: + qat_state = qconfig.is_qat + else: + if qat_state != qconfig.is_qat: + raise AssertionError( + f"All non-None quantization configs should have the same `is_qat`," + f"but got {qat_state} and {qconfig.is_qat}." + ) + # Query the `is_dynamic` state + input_activation_spec = qconfig.input_activation + if input_activation_spec is not None: + if dynamic_state is None: + dynamic_state = input_activation_spec.is_dynamic + else: + if dynamic_state != input_activation_spec.is_dynamic: + raise AssertionError( + f"All non-None `input_activation_spec` should have the same `is_dynamic`," + f"but got {dynamic_state} and {input_activation_spec.is_dynamic}." + ) + return _CurrentQuantizationMode( + qat_state=qat_state, dynamic_state=dynamic_state + ) + + def _need_skip_config(self, quantization_config: QuantizationConfig | None) -> bool: + """Check if the provided quantization config is valid for X86InductorQuantizer. + + Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported. + To avoid such a mix, we compare the incoming configuration with current configuration status. + Refer the `_CurrentQuantizationMode` definition for all possible modes. + """ + if quantization_config is None: + return False + + need_skip = False + current_mode = self._get_current_quantization_mode() + if ( + current_mode.qat_state is not None + and current_mode.qat_state != quantization_config.is_qat + ): + warnings.warn( + "Mixed QAT and Non-QAT quantization config is not supported.", + stacklevel=2, + ) + need_skip = True + if current_mode.dynamic_state is not None: + input_activation_spec = quantization_config.input_activation + if ( + input_activation_spec is not None + and current_mode.dynamic_state != input_activation_spec.is_dynamic + ): + warnings.warn( + "Mixed dynamic and static quantization config is not supported.", + stacklevel=2, + ) + need_skip = True + return need_skip + + def set_global(self, quantization_config: QuantizationConfig): + if self._need_skip_config(quantization_config): + warnings.warn("Skip the global quantization config.", stacklevel=2) + return self + self.global_config = quantization_config + return self + + def get_global_quantization_config(self): + if not isinstance(self.global_config, QuantizationConfig): + warnings.warn( + "The global_config for X86InductorQuantizer is currently invalid. \ + Please ensure that you use set_global to establish the global quantization configuration.", + stacklevel=2, + ) + return self.global_config + + @_config_checker + def set_function_type_qconfig( + self, + function_type: Callable, + quantization_config: QuantizationConfig | None, + ) -> "X86InductorQuantizer": + if function_type in X86InductorQuantizer.module_function_to_aten_operator_type: + self._set_aten_operator_qconfig( + X86InductorQuantizer.module_function_to_aten_operator_type[ + function_type + ], + quantization_config, + ) + else: + warnings.warn( + f"function: Unable to customize quantization config for {function_type} by X86InductorQuantizer.", + stacklevel=2, + ) + return self + + @_config_checker + def set_module_type_qconfig( + self, + module_type: torch.nn.Module, + quantization_config: QuantizationConfig | None, + ) -> "X86InductorQuantizer": + if module_type in X86InductorQuantizer.module_function_to_aten_operator_type: + self._set_aten_operator_qconfig( + X86InductorQuantizer.module_function_to_aten_operator_type[module_type], + quantization_config, + ) + else: + warnings.warn( + f"Module: Unable to customize quantization config for {module_type} by X86InductorQuantizer.", + stacklevel=2, + ) + return self + + @_config_checker + def set_module_name_qconfig( + self, module_name: str, quantization_config: QuantizationConfig | None + ): + """Set quantization_config for a submodule with name: `module_name`, for example: + quantizer.set_module_name_qconfig("blocks.sub"), it will quantize all supported operator/operator + patterns in the submodule with this module name with the given `quantization_config` + + The supported operators include `quantizable_ops` and `propagation_quantizable_ops`. + """ + self.module_name_qconfig[module_name] = quantization_config + return self + + def _set_aten_operator_qconfig( + self, + operator_type: torch._ops.OpOverloadPacket, + quantization_config: QuantizationConfig | None, + ) -> "X86InductorQuantizer": + if operator_type in quantizable_ops: + self.operator_type_qconfig[operator_type] = quantization_config + else: + warnings.warn( + f"operator: Unable to quantize {operator} by X86InductorQuantizer.", + stacklevel=2, + ) + return self + + def _annotate_conv_node_helper( + self, + conv_node: torch.fx.Node, + annotate_output: bool, + quantization_config: QuantizationConfig | None, + ) -> None: + """Helper function to annotate the conv node""" + if quantization_config is None: + _annotate_nodes_not_quantize(conv_node) + return + input_qspec_map = {} + input_node = conv_node.args[0] + if not isinstance(input_node, Node): + raise AssertionError("input_node must be a FX Node") + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + weight_node = conv_node.args[1] + if not isinstance(weight_node, Node): + raise AssertionError("weight_node must be a FX Node") + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + bias_node = None if len(conv_node.args) == 2 else conv_node.args[2] + if isinstance(bias_node, Node): + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + if annotate_output: + conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + + def _annotate_linear_node_helper( + self, + linear_node: torch.fx.Node, + annotate_output: bool, + quantization_config: QuantizationConfig | None, + ) -> None: + """Helper function to annotate the linear node""" + if quantization_config is None: + _annotate_nodes_not_quantize(linear_node) + return + input_qspec_map = {} + if linear_node.target is not torch.ops.aten.linear.default: + raise AssertionError( + "linear_node.target must be torch.ops.aten.linear.default" + ) + has_bias = len(linear_node.args) == 3 + input_index = 0 + weight_index = 1 + bias_index = 2 + + input_node = linear_node.args[input_index] + if not isinstance(input_node, Node): + raise AssertionError("input_node must be a FX Node") + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + + weight_node = linear_node.args[weight_index] + if not isinstance(weight_node, Node): + raise AssertionError("weight_node must be a FX Node") + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + + bias_node = linear_node.args[bias_index] if has_bias else None + if isinstance(bias_node, Node): + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + + if annotate_output: + linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, _annotated=True + ) + + def _get_output_nodes_of_partitions( + self, + partition_list: list[SourcePartition], + ) -> list[torch.fx.Node]: + """Helper function to get the output node list from partition list""" + output_node_list = [] + for partition in partition_list: + if len(partition.output_nodes) > 1: + raise ValueError("Input partition has more than one output node") + output_node = partition.output_nodes[0] + if not isinstance(output_node, Node): + raise AssertionError("output_node must be a FX Node") + output_node_list.append(output_node) + if len(output_node_list) != len(partition_list): + raise ValueError( + "length of output_node_list should equal to length of partition_list" + ) + return output_node_list + + def _get_input_idx_for_binary_node( + self, + conv_gemm_node: torch.fx.Node, + binary_node: torch.fx.Node, + ): + """Helper function to check conv_gemm and extra input node index + for binary node fused with conv_gemm. + """ + conv_gemm_node_idx = None + extra_input_node_idx = None + if (binary_node.args[0].op == "call_function") and ( # type: ignore[union-attr] + binary_node.args[0] == conv_gemm_node + ): + conv_gemm_node_idx = 0 + extra_input_node_idx = 1 + elif (binary_node.args[1].op == "call_function") and ( # type: ignore[union-attr] + binary_node.args[1] == conv_gemm_node + ): + conv_gemm_node_idx = 1 + extra_input_node_idx = 0 + extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index] + if not isinstance(extra_input_node, Node): + raise AssertionError("extra_input_node must be a FX Node") + return conv_gemm_node_idx, extra_input_node_idx + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Annotate the given model with quantization configurations. + + Annotation contracts: + 1. Annotate each node according to the user's qconfig in the following order: + `module_name_qconfig`, `operator_type_qconfig`, and `global_config`. + 2. Avoid re-annotating nodes already annotated in prior stages. For example, + if `linear1` has been annotated by `module_name_qconfig`, it won't be annotated again + during the processing of the 'operator_type_qconfig' or 'global_config'. + 3. For config is `None`, the node will be annotated with `_X86InductorQuantizationAnnotation(_annotated=True)`. + + For each pair of (module_name_or_operator_type_or_global, qconfig), a filter function is created. + This filter function checks if the node is marked by current stage and not annotated by the previous stage. + """ + for module_name, quantization_config in self.module_name_qconfig.items(): + self._annotate_with_config( + model, quantization_config, _create_module_name_filter(module_name) + ) + + for operator_type, quantization_config in self.operator_type_qconfig.items(): + self._annotate_with_config( + model, quantization_config, _create_operator_type_filter(operator_type) + ) + + if self.global_config: + self._annotate_with_config( + model, + self.global_config, + _global_config_filter, + ) + + # Once we've annotated the model with quantization configurations, we also need to annotate + # the output of quantizable operations. For example, if we annotated `maxpool2d` to quantize its inputs, + # we will quantize its output accordingly. This enables us to fuse the dq-operator-q into a quantized op. + # Refer to + # https://github.com/intel/intel-extension-for-pytorch/blob/90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 # noqa: B950 + + self._annotate_output_for_int8_in_int8_out_pattern_entry(model) + + return model + + def _annotate_with_config( + self, + model: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn, + ) -> None: + """Annotate the model with the given quantization configuration. + + High-level description of quantization recipe for X86 Inductor Backend: + Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. + Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model + from start to the end. If a pattern supports computation with int8 data type and inputs connected to + quantized patterns, annotate its inputs as quantized pattern. + """ + + # Step1: Recipe of fusion patterns like conv/linear. + self._annotate_conv2d_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_linear_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_matmul(model, quantization_config, filter_fn) + + # Step2: Recipe to propagate annotation for patterns beside conv/linear. + # Go through all the nodes from start to end. + # Recipe refer to + # https://github.com/intel/intel-extension-for-pytorch/blob/90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538 # noqa: B950 + + self._annotate_propagation_quantizable_pattern_entry( + model, quantization_config, filter_fn + ) + + def _annotate_qat_conv2d_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ): + # Annotate QAT Specific patterns + self._annotate_qat_conv2d_bn_binary_unary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn_binary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn_unary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn(model, quantization_config, filter_fn) + + def _annotate_qat_conv2d_bn_binary_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ) -> None: + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] + ) + for fused_partition in fused_partitions: + ( + conv_partition, + bn_partition, + binary_partition, + unary_partition, + ) = fused_partition + + ( + conv_node, + bn_output_node, + binary_node, + unary_node, + ) = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition, binary_partition, unary_partition] + ) + if len(bn_output_node.users) != 1: + # Conv BN pattern should only has 1 user. + continue + ( + bn_output_node_idx, + extra_input_node_idx, + ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node) + if (bn_output_node_idx is None) or (extra_input_node_idx is None): + continue + if bn_output_node != binary_node.args[bn_output_node_idx]: + raise ValueError(f"{bn_output_node} doesn't match input of binary node") + extra_input_node = binary_node.args[extra_input_node_idx] + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _skip_annotate( + [unary_node, binary_node, bn_output_node, conv_node], filter_fn + ): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + + if quantization_config is not None: + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # pyrefly: ignore [bad-argument-type] + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + ) + ) + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + else: + _annotate_nodes_not_quantize([binary_node, unary_node]) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + nodes_to_mark_annotated.extend(list(binary_partition.nodes)) + nodes_to_mark_annotated.extend(list(unary_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_qat_conv2d_bn_binary( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ) -> None: + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add] + ) + for fused_partition in fused_partitions: + conv_partition, bn_partition, binary_partition = fused_partition + ( + conv_node, + bn_output_node, + binary_node, + ) = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition, binary_partition] + ) + if len(bn_output_node.users) != 1: + # Conv BN pattern should only has 1 user. + continue + ( + bn_output_node_idx, + extra_input_node_idx, + ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node) + if (bn_output_node_idx is None) or (extra_input_node_idx is None): + continue + if bn_output_node != binary_node.args[bn_output_node_idx]: + raise ValueError(f"{bn_output_node} doesn't match input of binary node") + + extra_input_node = binary_node.args[extra_input_node_idx] + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _skip_annotate([binary_node, bn_output_node, conv_node], filter_fn): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + + if quantization_config is not None: + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # pyrefly: ignore [bad-argument-type] + input_qspec_map=binary_node_input_qspec_map, + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + else: + _annotate_nodes_not_quantize(binary_node) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + nodes_to_mark_annotated.extend(list(binary_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_qat_conv2d_bn_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ) -> None: + fused_partitions = [] + unary_patterns = [ + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardswish], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.SiLU], + ] + for unary_pattern in unary_patterns: + partitions = find_sequential_partitions(gm, unary_pattern) + if partitions: + # Extend the fused_partitions if partitions is not empty + fused_partitions.extend(partitions) + + for fused_partition in fused_partitions: + conv_partition, bn_partition, unary_partition = fused_partition + ( + conv_node, + bn_output_node, + unary_node, + ) = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition, unary_partition] + ) + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _skip_annotate([unary_node, bn_output_node, conv_node], filter_fn): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + if quantization_config is not None: + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + else: + _annotate_nodes_not_quantize(unary_node) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + nodes_to_mark_annotated.extend(list(unary_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_qat_conv2d_bn( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ) -> None: + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d] + ) + for fused_partition in fused_partitions: + conv_partition, bn_partition = fused_partition + conv_node, bn_output_node = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition] + ) + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _skip_annotate([bn_output_node, conv_node], filter_fn): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + if quantization_config is not None: + bn_output_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + else: + _annotate_nodes_not_quantize(bn_output_node) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_conv2d_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ): + if (quantization_config is None) or (quantization_config.is_qat): + # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat + self._annotate_qat_conv2d_fusion_pattern( + model, quantization_config, filter_fn + ) + self._annotate_conv2d_binary_unary(model, quantization_config, filter_fn) + self._annotate_conv2d_binary(model, quantization_config, filter_fn) + self._annotate_conv2d_unary(model, quantization_config, filter_fn) + self._annotate_conv2d(model, quantization_config, filter_fn) + + def _annotate_linear_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ): + self._annotate_linear_binary_unary(model, quantization_config, filter_fn) + self._annotate_linear_unary(model, quantization_config, filter_fn) + self._annotate_linear(model, quantization_config, filter_fn) + + def _annotate_matmul( + self, + model: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ): + for node in model.graph.nodes: + if node.target != torch.ops.aten.matmul.default: + continue + if _skip_annotate([node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize(node) + continue + + input_qspec_map = {} + matmul_node = node + for input_node in matmul_node.args: + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + matmul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d_binary_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ) -> None: + # Conv2d + add + unary op + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, operator.add, torch.nn.ReLU] + ) + for fused_partition in fused_partitions: + conv_partition, binary_partition, unary_partition = fused_partition + conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions( + [conv_partition, binary_partition, unary_partition] + ) + if len(conv_node.users) != 1: + # Conv Node should only has 1 user node + continue + conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( + conv_node, binary_node + ) + if (conv_node_idx is None) or (extra_input_node_idx is None): + continue + if conv_node != binary_node.args[conv_node_idx]: + raise ValueError(f"{conv_node} doesn't match input of binary node") + extra_input_node = binary_node.args[extra_input_node_idx] + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + # No conv node found to be fused with add + continue + if _skip_annotate([unary_node, binary_node, conv_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, binary_node, unary_node]) + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + # pyrefly: ignore [bad-argument-type] + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + ) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d_binary( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ) -> None: + # Conv2d + add + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, operator.add] + ) + for fused_partition in fused_partitions: + conv_partition, binary_partition = fused_partition + conv_node, binary_node = self._get_output_nodes_of_partitions( + [conv_partition, binary_partition] + ) + if len(conv_node.users) != 1: + # Conv Node should only has 1 user node + continue + conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( + conv_node, binary_node + ) + if (conv_node_idx is None) or (extra_input_node_idx is None): + continue + if conv_node != binary_node.args[conv_node_idx]: + raise ValueError(f"{conv_node} doesn't match input of binary node") + extra_input_node = binary_node.args[extra_input_node_idx] + if not isinstance(conv_node, Node): + raise AssertionError("conv_node must be a FX Node") + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + # No conv node found to be fused with add + continue + if _skip_annotate([binary_node, conv_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, binary_node]) + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + # pyrefly: ignore [bad-argument-type] + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ) -> None: + fused_partitions = [] + unary_patterns = [ + [torch.nn.Conv2d, torch.nn.ReLU], + [torch.nn.Conv2d, torch.nn.Hardtanh], + [torch.nn.Conv2d, torch.nn.Hardswish], + [torch.nn.Conv2d, torch.nn.ReLU6], + [torch.nn.Conv2d, torch.nn.SiLU], + [torch.nn.Conv1d, torch.nn.ReLU], + ] + for unary_pattern in unary_patterns: + partitions = find_sequential_partitions(gm, unary_pattern) + if partitions: + # Extend the fused_partitions if partitions is not empty + fused_partitions.extend(partitions) + + for fused_partition in fused_partitions: + conv_partition, unary_partition = fused_partition + conv_node, unary_node = self._get_output_nodes_of_partitions( + [conv_partition, unary_partition] + ) + if conv_node.op != "call_function" or conv_node.target not in ( + torch.ops.aten.conv2d.default, + torch.ops.aten.conv1d.default, + ): + continue + if _skip_annotate([unary_node, conv_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, unary_node]) + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ) -> None: + conv_partitions = get_source_partitions( + gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] + ) + conv_partitions = list(itertools.chain.from_iterable(conv_partitions.values())) + for conv_partition in conv_partitions: + if len(conv_partition.output_nodes) > 1: + raise ValueError("conv partition has more than one output node") + conv_node = conv_partition.output_nodes[0] + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + raise ValueError(f"{conv_node} is not an aten conv2d operator") + # skip annotation if it is already annotated + if _skip_annotate([conv_node], filter_fn): + continue + self._annotate_conv_node_helper(conv_node, True, quantization_config) + + def _annotate_maxpool2d( + self, + node: Node, + quantization_config: QuantizationConfig | None, + ) -> None: + if node.target is not torch.ops.aten.max_pool2d.default: + return + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + + maxpool_node = node + if _is_any_annotated( + [ + maxpool_node, + ] + ): + return + + input_node = maxpool_node.args[0] + if not isinstance(input_node, Node): + raise AssertionError("input_node must be a FX Node") + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_cat( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + cat_node = node + input_nodes = cat_node.args[0] + if not isinstance(input_nodes, Sequence): + raise AssertionError("input_nodes must be a Sequence of FX Nodes") + first_input_node = input_nodes[0] + input_qspec_map = {} + if not isinstance(first_input_node, Node): + raise AssertionError("first_input_node must be a FX Node") + if not isinstance(cat_node, Node): + raise AssertionError("cat_node must be a FX Node") + input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config) + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, cat_node) + ) + + for input_node in input_nodes[1:]: + if input_node not in input_qspec_map: + # There has the case of cat same nodes: torch.cat([input0, input0], 1) + if not isinstance(input_node, Node): + raise AssertionError("input_node must be a FX Node") + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + + cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_propagation_quantizable_pattern_entry( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ): + for node in gm.graph.nodes: + self._annotate_propagation_quantizable_pattern( + node, quantization_config, filter_fn + ) + + def _annotate_propagation_quantizable_pattern( + self, node: Node, quantization_config, filter_fn + ) -> None: + # Propagate annotation to quantizable patterns. + if ( + (node.target in propagation_quantizable_ops) + and (not _is_any_annotated([node])) + and (node.op == "call_function") + ): + + def is_all_inputs_connected_to_quantized_op(input_nodes): + # Ensure all the inputs connect to fusion pattern or quantized node + for input_node in input_nodes: + if not _is_quantized_op_pt2e(input_node): + return False + return True + + if _skip_annotate([node], filter_fn): + return + + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + + if node.target is torch.ops.aten.max_pool2d.default: + # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not + input_nodes_to_check = [node.all_input_nodes[0]] + if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): + if quantization_config is not None: + warnings.warn( + f"The input of maxpool2d is not quantized, skip annotate maxpool2d with config {quantization_config}.", + stacklevel=2, + ) + return + + self._annotate_maxpool2d(node, quantization_config) + return + elif node.target is torch.ops.aten.cat.default: + input_nodes_to_check = node.all_input_nodes + if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): + return + self._annotate_cat(node, quantization_config) + elif ( + node.target is torch.ops.aten.flatten.using_ints + and len(node.users) > 0 + and not any(user.target in quantizable_ops for user in node.users) + ): + # Recipe of flatten: check if any users of flatten node are quantizable ops or not + return + else: + input_node = node.all_input_nodes[0] + if not is_all_inputs_connected_to_quantized_op( + [ + input_node, + ] + ): + return + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + return + + def _annotate_output_share_observer_as_input( + self, input_node: Node, source_node: Node + ): + source_node_quantization_annotation = source_node.meta.get(QUANT_ANNOTATION_KEY) + if ( + source_node_quantization_annotation + and source_node_quantization_annotation._is_output_of_quantized_pattern + ): + edge_or_node = (input_node, source_node) + source_node_quantization_annotation.output_qspec = SharedQuantizationSpec( + edge_or_node + ) + return + + def _annotate_output_for_int8_in_int8_out_pattern_entry( + self, + model: torch.fx.GraphModule, + ): + for node in model.graph.nodes: + self._annotate_output_for_int8_in_int8_out_pattern(node) + + def _annotate_output_for_int8_in_int8_out_pattern( + self, + node: Node, + ) -> None: + r""" + Check and insert observer at output of node in int8_in_int8_out_ops if needed. + Recipe refers to + https://github.com/intel/intel-extension-for-pytorch/blob/90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495 + """ # noqa: B950 + edge_or_node: tuple[Node, Node] + if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): + if node.target is torch.ops.aten.max_pool2d.default: + maxpool_node = node + if not _is_all_annotated( + [ + maxpool_node, + ] + ): + return + + # Get the quantization_annotation from getitem_node + maxpool_node_quantization_annotation = maxpool_node.meta.get( + QUANT_ANNOTATION_KEY + ) + if ( + maxpool_node_quantization_annotation + and maxpool_node_quantization_annotation._is_output_of_quantized_pattern + ): + # Annotate the output_qspec of getitem_node + input_act = maxpool_node.args[0] + if not isinstance(input_act, Node): + raise AssertionError("input_act must be a FX Node") + if not isinstance(maxpool_node, Node): + raise AssertionError("maxpool_node must be a FX Node") + edge_or_node = (input_act, maxpool_node) + maxpool_node_quantization_annotation.output_qspec = ( + SharedQuantizationSpec(edge_or_node) + ) + else: + input_node = node.all_input_nodes[0] + self._annotate_output_share_observer_as_input(input_node, node) + return + + def _annotate_linear( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ) -> None: + linear_partitions = get_source_partitions( + gm.graph, [torch.nn.Linear, torch.nn.functional.linear] + ) + linear_partitions = list( + itertools.chain.from_iterable(linear_partitions.values()) + ) + for partition in linear_partitions: + if len(partition.output_nodes) > 1: + raise ValueError( + "Linear partition cannot have more than one output node" + ) + linear_node = partition.output_nodes[0] + if ( + linear_node.op != "call_function" + or linear_node.target != torch.ops.aten.linear.default + ): + raise ValueError(f"{linear_node} is not an aten linear operator") + # skip annotation if it is already annotated + if _skip_annotate([linear_node], filter_fn): + continue + self._annotate_linear_node_helper(linear_node, True, quantization_config) + + def _annotate_linear_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ) -> None: + postop_list = [ + torch.nn.ReLU, + torch.nn.LeakyReLU, + torch.nn.Tanh, + torch.nn.GELU, + ] + fused_partitions: list[tuple] = [] + for postop in postop_list: + fused_partitions = fused_partitions + find_sequential_partitions( + gm, [torch.nn.Linear, postop] + ) + for fused_partition in fused_partitions: + linear_partition, unary_partition = fused_partition + linear_node, unary_node = self._get_output_nodes_of_partitions( + [linear_partition, unary_partition] + ) + if ( + linear_node.op != "call_function" + or linear_node.target != torch.ops.aten.linear.default + ): + continue + if _skip_annotate([unary_node, linear_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([linear_node, unary_node]) + continue + + self._annotate_linear_node_helper(linear_node, False, quantization_config) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_linear_binary_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ) -> None: + # linear + binary_op + (optional) unary op + binary_op_list = [operator.add] + unary_op_list = [torch.nn.ReLU, None] + combinations = itertools.product(binary_op_list, unary_op_list) + for binary_op, unary_op in combinations: + has_unary = unary_op is not None + seq_partition = [torch.nn.Linear, binary_op] + if has_unary: + # pyrefly: ignore [bad-argument-type] + seq_partition.append(unary_op) + fused_partitions = find_sequential_partitions(gm, seq_partition) + for fused_partition in fused_partitions: + unary_partition, unary_node = None, None + if has_unary: + ( + linear_partition, + binary_partition, + unary_partition, + ) = fused_partition + ( + linear_node, + binary_node, + unary_node, + ) = self._get_output_nodes_of_partitions( + [linear_partition, binary_partition, unary_partition] + ) + else: + linear_partition, binary_partition = fused_partition + linear_node, binary_node = self._get_output_nodes_of_partitions( + [linear_partition, binary_partition] + ) + if len(linear_node.users) != 1: + # Linear Node should only has 1 user node + continue + ( + linear_node_idx, + extra_input_node_idx, + ) = self._get_input_idx_for_binary_node(linear_node, binary_node) + if (linear_node_idx is None) or (extra_input_node_idx is None): + continue + if linear_node != binary_node.args[linear_node_idx]: + raise ValueError( + f"{linear_node} doesn't match input of binary node" + ) + if not isinstance(linear_node, Node): + raise AssertionError("linear_node must be a FX Node") + if ( + linear_node.op != "call_function" + or linear_node.target != torch.ops.aten.linear.default + ): + # No linear node found to be fused with add + continue + node_list = ( + [binary_node, linear_node] + if unary_node is None + else [unary_node, binary_node, linear_node] + ) + if _skip_annotate(node_list, filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize(node_list) + continue + + self._annotate_linear_node_helper( + linear_node, False, quantization_config + ) + # We don't insert q-dq before the binary input node due to accuracy issues + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + input_qspec_map={}, + _annotated=True, + _is_output_of_quantized_pattern=(not has_unary), + ) + ) + if unary_node is not None: + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a2234fdff3f137170d2810ef82fe8b7c706c0c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -0,0 +1,451 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import copy +import functools +import typing_extensions +from typing import Any, TYPE_CHECKING + +import torch +import torch._dynamo as torchdynamo +import torch.nn.functional as F +from torch.ao.quantization.fake_quantize import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, +) +from torch.ao.quantization.observer import ( + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, + PlaceholderObserver, +) +from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +from torch.ao.quantization.quantizer.utils import _get_module_name_filter +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + _convert_scalars_to_attrs, + OP_TO_ANNOTATOR, + OperatorConfig, + OperatorPatternType, + propagate_annotation, + QuantizationConfig, +) +from torch.fx._compatibility import compatibility + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + from torch.fx import Node + + +__all__ = [ + "XNNPACKQuantizer", + "get_symmetric_quantization_config", +] + + +def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph: + gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs) + gm.graph.eliminate_dead_code() + return gm.graph + + +def _get_linear_patterns(input_size: list[int]): + in_channels = input_size[-1] + out_channels = 8 # hard coding but this should not matter + weight = torch.ones((out_channels, in_channels)) + bias = torch.ones((out_channels,)) + act = torch.ones(input_size) + + def linear_op(act, weight, bias=None): + return F.linear(act, weight, bias) + + pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias)) + pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight)) + return [pattern_w_bias, pattern_wo_bias] + + +def _supported_symmetric_quantized_operators() -> dict[str, list[OperatorPatternType]]: + supported_operators: dict[str, list[OperatorPatternType]] = { + # Both conv and linear should be able to handle relu + hardtanh fusion since + # those are clamp ops + "conv2d": [ + [torch.nn.Conv2d, torch.nn.ReLU], + [torch.nn.Conv2d, F.relu], + [F.conv2d, torch.nn.ReLU], + [F.conv2d, F.relu], + ], + "linear": [[torch.nn.Linear], [F.linear]], + "add": [[torch.add]], + "adaptive_avg_pool2d": [ + [torch.nn.AdaptiveAvgPool2d], + [F.adaptive_avg_pool2d], + ], + } + return copy.deepcopy(supported_operators) + + +def _get_supported_symmetric_config_and_operators() -> list[OperatorConfig]: + supported_config_and_operators: list[OperatorConfig] = [] + for quantization_config in [ + get_symmetric_quantization_config(), + get_symmetric_quantization_config(is_qat=True), + get_symmetric_quantization_config(is_per_channel=True), + get_symmetric_quantization_config(is_per_channel=True, is_qat=True), + ]: + ops = _supported_symmetric_quantized_operators() + supported_config_and_operators.extend( + OperatorConfig(quantization_config, pattern_list) + for pattern_list in ops.values() + ) + return copy.deepcopy(supported_config_and_operators) + + +@functools.lru_cache +def get_symmetric_quantization_config( + is_per_channel: bool = False, + is_qat: bool = False, + is_dynamic: bool = False, + act_qmin: int = -128, + act_qmax: int = 127, + weight_qmin: int = -127, + weight_qmax: int = 127, +): + extra_args: dict[str, Any] = {"eps": 2**-12} + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=act_qmin, + quant_max=act_qmax, + qscheme=torch.per_tensor_affine, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args, + ), + ) + weight_qscheme = ( + torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric + ) + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + MinMaxObserver + ) + if is_qat: + # TODO: qat + per channel? + weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize + elif is_per_channel: + weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + + extra_args: dict[str, Any] = {"eps": 2**-12} + if is_qat: + if weight_qscheme == torch.per_tensor_symmetric: + extra_args["observer"] = MovingAverageMinMaxObserver + else: + extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=weight_qmin, + quant_max=weight_qmax, + qscheme=weight_qscheme, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + bias_quantization_spec = None + if is_dynamic: + quantization_config = QuantizationConfig( + act_quantization_spec, + None, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + else: + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + return quantization_config + + +def _get_supported_config_and_operators() -> list[OperatorConfig]: + return _get_supported_symmetric_config_and_operators() + + +def _get_module_type_filter(tp: Callable): + """Get the module_type_filter function for a given module type, the filter accepts + a node and checks if the node comes from a module that has certain module type + + For example: + node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear + + + >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule + >> print(module_type_filter(node)) + True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) + """ + + tp_str = tp.__module__ + "." + tp.__qualname__ + + def module_type_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + nn_module_stack = n.meta.get("nn_module_stack", {}) + types = [] + for _, t in nn_module_stack.values(): + # export() returns str, but older APIs (e.g. capture_pre_autograd_graph) + # return type. Handle both cases. + if isinstance(t, type): + t = t.__module__ + "." + t.__qualname__ + types.append(t) + return tp_str in types + + return module_type_filter + + +def _get_not_module_type_or_name_filter( + tp_list: list[Callable], module_name_list: list[str] +) -> Callable[[Node], bool]: + module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] + module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + + def not_module_type_or_name_filter(n: Node) -> bool: + return not any(f(n) for f in module_type_filters + module_name_list_filters) + + return not_module_type_or_name_filter + + +@compatibility(is_backward_compatible=False) +@typing_extensions.deprecated( + "XNNPACKQuantizer is deprecated! Please use xnnpack quantizer in " + "ExecuTorch (https://github.com/pytorch/executorch/tree/main/backends/xnnpack/quantizer) instead." +) +class XNNPACKQuantizer(Quantizer): + """ + !!! DEPRECATED !!! + XNNPACKQuantizer is a marked as deprecated. It will be removed in the future. + It has been moved to executorch.backends.xnnpack.quantizer.xnnpack_quantizer.XNNPACKQuantizer. + Please use the new quantizer instead. + """ + + supported_config_and_operators = _get_supported_config_and_operators() + STATIC_QAT_ONLY_OPS = [ + "conv_bn_relu", + "conv_bn", + "conv_transpose_bn_relu", + "conv_transpose_bn", + ] + + # static quantization ops (both PTQ and QAT) + # Preserve the order that fusions come before singular ops + STATIC_OPS = [ + "linear_relu", + "linear", + "conv_relu", + "conv", + "conv_transpose_relu", + "adaptive_avg_pool2d", + # TODO: move this to BoltNNQuantizer? + "gru_io_only", + "add_relu", + "add", + "mul_relu", + "mul", + "cat", + ] + + DYNAMIC_OPS = [ + "linear", + ] + + def __init__(self) -> None: + super().__init__() + self.global_config: QuantizationConfig | None = None + self.operator_type_config: dict[ + torch._ops.OpOverloadPacket, QuantizationConfig | None + ] = {} + self.module_type_config: dict[Callable, QuantizationConfig | None] = {} + self.module_name_config: dict[str, QuantizationConfig | None] = {} + + @classmethod + def get_supported_quantization_configs(cls) -> list[QuantizationConfig]: + op_configs: set[QuantizationConfig] = { + spec for spec, _ in cls.supported_config_and_operators + } + return list(op_configs) + + @classmethod + def get_supported_operator_for_quantization_config( + cls, quantization_config: QuantizationConfig | None + ) -> list[OperatorPatternType]: + if quantization_config is None: + all_ops = [] + for _, ops in cls.supported_config_and_operators: + all_ops.extend(ops) + return all_ops + + for config, ops in cls.supported_config_and_operators: + # note: this assumes each entry in cls.supported_spec_and_operators + # corresponds to one spec, e.g. we don't have + # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] + # where the first and second entry have the same spec but did not + # merge the op list + if config == quantization_config: + return ops + return [] + + def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer: + self.global_config = quantization_config + return self + + def set_operator_type( + self, + operator_type: torch._ops.OpOverloadPacket, + quantization_config: QuantizationConfig, + ) -> XNNPACKQuantizer: + self.operator_type_config[operator_type] = quantization_config + return self + + def set_module_type( + self, module_type: Callable, quantization_config: QuantizationConfig + ): + """Set quantization_config for a submodule with type: `module_type`, for example: + quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator + patterns in the submodule with this module type with the given `quantization_config` + """ + self.module_type_config[module_type] = quantization_config + return self + + def set_module_name( + self, module_name: str, quantization_config: QuantizationConfig | None + ): + """Set quantization_config for a submodule with name: `module_name`, for example: + quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator + patterns in the submodule with this module name with the given `quantization_config` + """ + if quantization_config is None: + raise AssertionError("quantization_config == None is not supported yet") + self.module_name_config[module_name] = quantization_config + return self + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + """Transforms scalar values to tensor attributes""" + return _convert_scalars_to_attrs(model) + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + # hacked for handling dynamic linear quant. will fix later. + if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] + model = self._annotate_for_dynamic_quantization_config(model) + else: + model = self._annotate_for_static_quantization_config(model) + propagate_annotation(model) + return model + + def _annotate_all_static_patterns( + self, + model: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, + ) -> torch.fx.GraphModule: + # TODO: implement the support for None to be canceling out previous annotations + if quantization_config is None: + return model + + if quantization_config.is_qat: + for op in self.STATIC_QAT_ONLY_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + for op in self.STATIC_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + return model + + def _annotate_all_dynamic_patterns( + self, + model: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, + ) -> torch.fx.GraphModule: + # TODO: implement the support for None to be canceling out previous annotations + if quantization_config is None: + return model + + for op in self.DYNAMIC_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + return model + + def _annotate_for_static_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + module_name_list = list(self.module_name_config.keys()) + for module_name, config in self.module_name_config.items(): + self._annotate_all_static_patterns( + model, config, _get_module_name_filter(module_name) + ) + + tp_list = list(self.module_type_config.keys()) + for module_type, config in self.module_type_config.items(): + self._annotate_all_static_patterns( + model, config, _get_module_type_filter(module_type) + ) + + self._annotate_all_static_patterns( + model, + self.global_config, + _get_not_module_type_or_name_filter(tp_list, module_name_list), + ) + return model + + def _annotate_for_dynamic_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + module_name_list = list(self.module_name_config.keys()) + for module_name, config in self.module_name_config.items(): + self._annotate_all_dynamic_patterns( + model, config, _get_module_name_filter(module_name) + ) + + tp_list = list(self.module_type_config.keys()) + for module_type, config in self.module_type_config.items(): + self._annotate_all_dynamic_patterns( + model, config, _get_module_type_filter(module_type) + ) + + self._annotate_all_dynamic_patterns( + model, + self.global_config, + _get_not_module_type_or_name_filter(tp_list, module_name_list), + ) + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> list[OperatorConfig]: + return cls.supported_config_and_operators diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..22282d3d071a899e31cd4607027aa3abec249c7f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -0,0 +1,1152 @@ +# mypy: allow-untyped-defs +import itertools +import typing +from collections.abc import Callable +from dataclasses import dataclass +from typing import NamedTuple + +import torch +import torch.nn.functional as F +from torch._subclasses import FakeTensor +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix +from torch.ao.quantization.pt2e.export_utils import _WrapperModule +from torch.ao.quantization.pt2e.utils import ( + _get_aten_graph_module_for_pattern, + _is_conv_node, + _is_conv_transpose_node, +) +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + SharedQuantizationSpec, +) +from torch.ao.quantization.quantizer.utils import ( + _annotate_input_qspec_map, + _annotate_output_qspec, +) +from torch.fx import Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap, +) +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +__all__ = [ + "OperatorConfig", + "OperatorPatternType", + "QuantizationConfig", + "get_input_act_qspec", + "get_output_act_qspec", + "get_weight_qspec", + "get_bias_qspec", + "OP_TO_ANNOTATOR", + "propagate_annotation", +] + + +# In the absence of better name, just winging it with QuantizationConfig +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: QuantizationSpec | None + output_activation: QuantizationSpec | None + weight: QuantizationSpec | None + bias: QuantizationSpec | None + # TODO: remove, since we can use observer_or_fake_quant_ctr to express this + is_qat: bool = False + + +# Use Annotated because list[Callable].__module__ is read-only. +OperatorPatternType = typing.Annotated[list[Callable], None] +OperatorPatternType.__module__ = ( + "torch.ao.quantization.quantizer.xnnpack_quantizer_utils" +) + +AnnotatorType = Callable[ + [ + torch.fx.GraphModule, + QuantizationConfig | None, + Callable[[Node], bool] | None, + ], + list[list[Node]] | None, +] +OP_TO_ANNOTATOR: dict[str, AnnotatorType] = {} + + +def register_annotator(op: str) -> Callable[[AnnotatorType], None]: + def decorator(annotator: AnnotatorType) -> None: + OP_TO_ANNOTATOR[op] = annotator + + return decorator + + +class OperatorConfig(NamedTuple): + # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]] + # Basically we are mapping a quantization config to some list of patterns. + # a pattern is defined as a list of nn module, function or builtin function names + # e.g. [nn.Conv2d, torch.relu, torch.add] + # We have not resolved whether fusion can be considered internal details of the + # quantizer hence it does not need communication to user. + # Note this pattern is not really informative since it does not really + # tell us the graph structure resulting from the list of ops. + config: QuantizationConfig + operators: list[OperatorPatternType] + + +def _is_annotated(nodes: list[Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False + """ + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + +def _mark_nodes_as_annotated(nodes: list[Node]): + for node in nodes: + if node is not None: + if "quantization_annotation" not in node.meta: + node.meta["quantization_annotation"] = QuantizationAnnotation() + node.meta["quantization_annotation"]._annotated = True + + +def get_input_act_qspec(quantization_config: QuantizationConfig | None): + if quantization_config is None: + return None + if quantization_config.input_activation is None: + return None + quantization_spec: QuantizationSpec = quantization_config.input_activation + if quantization_spec.qscheme not in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ]: + raise AssertionError( + f"Unsupported activation qscheme: {quantization_spec.qscheme}" + ) + return quantization_spec + + +def get_output_act_qspec(quantization_config: QuantizationConfig | None): + if quantization_config is None: + return None + if quantization_config.output_activation is None: + return None + quantization_spec: QuantizationSpec = quantization_config.output_activation + if quantization_spec.qscheme not in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ]: + raise AssertionError( + f"Unsupported activation qscheme: {quantization_spec.qscheme}" + ) + return quantization_spec + + +def get_weight_qspec(quantization_config: QuantizationConfig | None): + if quantization_config is None: + return None + if quantization_config is None: + raise AssertionError("quantization_config must not be None") + if quantization_config.weight is None: + return None + quantization_spec: QuantizationSpec = quantization_config.weight + if quantization_spec.qscheme not in [ + torch.per_tensor_symmetric, + torch.per_channel_symmetric, + None, + ]: + raise ValueError( + f"Unsupported quantization_spec {quantization_spec} for weight" + ) + return quantization_spec + + +def get_bias_qspec(quantization_config: QuantizationConfig | None): + if quantization_config is None: + return None + if quantization_config is None: + raise AssertionError("quantization_config must not be None") + if quantization_config.bias is None: + return None + quantization_spec: QuantizationSpec = quantization_config.bias + if quantization_spec.dtype != torch.float: + raise AssertionError( + "Only float dtype for bias is supported for bias right now" + ) + return quantization_spec + + +@register_annotator("linear") +def _annotate_linear( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target != torch.ops.aten.linear.default: + continue + if filter_fn and not filter_fn(node): + continue + act_node = node.args[0] + weight_node = node.args[1] + bias_node = None + if len(node.args) > 2: + bias_node = node.args[2] + + if _is_annotated([node]) is False: # type: ignore[list-item] + _annotate_input_qspec_map( + node, + act_node, + input_act_qspec, + ) + _annotate_input_qspec_map( + node, + weight_node, + weight_qspec, + ) + nodes_to_mark_annotated = [node, weight_node] + if bias_node: + _annotate_input_qspec_map( + node, + bias_node, + bias_qspec, + ) + nodes_to_mark_annotated.append(bias_node) + _annotate_output_qspec(node, output_act_qspec) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + annotated_partitions.append(nodes_to_mark_annotated) + + return annotated_partitions + + +@register_annotator("linear_relu") +def _annotate_linear_relu( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_linear_node = node.args[0] + if ( + not isinstance(maybe_linear_node, Node) + or maybe_linear_node.op != "call_function" + or maybe_linear_node.target != torch.ops.aten.linear.default + ): + continue + + linear_node = maybe_linear_node + if len(linear_node.users) > 1: + # if linear node has multiple users, then it can't be fused with relu + continue + + input_qspec_map = {} + input_act = linear_node.args[0] + if not isinstance(input_act, Node): + raise AssertionError("input activation must be a FX Node") + input_qspec_map[input_act] = input_act_qspec + + weight = linear_node.args[1] + if not isinstance(weight, Node): + raise AssertionError("weight must be a FX Node") + input_qspec_map[weight] = weight_qspec + + # adding weight node to the partition as well + partition = [relu_node, linear_node, weight] + bias = linear_node.args[2] if len(linear_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = bias_qspec + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + linear_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv") +def _annotate_conv( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + annotated_partitions = [] + for n in gm.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ]: + continue + conv_node = n + + input_qspec_map = {} + input_act = conv_node.args[0] + if not isinstance(input_act, Node): + raise AssertionError("input activation must be a FX Node") + input_qspec_map[input_act] = get_input_act_qspec(quantization_config) + + weight = conv_node.args[1] + if not isinstance(weight, Node): + raise AssertionError("weight must be a FX Node") + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + # adding weight node to the partition as well + partition = [conv_node, conv_node.args[1]] + + bias = conv_node.args[2] if len(conv_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=get_output_act_qspec(quantization_config), + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +def _do_annotate_conv_relu( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, + is_conv_transpose: bool = False, +): + annotated_partitions = [] + for n in gm.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = n + maybe_conv_node = n.args[0] + + is_conv_node = _is_conv_transpose_node if is_conv_transpose else _is_conv_node + if not isinstance(maybe_conv_node, Node) or not is_conv_node(maybe_conv_node): + continue + conv_node = maybe_conv_node + + if len(conv_node.users) > 1: + # relu shouldn't be fuseable to conv if there are other users + # of convolution + continue + + input_qspec_map = {} + input_act = conv_node.args[0] + if not isinstance(input_act, Node): + raise AssertionError("input activation must be a FX Node") + input_qspec_map[input_act] = get_input_act_qspec(quantization_config) + + weight = conv_node.args[1] + if not isinstance(weight, Node): + raise AssertionError("weight must be a FX Node") + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + # adding weight node to the partition as well + partition = [relu_node, conv_node, conv_node.args[1]] + bias = conv_node.args[2] if len(conv_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + partition.append(bias) + + # pyrefly: ignore [bad-argument-type] + if _is_annotated(partition): + continue + + # pyrefly: ignore [bad-argument-type] + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, _annotated=True + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + ) + # pyrefly: ignore [bad-argument-type] + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv_relu") +def _annotate_conv_relu( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + return _do_annotate_conv_relu( + gm, quantization_config, filter_fn, is_conv_transpose=False + ) + + +@register_annotator("conv_transpose_relu") +def _annotate_conv_transpose_relu( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + return _do_annotate_conv_relu( + gm, quantization_config, filter_fn, is_conv_transpose=True + ) + + +@register_annotator("conv_bn") +def _annotate_conv_bn( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + """ + Find conv + batchnorm partitions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False) + + +@register_annotator("conv_bn_relu") +def _annotate_conv_bn_relu( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + """ + Find conv + batchnorm + relu partitions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True) + + +@register_annotator("conv_transpose_bn") +def _annotate_conv_transpose_bn( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + """ + Find conv_transpose + batchnorm partitions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn( + gm, quantization_config, filter_fn, has_relu=False, is_conv_transpose=True + ) + + +@register_annotator("conv_transpose_bn_relu") +def _annotate_conv_transpose_bn_relu( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + """ + Find conv_transpose + batchnorm + relu partitions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn( + gm, quantization_config, filter_fn, has_relu=True, is_conv_transpose=True + ) + + +def _do_annotate_conv_bn( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None, + has_relu: bool, + is_conv_transpose: bool = False, +) -> list[list[Node]]: + """ + Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern, + return a list of annotated partitions. + + The output of the pattern must include a dictionary from string name to node + for the following names: "input", "conv", "weight", "bias", and "output". + """ + + # Example inputs for conv-bn1d patterns + _conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for conv-bn2d patterns + _conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + def get_pattern(conv_fn: Callable, relu_is_inplace: bool): + def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): + conv = conv_fn(x, conv_weight, conv_bias) + bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True) + if has_relu: + output = F.relu_(bn) if relu_is_inplace else F.relu(bn) + else: + output = bn + return output, { + "input": x, + "conv": conv, + "weight": conv_weight, + "bias": conv_bias, + "output": output, + } + + return _WrapperModule(_conv_bn) + + # Needed for matching, otherwise the matches gets filtered out due to unused + # nodes returned by batch norm + gm.graph.eliminate_dead_code() + gm.recompile() + + matches = [] + if is_conv_transpose: + combinations = [ + (F.conv_transpose1d, _conv1d_bn_example_inputs), + (F.conv_transpose2d, _conv2d_bn_example_inputs), + ] + else: + combinations = [ + (F.conv1d, _conv1d_bn_example_inputs), # type: ignore[list-item] + (F.conv2d, _conv2d_bn_example_inputs), # type: ignore[list-item] + ] + + # Add `is_cuda` and `relu_is_inplace` dimensions + combinations = itertools.product( # type: ignore[assignment] + combinations, + [True, False] if torch.cuda.is_available() else [False], # is_cuda + [True, False] if has_relu else [False], # relu_is_inplace + ) + + # Match against all conv dimensions and cuda variants + for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: # type: ignore[misc] + pattern = get_pattern(conv_fn, relu_is_inplace) # type: ignore[has-type] + pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) # type: ignore[has-type] + pattern.graph.eliminate_dead_code() + pattern.recompile() + matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True) + matches.extend(matcher.match(gm.graph)) + + # Annotate nodes returned in the matches + annotated_partitions = [] + for match in matches: + name_node_map = match.name_node_map + input_node = name_node_map["input"] + conv_node = name_node_map["conv"] + weight_node = name_node_map["weight"] + bias_node = name_node_map["bias"] + output_node = name_node_map["output"] + + # TODO: annotate the uses of input, weight, and bias separately instead + # of assuming they come from a single conv node. This is not possible today + # because input may have multiple users, and we can't rely on the conv node + # always being the first user. This was the case in models with skip + # connections like resnet18 + + # Validate conv args + if conv_node.args[0] is not input_node: + raise ValueError("Conv arg did not contain input node ", input_node) + if conv_node.args[1] is not weight_node: + raise ValueError("Conv arg did not contain weight node ", weight_node) + if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node: + raise ValueError("Conv arg did not contain bias node ", bias_node) + + # Skip if the partition is already annotated or is filtered out by the user + partition = [conv_node, weight_node] + if bias_node is not None: + partition.append(bias_node) + if _is_annotated(partition): + continue + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + # Annotate conv inputs and pattern output + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + if bias_node is not None: + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + output_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("gru_io_only") +def _annotate_gru_io_only( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn) + gru_partitions = list(itertools.chain.from_iterable(gru_partitions.values())) + annotated_partitions = [] + for gru_partition in gru_partitions: + annotated_partitions.append(gru_partition.nodes) + output_nodes = gru_partition.output_nodes + input_nodes = gru_partition.input_nodes + # skip annotation if it is already annotated + if _is_annotated(input_nodes + output_nodes): + continue + # inside each GRU partition, we should be able to annotate each linear + # subgraph + input_act = input_nodes[0] + input_act_user = next(iter(input_act.users.keys())) + if not isinstance(input_act, Node): + raise AssertionError("input activation must be a FX Node") + if not isinstance(input_act_user, Node): + raise AssertionError("input activation user must be a FX Node") + input_act_user.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: get_input_act_qspec(quantization_config), + }, + _annotated=True, + ) + + hidden_state = input_nodes[1] + hidden_state_user = next(iter(hidden_state.users.keys())) + if not isinstance(hidden_state, Node): + raise AssertionError("hidden state must be a FX Node") + if not isinstance(hidden_state_user, Node): + raise AssertionError("hidden state user must be a FX Node") + hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + hidden_state: get_input_act_qspec(quantization_config), + }, + _annotated=True, + ) + + if len(output_nodes) != 2: + raise AssertionError("expecting GRU to have two outputs") + for output in output_nodes: + output.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), + _annotated=True, + ) + nodes_to_mark_annotated = list(gru_partition.nodes) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + return annotated_partitions + + +@register_annotator("adaptive_avg_pool2d") +def _annotate_adaptive_avg_pool2d( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + """Always annotate adaptive_avg_pool2d op""" + module_partitions = get_source_partitions( + gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn + ) + partitions = list(itertools.chain.from_iterable(module_partitions.values())) + annotated_partitions = [] + for partition in partitions: + pool_node = partition.output_nodes[0] + if ( + pool_node.op != "call_function" + or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default + ): + raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator") + + if _is_annotated([pool_node]): + continue + + annotated_partitions.append(partition.nodes) + input_act = pool_node.args[0] + if not isinstance(input_act, Node): + raise AssertionError("input activation must be a FX Node") + + # only annotate input output sharing operator + # when the output of the input node is annotated + if ( + "quantization_annotation" not in input_act.meta + or not input_act.meta["quantization_annotation"]._annotated + or input_act.meta["quantization_annotation"].output_qspec is None + ): + input_act_qspec = get_input_act_qspec(quantization_config) + else: + input_act_qspec = SharedQuantizationSpec(input_act) + + # output sharing with input + output_act_qspec = SharedQuantizationSpec((input_act, pool_node)) + pool_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: input_act_qspec, + }, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +def _is_input_large_scalar(node: Node, gm: torch.fx.GraphModule): + """Check if input is a large scalar value. So that we can skip quantization for the node + since histc op (in HistogramObserver) only works for values up to certain upper bound + """ + if node.op == "get_attr": + qualified_name = str(node.target) + module_path, _, name = qualified_name.rpartition(".") + submod = gm.get_submodule(module_path) + tensor = getattr(submod, name) + # torch.histc works until this upper bound + HISTC_UPPER_BOUND = 3.4028235e15 + return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND + return False + + +def _is_input_non_float_tensor(node: Node): + """Check if the input is not a float tensor, so that we can skip quantization for the node + since observers only works with float Tensors + """ + if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): + return True + return node.meta["val"].dtype != torch.float32 + + +@register_annotator("add_relu") +def _annotate_add_relu( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_add = node.args[0] + if ( + not isinstance(maybe_add, Node) + or maybe_add.op != "call_function" + or maybe_add.target + not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ] + ): + continue + + add_node = maybe_add + + if len(add_node.users) > 1: + # add can't be fused with ReLU if the result of add is being used + # else where in the graph + continue + + partition = [relu_node, add_node] + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = add_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + partition.append(input_act0) + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = add_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + partition.append(input_act1) + input_qspec_map[input_act1] = input_act_qspec + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("add") +def _annotate_add( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ]: + continue + add_node = node + partition = [add_node] + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = add_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + partition.append(input_act0) + + input_act1 = add_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + partition.append(input_act1) + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("mul_relu") +def _annotate_mul_relu( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_mul = node.args[0] + if ( + not isinstance(maybe_mul, Node) + or maybe_mul.op != "call_function" + or maybe_mul.target + not in [ + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul_.Tensor, + ] + ): + continue + + mul_node = maybe_mul + if len(mul_node.users) > 1: + # mul can't be fused with ReLU if the result of mul is being used + # else where in the graph + continue + + partition = [relu_node, mul_node] + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = mul_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + partition.append(input_act0) + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = mul_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + partition.append(input_act1) + input_qspec_map[input_act1] = input_act_qspec + + mul_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("mul") +def _annotate_mul( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul_.Tensor, + ]: + continue + + mul_node = node + partition = [mul_node] + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = mul_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + partition.append(input_act0) + + input_act1 = mul_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + partition.append(input_act0) + + mul_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +# TODO: remove Optional in return type, fix annotated_partitions logic +@register_annotator("cat") +def _annotate_cat( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: Callable[[Node], bool] | None = None, +) -> list[list[Node]] | None: + cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn) + cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values())) + annotated_partitions = [] + for cat_partition in cat_partitions: + cat_node = cat_partition.output_nodes[0] + if _is_annotated([cat_node]): + continue + + if cat_node.target != torch.ops.aten.cat.default: + # TODO: change this to AnnotationException + raise Exception( # noqa: TRY002 + f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}" + " please check if you are calling the correct capture API" + ) + + annotated_partitions.append(cat_partition.nodes) + + input_act_qspec = get_input_act_qspec(quantization_config) + inputs = cat_node.args[0] + + input_qspec_map = {} + input_act0 = inputs[0] # type: ignore[index] + if isinstance(input_act0, Node): + input_qspec_map[input_act0] = input_act_qspec + + shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) # type: ignore[arg-type] + for input_act in inputs[1:]: # type: ignore[index, union-attr] + if input_act not in input_qspec_map: + input_qspec_map[input_act] = shared_with_input0_qspec # type: ignore[index] + + output_act_qspec = shared_with_input0_qspec + + cat_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +def _is_share_obs_or_fq_op(op: Callable) -> bool: + return op in [ + torch.ops.aten.relu.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.max_pool2d.default, + torch.ops.aten.mean.default, + torch.ops.aten.mean.dim, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dim, + # TODO: remove? + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.view_copy.default, + torch.ops.aten.view.default, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.flatten.using_ints, + ] + + +def propagate_annotation(model: torch.fx.GraphModule) -> None: + for n in model.graph.nodes: + if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target): + continue + + prev_node = n.args[0] + if not isinstance(prev_node, Node): + continue + + quantization_annotation = prev_node.meta.get("quantization_annotation", None) + if not quantization_annotation: + continue + + output_qspec = quantization_annotation.output_qspec + if not output_qspec: + continue + + # make sure current node is not annotated + if ( + "quantization_annotation" in n.meta + and n.meta["quantization_annotation"]._annotated + ): + continue + + shared_qspec = SharedQuantizationSpec(prev_node) + # propagate the previous output_qspec to the current node + n.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + prev_node: shared_qspec, + }, + output_qspec=shared_qspec, + _annotated=True, + ) + + +# TODO: make the list of ops customizable +def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in model.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + ]: + continue + args = list(n.args) + new_args = [] + for i in range(len(args)): + if isinstance(args[i], torch.fx.Node): + new_args.append(args[i]) + continue + prefix = "_tensor_constant_" + get_new_attr_name = get_new_attr_name_with_prefix(prefix) + tensor_constant_name = get_new_attr_name(model) + float_tensor = torch.tensor(float(args[i])) + model.register_buffer(tensor_constant_name, float_tensor) + fake_mode = n.meta["val"].fake_mode + with model.graph.inserting_before(n): + get_attr_node = model.graph.create_node( + "get_attr", tensor_constant_name, (), {} + ) + get_attr_node.meta["val"] = fake_mode.from_tensor( + float_tensor, static_shapes=True + ) + new_args.append(get_attr_node) + n.args = tuple(new_args) + model.recompile() + return model diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0fc48fd54fa17b6ed0db900677ab339d62a988 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py @@ -0,0 +1,117 @@ +# mypy: allow-untyped-defs +import functools +from typing import Any, TYPE_CHECKING + +import torch +from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver +from torch.ao.quantization.quantizer.quantizer import QuantizationSpec +from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( + _is_any_annotated, + FilterFn, + int8_in_int8_out_ops, + X86InductorQuantizer, +) +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig +from torch.fx import Node + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + +__all__ = [ + "XPUInductorQuantizer", + "get_default_xpu_inductor_quantization_config", +] + + +@functools.lru_cache +def get_default_xpu_inductor_quantization_config(): + extra_args: dict[str, Any] = {"eps": 2**-12} + act_observer_or_fake_quant_ctr = HistogramObserver + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + PerChannelMinMaxObserver + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + bias_quantization_spec = None # will use placeholder observer by default + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + False, + ) + return quantization_config + + +class XPUInductorQuantizer(X86InductorQuantizer): + """ + XPUInductorQuantizer is a class designed to facilitate + quantization capability at Intel GPU backend. The class + highly reuses the existing implementation of + X86InductorQuantizer as both are intended to take advantage + of the optimized kernels in oneDNN library. + """ + + """ + Following annotate_xx overrides the impls in base class, as + no XPU implementation for these operators currently. We would + gradually enable the XPU implementation and remove following + overrides. We keep the annotate methods but make the function + body empty, aiming to let `_generate_qdq_quantized_model` + generate qdq around op and graph execute on fp32 dtype for + unsupported operators. + """ + + def _annotate_qat_conv2d_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: QuantizationConfig | None, + filter_fn: FilterFn | None = None, + ): + pass + + def _annotate_maxpool2d( + self, + node: Node, + quantization_config: QuantizationConfig | None, + ) -> None: + """ + Here we skip the annotate logic for maxpool at XPU backend + as the quantized::max_pool2d is only implemented for CPU. + """ + return + + def _annotate_output_for_int8_in_int8_out_pattern( + self, + node: Node, + ) -> None: + if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): + if node.target is torch.ops.aten.max_pool2d.default: + return + else: + input_node = node.all_input_nodes[0] + self._annotate_output_share_observer_as_input(input_node, node) + return diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75f8090e937f94a2bc8c87138963a0be70fc5cbd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/_reduction.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/_reduction.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6698501685e52adaa364084558d97ea51a2a1ce4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/_reduction.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/common_types.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/common_types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..120843d1673402256c545bb3e38800f8c1fe1752 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/common_types.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/cpp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/cpp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ded39e3098c90f6e37da3b873a391b84bb5a0ab Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/cpp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/grad.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/grad.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..141ef1ad07360fdec2cdc40809349ae31c59f4ea Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/grad.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/init.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/init.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fc317c035377890376c1759bd113fc91ae99761 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/init.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/parameter.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/parameter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e11c109aac2ef6456aae19cd931f188652d210fb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/__pycache__/parameter.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..438a8bc55caf0b73780288496a943848d7a71191 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__init__.py @@ -0,0 +1,187 @@ +# mypy: allow-untyped-defs +"""This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention""" + +import contextlib +from collections.abc import Iterable +from typing import Union +from warnings import warn + +import torch.backends.cuda +from torch._C import _SDPBackend as SDPBackend +from torch.backends.cuda import ( + can_use_efficient_attention, + can_use_flash_attention, + SDPAParams, +) + + +__all__: list[str] = [ + "SDPBackend", + "sdpa_kernel", + "WARN_FOR_UNFUSED_KERNELS", + "register_flash_attention_impl", + "activate_flash_attention_impl", + "list_flash_attention_impls", + "current_flash_attention_impl", +] + + +# Note: [SDPA warnings] +# TODO: Consider using this for sdpa regardless of subclasses +# This only effects users of bias subclasses +# If this is set to True, we will warn the user if they are not using the fused kernels +# As well, it will raise warnings for all the reasons why the fused kernels can't be run. +# To set this to True, run +# torch.nn.attention.WARN_FOR_UNFUSED_KERNELS = True +WARN_FOR_UNFUSED_KERNELS = False + + +r"""An enum-like class that contains the different backends for scaled dot product attention. + This backend class is designed to be used with the sdpa_kernel context manager. + + The following Enums are available: + - ERROR: An error occurred when trying to determine the backend. + - MATH: The math backend for scaled dot product attention. + - FLASH_ATTENTION: The flash attention backend for scaled dot product attention. + - EFFICIENT_ATTENTION: The efficient attention backend for scaled dot product attention. + - CUDNN_ATTENTION: The cuDNN backend for scaled dot product attention. + - OVERRIDEABLE: The overridable backend for extension. + + See :func:`torch.nn.attention.sdpa_kernel` for more details. + + .. warning:: This class is in beta and subject to change. +""" +SDPBackend.__module__ = __name__ +SDPBackend.__name__ = "SDPBackend" + + +def _raise_kernel_warnings(params: SDPAParams) -> None: + """ + If WARN_FOR_UNFUSED_KERNELS is set to True, this will raise warnings + for all the reasons why the fused kernels can't be run. If using subclasses + """ + if WARN_FOR_UNFUSED_KERNELS: + if not can_use_efficient_attention(params): + warn("Efficient attention can't be used because:", stacklevel=2) + can_use_efficient_attention(params, True) + if not can_use_flash_attention(params): + warn("Flash attention can't be used because:", stacklevel=2) + can_use_flash_attention(params, True) + + +_backend_names = { + "cudnn": "CUDNN_ATTENTION", + "flash": "FLASH_ATTENTION", + "mem_efficient": "EFFICIENT_ATTENTION", + "math": "MATH", + "overrideable": "OVERRIDEABLE", +} + + +def _backend_from_string(name: str): + return getattr(SDPBackend, name) + + +def _cur_sdpa_kernel_backends(with_priority: bool = False): + backends = [] + for name, val in _backend_names.items(): + if getattr(torch._C, f"_get_{name}_sdp_enabled")(): + backends.append(getattr(SDPBackend, val)) + if with_priority: + curr_priority = torch._C._get_sdp_priority_order() + backends = sorted( + backends, key=lambda backend: curr_priority.index(int(backend)) + ) + return backends + + +def _sdpa_kernel(backends: Iterable, set_priority: bool = False) -> None: + for name, val in _backend_names.items(): + enabled = getattr(SDPBackend, val) in backends + getattr(torch._C, f"_set_sdp_use_{name}")(enabled) + if set_priority: + # backends should be a unique list + user_priority = [int(backend) for backend in backends] + previous_priority = torch._C._get_sdp_priority_order() + for backend in previous_priority: + if backend not in user_priority: + user_priority.append(int(backend)) + torch._C._set_sdp_priority_order(user_priority) + + +@contextlib.contextmanager +def sdpa_kernel(backends: list[SDPBackend] | SDPBackend, set_priority: bool = False): + r""" + Context manager to select which backend to use for scaled dot product attention. + + .. warning:: This function is beta and subject to change. + + Args: + backends (Union[List[SDPBackend], SDPBackend]): A backend or list of backends for scaled dot product attention. + set_priority_order (bool=False): Whether the ordering of the backends is interpreted as their priority order. + + Example: + + .. code-block:: python + + from torch.nn.functional import scaled_dot_product_attention + from torch.nn.attention import SDPBackend, sdpa_kernel + + # Only enable flash attention backend + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + scaled_dot_product_attention(...) + + # Enable the Math or Efficient attention backends + with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): + scaled_dot_product_attention(...) + + This context manager can be used to select which backend to use for scaled dot product attention. + Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends. + """ + assert isinstance(backends, (list, SDPBackend)), ( + "Backend must be an instance of SDPBackend or a list of SDPBackend instances" + ) + + if isinstance(backends, SDPBackend): + backends = [backends] + + backends = list(dict.fromkeys(backends)) + + previous_backends = _cur_sdpa_kernel_backends(with_priority=set_priority) + try: + _sdpa_kernel(backends, set_priority) + yield {} + finally: + _sdpa_kernel(previous_backends, set_priority) + + +# variadic version of sdpa_kernel for dynamo to use while reconstructing +@contextlib.contextmanager +def _sdpa_kernel_variadic(*backends: SDPBackend): + with sdpa_kernel(list(backends)): + yield + + +def _get_flash_version() -> str: + """This returns the closest matching tag for the flash attention backend""" + return "2.5.7" + + +from . import _registry + + +# Re-export registry types and functions for public API +_FlashAttentionImpl = _registry._FlashAttentionImpl +_RegisterFn = _registry._RegisterFn +register_flash_attention_impl = _registry.register_flash_attention_impl +activate_flash_attention_impl = _registry.activate_flash_attention_impl +list_flash_attention_impls = _registry.list_flash_attention_impls +current_flash_attention_impl = _registry.current_flash_attention_impl + +register_flash_attention_impl.__module__ = __name__ +activate_flash_attention_impl.__module__ = __name__ +list_flash_attention_impls.__module__ = __name__ +current_flash_attention_impl.__module__ = __name__ + +# Import built-in implementations to trigger self-registration +from . import _fa4 # noqa: F401 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acd0be78fcc58ef94232aa898e2f46a656f5ea60 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/_fa4.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/_fa4.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8792bb557ff12fb792125f397e768e37f42b2c44 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/_fa4.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/_registry.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/_registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdde7403a9df831c18d24fc8889f4363bf808788 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/_registry.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ff4c8e334d3ebcbbf35497e1c12ddcdd0c029ee Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/bias.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/bias.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c276062f3e60c1b931bdbf705ef505361eb71524 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/bias.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/flex_attention.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/flex_attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05f9b33f049e74cd317158e2c0310167deb92b7c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/flex_attention.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/varlen.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/varlen.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67181a956473d7738d30493caa0551d26f550bef Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/__pycache__/varlen.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/_fa4.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/_fa4.py new file mode 100644 index 0000000000000000000000000000000000000000..1be960ee53218e4fe01ac1f16c7416ecc0ff3822 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/_fa4.py @@ -0,0 +1,456 @@ +"""UBER PROTOTYPE!!!""" +# mypy: allow-untyped-defs + +from __future__ import annotations + +import importlib +from dataclasses import dataclass +from functools import cache +from typing import Any, TYPE_CHECKING +from typing_extensions import TypeVarTuple, Unpack + +from . import _registry + + +if TYPE_CHECKING: + from types import ModuleType + +import torch +from torch.library import Library + + +__all__ = [ + "register_flash_attention_fa4", +] + + +_FA4_MODULE_PATH: str | None = None + + +@dataclass +class _FA4Handle: + library: Library | None + + def remove(self) -> None: + self.library = None + + +@cache +def _get_device_major(device: torch.device) -> int: + major, _ = torch.cuda.get_device_capability(device) + return major + + +def register_flash_attention_fa4( + module_path: str = "flash_attn.cute.interface", +) -> _FA4Handle: + """ + Register FA4 flash attention kernels with the PyTorch dispatcher. + + Args: + module_path: Python module path to the FA4 implementation. + """ + global _FA4_MODULE_PATH + _ = _fa4_import_module(module_path) + _FA4_MODULE_PATH = module_path + return _FA4Handle(_fa4_register_kernels()) + + +@cache +def _fa4_import_module(module_path: str) -> ModuleType: + module = importlib.import_module(module_path) + if not hasattr(module, "_flash_attn_fwd") or not hasattr(module, "_flash_attn_bwd"): + raise RuntimeError(f"Module '{module_path}' does not expose FA4 kernels") + return module + + +def _fa4_register_kernels() -> Library: + lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901 + lib.impl("_flash_attention_forward", _fa4_flash_attention_forward_impl, "CUDA") + lib.impl("_flash_attention_backward", _fa4_flash_attention_backward_impl, "CUDA") + lib.impl( + "_scaled_dot_product_flash_attention", + _fa4_scaled_dot_product_flash_attention_forward_impl, + "CUDA", + ) + lib.impl( + "_scaled_dot_product_flash_attention_backward", + _fa4_scaled_dot_product_flash_attention_backward_impl, + "CUDA", + ) + return lib + + +def _fa4_common_support_error( + query: torch.Tensor, + tensors: tuple[torch.Tensor, ...], + cum_seq_q: torch.Tensor | None, + require_fp32: tuple[tuple[str, torch.Tensor], ...] = (), +) -> str | None: + if not all(t.is_cuda for t in tensors): + return "inputs must be CUDA tensors" + if len({t.device for t in tensors}) != 1: + return "inputs must share device" + if query.dtype not in (torch.float16, torch.bfloat16): + return "query dtype must be float16 or bfloat16" + for name, tensor in require_fp32: + if tensor.dtype != torch.float32: + return f"{name} dtype must be float32" + if cum_seq_q is None and query.dim() != 4: + return "dense query must be 4D" + if cum_seq_q is not None and query.dim() != 3: + return "ragged query must be 3D" + if not torch.cuda.is_available(): + return "CUDA not available" + if _get_device_major(query.device) not in (9, 10): + return "FA4 requires compute capability 9.0 or 10.0" + return None + + +def _fa4_forward_support_error( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float, + return_debug_mask: bool, + alibi_slopes: torch.Tensor | None, + seqused_k: torch.Tensor | None, + cum_seq_q: torch.Tensor | None, +) -> str | None: + if dropout_p != 0.0: + return "dropout_p must be 0" + if return_debug_mask: + return "return_debug_mask must be False" + if alibi_slopes is not None: + return "alibi_slopes not supported" + if seqused_k is not None: + if seqused_k.dtype != torch.int32: + return "seqused_k must be int32" + if not seqused_k.is_cuda: + return "seqused_k must be CUDA" + error = _fa4_common_support_error( + query, + (query, key, value), + cum_seq_q, + ) + if error is not None: + if error == "inputs must share device": + return "query, key, value must be on same device" + return error + return None + + +def _fa4_backward_support_error( + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + dropout_p: float, + cum_seq_q: torch.Tensor | None, + window_size_left: int | None, + window_size_right: int | None, +) -> str | None: + if dropout_p != 0.0: + return "dropout_p must be 0" + if window_size_left is not None or window_size_right is not None: + return "windowed attention not supported" + error = _fa4_common_support_error( + query, + (grad_out, query, key, value, out, logsumexp), + cum_seq_q, + require_fp32=(("logsumexp", logsumexp),), + ) + if error is not None: + return error + return None + + +Ts = TypeVarTuple("Ts") + + +def _transpose_dense(*tensors: Unpack[Ts]) -> tuple[Unpack[Ts]]: + return tuple(t.transpose(1, 2) for t in tensors) # type: ignore[attr-defined] + + +def _fa4_run_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seq_q: torch.Tensor | None, + cu_seq_k: torch.Tensor | None, + scale: float | None, + is_causal: bool, + window_size_left: int | None, + window_size_right: int | None, + seqused_k: torch.Tensor | None, + out: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if _FA4_MODULE_PATH is None: + raise RuntimeError("FA4 not registered") + module = _fa4_import_module(_FA4_MODULE_PATH) + + kwargs: dict[str, Any] = { + "softmax_scale": scale, + "causal": is_causal, + "window_size_left": window_size_left, + "window_size_right": window_size_right, + "return_lse": True, + "cu_seqlens_q": cu_seq_q, + "cu_seqlens_k": cu_seq_k, + "seqused_k": seqused_k.contiguous() if seqused_k is not None else None, + } + if out is not None: + kwargs["out"] = out + out, lse = module._flash_attn_fwd(query, key, value, **kwargs) + return out, lse.contiguous() + + +def _fa4_run_backward( + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + cu_seq_q: torch.Tensor | None, + cu_seq_k: torch.Tensor | None, + scale: float | None, + is_causal: bool, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if _FA4_MODULE_PATH is None: + raise RuntimeError("FA4 not registered") + module = _fa4_import_module(_FA4_MODULE_PATH) + dq, dk, dv = module._flash_attn_bwd( + query, + key, + value, + out, + grad_out, + logsumexp.contiguous(), + softmax_scale=scale, + causal=is_causal, + cu_seqlens_q=cu_seq_q, + cu_seqlens_k=cu_seq_k, + ) + return dq, dk, dv + + +def _fa4_flash_attention_forward_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cum_seq_q: torch.Tensor | None, + cum_seq_k: torch.Tensor | None, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + return_debug_mask: bool, + *, + scale: float | None = None, + window_size_left: int | None = None, + window_size_right: int | None = None, + seqused_k: torch.Tensor | None = None, + alibi_slopes: torch.Tensor | None = None, + out: torch.Tensor | None = None, +): + error = _fa4_forward_support_error( + query, + key, + value, + dropout_p, + return_debug_mask, + alibi_slopes, + seqused_k, + cum_seq_q, + ) + if error is not None: + raise RuntimeError(f"FA4 flash_attention forward unsupported: {error}") + out, lse = _fa4_run_forward( + query, + key, + value, + cum_seq_q, + cum_seq_k, + scale, + is_causal, + window_size_left, + window_size_right, + seqused_k, + out, + ) + rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device) + philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device) + debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) + return out, lse, rng_state, philox_offset, debug_mask + + +def _fa4_flash_attention_backward_impl( + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + cum_seq_q: torch.Tensor | None, + cum_seq_k: torch.Tensor | None, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + rng_state: torch.Tensor, + unused: torch.Tensor, + *, + scale: float | None = None, + window_size_left: int | None = None, + window_size_right: int | None = None, +): + error = _fa4_backward_support_error( + grad_out, + query, + key, + value, + out, + logsumexp, + dropout_p, + cum_seq_q, + window_size_left, + window_size_right, + ) + if error is not None: + raise RuntimeError(f"FA4 flash_attention backward unsupported: {error}") + dq, dk, dv = _fa4_run_backward( + grad_out, + query, + key, + value, + out, + logsumexp, + cum_seq_q, + cum_seq_k, + scale, + is_causal, + ) + return dq, dk, dv + + +def _fa4_scaled_dot_product_flash_attention_forward_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: float | None = None, +): + error = _fa4_forward_support_error( + query, + key, + value, + dropout_p, + return_debug_mask, + None, + None, + None, + ) + if error is not None: + raise RuntimeError(f"FA4 SDPA forward unsupported: {error}") + q, k, v = _transpose_dense(query, key, value) + + # Pre-allocate output with query's strides (BHSD layout), then create + # a BSHD view for the kernel. This ensures the returned output has + # the same memory layout as the input query. + out_bhsd = torch.empty_like(query) + out_bshd = out_bhsd.transpose(1, 2) + + max_q_flash = q.size(1) + max_k_flash = k.size(1) + _, lse, rng_state, philox_offset, debug_mask = _fa4_flash_attention_forward_impl( + q, + k, + v, + None, + None, + max_q_flash, + max_k_flash, + dropout_p, + is_causal, + return_debug_mask, + scale=scale, + out=out_bshd, + ) + max_q = query.size(2) + max_k = key.size(2) + return ( + out_bhsd, + lse, + None, + None, + max_q, + max_k, + rng_state, + philox_offset, + debug_mask, + ) + + +def _fa4_scaled_dot_product_flash_attention_backward_impl( + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + cum_seq_q: torch.Tensor | None, + cum_seq_k: torch.Tensor | None, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + philox_seed: torch.Tensor, + philox_offset: torch.Tensor, + *, + scale: float | None = None, +): + error = _fa4_backward_support_error( + grad_out, + query, + key, + value, + out, + logsumexp, + dropout_p, + None, + None, + None, + ) + if error is not None: + raise RuntimeError(f"FA4 SDPA backward unsupported: {error}") + q, k, v, o, go = _transpose_dense(query, key, value, out, grad_out) + max_q = query.size(2) + max_k = key.size(2) + dq, dk, dv = _fa4_flash_attention_backward_impl( + go, + q, + k, + v, + o, + logsumexp, + None, + None, + max_q, + max_k, + dropout_p, + is_causal, + philox_seed, + philox_offset, + scale=scale, + ) + dq, dk, dv = _transpose_dense(dq, dk, dv) + return dq, dk, dv + + +_registry.register_flash_attention_impl("FA4", register_fn=register_flash_attention_fa4) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/_registry.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..883252d56f8b65cfa258d9d77ed463b374fd77ab --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/_registry.py @@ -0,0 +1,109 @@ +# mypy: allow-untyped-defs +"""Registry for flash attention implementations. + +This module contains the registration system for flash attention implementations. +It has no torch dependencies to avoid circular imports during initialization. +""" + +from collections.abc import Callable +from typing import Literal, Protocol + + +class FlashAttentionHandle(Protocol): + def remove(self) -> None: ... + + +_RegisterFn = Callable[..., FlashAttentionHandle | None] +_FlashAttentionImpl = Literal["FA4"] + +_FLASH_ATTENTION_IMPLS: dict[str, _RegisterFn] = {} + +_FLASH_ATTENTION_ACTIVE: str | None = None +_FLASH_ATTENTION_HANDLES: dict[str, FlashAttentionHandle] = {} + + +def register_flash_attention_impl( + impl: str | _FlashAttentionImpl, + *, + register_fn: _RegisterFn, +) -> None: + """ + Register the callable that activates a flash attention impl. + + .. note:: + This function is intended for SDPA backend providers to register their + implementations. End users should use :func:`activate_flash_attention_impl` + to activate a registered implementation. + + Args: + impl: Implementation identifier (e.g., ``"FA4"``). + register_fn: Callable that performs the actual dispatcher registration. + This function will be invoked by :func:`activate_flash_attention_impl` + and should register custom kernels with the PyTorch dispatcher. + It may optionally return a handle implementing + :class:`FlashAttentionHandle` to keep any necessary state alive. + + Example: + >>> def my_impl_register(module_path: str = "my_flash_impl"): + ... # Register custom kernels with torch dispatcher + ... pass # doctest: +SKIP + >>> register_flash_attention_impl( + ... "MyImpl", register_fn=my_impl_register + ... ) # doctest: +SKIP + """ + _FLASH_ATTENTION_IMPLS[impl] = register_fn + + +def activate_flash_attention_impl( + impl: str | _FlashAttentionImpl, +) -> None: + """ + Activate into the dispatcher a previously registered flash attention impl. + + .. note:: + Backend providers should NOT automatically activate their implementation + on import. Users should explicitly opt-in by calling this function or via + environment variables to ensure multiple provider libraries can coexist. + + Args: + impl: Implementation identifier to activate. See + :func:`~torch.nn.attention.list_flash_attention_impls` for available + implementations. + If the backend's :func:`register_flash_attention_impl` callable + returns a :class:`FlashAttentionHandle`, the registry keeps that + handle alive for the lifetime of the process (until explicit + uninstall support exists). + + Example: + >>> activate_flash_attention_impl("FA4") # doctest: +SKIP + """ + global _FLASH_ATTENTION_ACTIVE + register_fn = _FLASH_ATTENTION_IMPLS.get(impl) + if register_fn is None: + raise ValueError( + f"Unknown flash attention impl '{impl}'. " + f"Available implementations: {list_flash_attention_impls()}" + ) + # TODO: The only way to actually register a new impl is to unregister the current impl + # reinstall the default impl and then register the new impl + if _FLASH_ATTENTION_ACTIVE == impl: + return + + handle = register_fn() + if handle is not None: + _FLASH_ATTENTION_HANDLES[impl] = handle + _FLASH_ATTENTION_ACTIVE = impl + + +def list_flash_attention_impls() -> list[str]: + """Return the names of all available flash attention implementations.""" + return sorted(_FLASH_ATTENTION_IMPLS.keys()) + + +def current_flash_attention_impl() -> str | None: + """ + Return the currently activated flash attention impl name, if any. + + ``None`` indicates that no custom impl has been activated. + """ + return _FLASH_ATTENTION_ACTIVE diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cd530bb675e8fce9164b7de0d75fd9dce90edec8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/_utils.py @@ -0,0 +1,59 @@ +# mypy: allow-untyped-defs +"""Defines utilities for interacting with scaled_dot_product_attention""" + +import math + +import torch + + +__all__: list[str] = [] + + +def _input_requires_grad(*tensors: torch.Tensor) -> bool: + """Returns True if any of the tensors requires grad""" + return any(t.requires_grad for t in tensors) + + +def _postprocess_flash_output(inpt_tensor: torch.Tensor, og_size: int) -> torch.Tensor: + """Handles the unpad of the last dimension""" + if inpt_tensor.size(-1) != og_size: + return inpt_tensor[..., :og_size] + return inpt_tensor + + +def _calculate_scale(head_dim_size: int, scale: float | None) -> float: + """ + For FlashAttention we pad the head dimension to be a multiple of 8 so we need to scale the output + by the original head size and not the padded. + """ + if scale is not None: + return scale + return 1.0 / math.sqrt(head_dim_size) + + +def _validate_sdpa_input( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p=0.0, + is_causal=False, + scale=None, +) -> None: + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError( + f"Expected query, key, and value to have the same dtype, " + f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, " + f"and value.dtype: {value.dtype} instead." + ) + if query.device != key.device or query.device != value.device: + raise ValueError( + f"Expected query, key, and value to have the same device type, " + f"but got query.device: {query.device}, key.device: {key.device}, " + f"and value.device: {value.device} instead." + ) + if query.dim() < 2 or key.dim() < 2 or value.dim() < 2: + raise ValueError( + f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: " + f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead." + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/bias.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/bias.py new file mode 100644 index 0000000000000000000000000000000000000000..746e04c01f3d571fc61a06a62332f229ada4e6c7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/bias.py @@ -0,0 +1,371 @@ +# mypy: allow-untyped-defs +"""Defines bias subclasses that work with scaled_dot_product_attention""" + +from enum import auto, IntEnum +from warnings import warn + +import torch +import torch.nn.functional as F +from torch.backends.cuda import ( + can_use_efficient_attention, + can_use_flash_attention, + is_flash_attention_available, + SDPAParams, +) +from torch.nn.attention import _raise_kernel_warnings +from torch.nn.attention._utils import ( + _calculate_scale, + _input_requires_grad, + _postprocess_flash_output, + _validate_sdpa_input, +) + + +__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"] + + +torch._dynamo.allow_in_graph(is_flash_attention_available) +torch._dynamo.allow_in_graph(can_use_flash_attention) +torch._dynamo.allow_in_graph(can_use_efficient_attention) +torch._dynamo.allow_in_graph(SDPAParams) + + +class CausalVariant(IntEnum): + r""" + Enum for causal variants used in attention mechanisms. + + Defines two types of causal biases: + + ``UPPER_LEFT``: Represents upper-left triangular bias for standard causal attention. + The equivalent pytorch code for constructing this bias is: + + .. code-block:: python + + torch.tril(torch.ones(size, dtype=torch.bool)) + + For instance, with ``shape=(3,4)``, the materialized bias tensor will be: + + .. code-block:: text + + [[1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0]] + + + ``LOWER_RIGHT``: Represents lower-right triangular bias, the include values are aligned to the lower + right corner of the matrix. + + The equivalent pytorch code for constructing this bias is: + + .. code-block:: python + + diagonal_offset = size[1] - size[0] + torch.tril( + torch.ones(size, dtype=torch.bool), + diagonal=diagonal_offset, + ) + + For instance, with ``shape=(3,4)``, the materialized bias tensor will be: + + .. code-block:: text + + [[1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]] + + Note that these variants are equivalent to each other when the sequence lengths of the query and key/value + tensors are equal since the triangular matrix is square. + + .. warning:: This enum is a prototype and subject to change. + """ + + UPPER_LEFT = auto() + LOWER_RIGHT = auto() + + +class CausalBias(torch.Tensor): + """ + A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum. + + This class is used for defining causal (triangular) attention biases. For construing the bias, there exist + two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`. + + Example: + + .. code-block:: python + + from torch.nn.attention.bias import causal_lower_right + + bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8 + + # Create a lower-right causal bias + attn_bias = causal_lower_right(seqlen_q, seqlen_kv) + + q = torch.randn( + bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16 + ) + k = torch.randn( + bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16 + ) + v = torch.randn( + bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16 + ) + + out = F.scaled_dot_product_attention(q, k, v, attn_bias) + + .. warning:: This class is a prototype and subject to change. + """ + + def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int) -> None: + """ + Initializes the CausalBias instance with a specified variant and sequence lengths. + + Args: + variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT). + seq_len_q (int): The sequence length of the query tensor. + seq_len_kv (int): The sequence length of the key/value tensor. + + Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs. + """ + assert isinstance(variant, CausalVariant) + super().__init__() + self.variant = variant + self.seq_len_q = seq_len_q + self.seq_len_kv = seq_len_kv + if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT: + warn( + "Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!", + stacklevel=2, + ) + + def _upper_left(self, device: torch.device) -> torch.Tensor: + """Upper left causal bias""" + return torch.tril( + torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool) + ) + + def _lower_right(self, device: torch.device) -> torch.Tensor: + """Lower right causal bias""" + diagonal_offset = self.seq_len_kv - self.seq_len_q + return torch.tril( + torch.ones( + self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool + ), + diagonal=diagonal_offset, + ) + + # pyrefly: ignore [bad-return] + def _materialize(self, device: torch.device | None = None) -> torch.Tensor: + """ + Materializes the causal bias into a tensor form. + + Depending on the variant, this method generates either an upper-left or lower-right + triangular matrix to represent the causal bias. + + Args: + device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU. + + Returns: + torch.Tensor: The materialized bias tensor. + """ + if device is None: + device = torch.device("cpu") + if self.variant == CausalVariant.UPPER_LEFT: + return self._upper_left(device) + elif self.variant == CausalVariant.LOWER_RIGHT: + return self._lower_right(device) + + @staticmethod + def _dispatch( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: "CausalBias", + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float | None = None, + enable_gqa: bool = False, + ) -> torch.Tensor: + r""" + Handles the logic for computing attention with the specified causal bias. + + Args: + query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`. + key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`. + value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`. + attn_mask (CausalBias): The type of causal attention to apply. + A boolean mask where a value of True indicates that the element *should* take part in attention. + A float mask of the same type as query, key, value that is added to the attention score. + dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied + is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal + are set. + scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set + to :math:`\frac{1}{\sqrt{E}}`. + enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False. + + Returns: + output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`. + + Raises: + ValueError: If the causal bias variant is not a CausalVariant type. + + """ + if is_causal: + raise ValueError("CausalBias should not be used with causal=True") + + if ( + attn_mask.seq_len_q == attn_mask.seq_len_kv + or attn_mask.variant == CausalVariant.UPPER_LEFT + ): + return F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=dropout_p, + is_causal=True, + scale=scale, + enable_gqa=enable_gqa, + ) + elif attn_mask.variant == CausalVariant.LOWER_RIGHT: + _validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale) + sdpa_params = SDPAParams( + query, key, value, None, dropout_p, is_causal, enable_gqa + ) + if can_use_flash_attention(sdpa_params): + alignment = 64 if query.device.type == "xpu" else 8 + og_head_size = query.size(-1) + og_scale = _calculate_scale(og_head_size, scale) + needs_padding = og_head_size % alignment != 0 + if needs_padding: + pad_len = alignment - (og_head_size % alignment) + query = torch.nn.functional.pad(query, (0, pad_len)) + key = torch.nn.functional.pad(key, (0, pad_len)) + value = torch.nn.functional.pad(value, (0, pad_len)) + out = torch.ops.aten._scaled_dot_product_flash_attention( + query, + key, + value, + dropout_p, + is_causal=True, # TODO: Flash accepts causal = True and for this particular op it means lower right + return_debug_mask=False, + scale=og_scale, + )[0] + return _postprocess_flash_output(out, og_head_size) + if can_use_efficient_attention(sdpa_params): + compute_log_sumexp = False + if _input_requires_grad(query, key, value): + compute_log_sumexp = True + return torch.ops.aten._efficient_attention_forward( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + bias=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + dropout_p=dropout_p, + custom_mask_type=int(attn_mask.variant), + compute_log_sumexp=compute_log_sumexp, + scale=scale, + seqlen_k=None, + )[0].transpose(1, 2) + else: + _raise_kernel_warnings(sdpa_params) + # We can't use efficient attention the only support for lower right is via materialization + return F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask._materialize(query.device), + dropout_p=dropout_p, + is_causal=False, + scale=scale, + enable_gqa=enable_gqa, + ) + else: + raise ValueError( + f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}" + ) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + """Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias""" + if kwargs is None: + kwargs = {} + if func is torch.nn.functional.scaled_dot_product_attention: + return cls._dispatch(*args, **kwargs) + return super().__torch_function__(func, types, args, kwargs) + + def __repr__(self) -> str: # type:ignore[override] + return self._materialize().__repr__() + + +def causal_upper_left(*size) -> CausalBias: + """ + Creates an upper-left triangular causal bias. + + This function generates a upper-left triangular matrix to represent causal attention bias with a + diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix. + This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`. + + The equivalent pytorch code for constructing this bias is: + + .. code-block:: python + + torch.tril(torch.ones(size, dtype=torch.bool)) + + For instance, with `shape=(3,4)`, the materialized bias tensor will be: + + .. code-block:: text + + [[1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0]] + + Args: + size: The size of the bias matrix. + + Returns: + CausalBias: The UPPER_LEFT triangular causal bias variant. + """ + assert len(size) == 2, "causal_upper_left only supports 2D tensors" + seq_len_q, seq_len_kv = size + return CausalBias(CausalVariant.UPPER_LEFT, seq_len_q, seq_len_kv) + + +def causal_lower_right(*size) -> CausalBias: + """ + Creates a lower-right triangular causal bias. + + This function generates a lower-right triangular matrix to represent causal attention bias with a + diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix. + + The equivalent pytorch code for constructing this bias is: + + .. code-block:: python + + diagonal_offset = size[1] - size[0] + torch.tril( + torch.ones(size, dtype=torch.bool), + diagonal=diagonal_offset, + ) + + For instance, with `shape=(3,4)`, the materialized bias tensor will be: + + .. code-block:: text + + [[1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]] + + Args: + size: The size of the bias matrix. + + Returns: + CausalBias: The LOWER_RIGHT triangular causal bias variant. + """ + assert len(size) == 2, "causal_lower_right only supports 2D tensors" + seq_len_q, seq_len_kv = size + return CausalBias(CausalVariant.LOWER_RIGHT, seq_len_q, seq_len_kv) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/experimental/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6694bbe3990bacac6025e8c8bd4ab86e80d2e9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/experimental/__init__.py @@ -0,0 +1,2 @@ +# Experimental features are not mature yet and are subject to change. +# We do not provide any BC/FC guarantees diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/experimental/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/experimental/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4909ae836f3f959e750d29778ea35653c994e436 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/experimental/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/experimental/__pycache__/_paged_attention.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/experimental/__pycache__/_paged_attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd81e558660798f746dddc0de3b860fa240acd70 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/experimental/__pycache__/_paged_attention.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/experimental/_paged_attention.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/experimental/_paged_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..5bbbc2b78aa6ab54983965458d1901dc4e1a1bb1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/experimental/_paged_attention.py @@ -0,0 +1,354 @@ +# mypy: allow-untyped-defs +""" +This module implements Paged Attention on top of flex_attention. +This module is experimental and subject to change. +""" + +import torch +from torch.nn.attention.flex_attention import ( + _identity, + _mask_mod_signature, + _score_mod_signature, + BlockMask, + noop_mask, +) + + +__all__ = ["PagedAttention"] + + +def _cdiv(x: int | float | torch.Tensor, multiple: int | float | torch.Tensor): + return (x + multiple - 1) // multiple + + +class PagedAttention: + """ + PagedAttention supports flex attention inference with a large batch size. + With PagedAttention, a batch of key/value tensors with varying kv length + is split into tensor blocks of fixed length and cached in a compact way. + Thus we can avoid redundant memory consumption due to varying kv length and + support a larger batch size. + """ + + def __init__( + self, + n_pages: int, + page_size: int, + max_batch_size: int, + device: str = "cuda", + ) -> None: + # number of pages + self.n_pages = n_pages + + # number of tokens per page + self.page_size = page_size + + # page table: [batch, logical_block_idx] -> physical_page_idx + self.page_table = -torch.ones( + (max_batch_size, self.n_pages), dtype=torch.int64, device=device + ) + + # capacity: batch_idx -> allocated sequence length + self.capacity = torch.zeros(max_batch_size, dtype=torch.int64, device=device) + + # index of empty pages that is available for allocation + self.empty_pages = list(range(n_pages - 1, -1, -1)) + + # mapping from physical page index to logical page index + self.physical_to_logical = -torch.ones( + (max_batch_size, n_pages), dtype=torch.int64, device=device + ) + + def reserve(self, batch_idx: torch.Tensor, seq_len: torch.Tensor) -> None: + """ + Requests the capacity of a given batch to be at least enough to + hold `seq_len` elements. + + Args: + batch_idx (Tensor): batch index to be reserved; shape :math:`(1)`. + seq_len (Tensor): minimum capacity for the given batch; shape :math:`(1)`. + """ + + if seq_len <= self.capacity[batch_idx]: + return + + num_pages_to_allocate = _cdiv( + seq_len - self.capacity[batch_idx], self.page_size + ) + + assert len(self.empty_pages) >= num_pages_to_allocate, ( + f"requested {num_pages_to_allocate.item()} pages " + f"but there are only {len(self.empty_pages)} empty pages" + ) + + start_page_idx = self.capacity[batch_idx] // self.page_size + end_page_idx = start_page_idx + num_pages_to_allocate + + # find empty physical pages + allocated_pages = torch.tensor( + self.empty_pages[-num_pages_to_allocate:], + device=num_pages_to_allocate.device, + ) + self.empty_pages = self.empty_pages[:-num_pages_to_allocate] + + # update page table + self.page_table[ + batch_idx, + start_page_idx:end_page_idx, + ] = allocated_pages + + # update metadata + self.physical_to_logical[batch_idx, allocated_pages] = torch.arange( + start_page_idx.item(), + end_page_idx.item(), + device=num_pages_to_allocate.device, + ) + self.capacity[batch_idx] += num_pages_to_allocate * self.page_size + + def erase(self, batch_idx: torch.Tensor) -> None: + """ + Removes a single batch from paged attention. + + Args: + batch_idx (Tensor): batch index to be removed; shape :math:`(1)`. + """ + + # find allocated pages + allocated_page_idx = self.page_table[batch_idx] != -1 + allocated_pages = self.page_table[batch_idx][allocated_page_idx] + + # clean metadata + self.capacity[batch_idx] = 0 + self.empty_pages += allocated_pages.tolist() + self.physical_to_logical[batch_idx][:, allocated_pages] = -1 + self.page_table[batch_idx] = -1 + + def assign( + self, + batch_idx: torch.Tensor, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + ) -> None: + """ + Assigns new contents `val` to the storage `cache` at the location + `batch_idx` and `input_pos`. + + Args: + batch_idx (Tensor): batch index; shape :math:`(B)`. + input_pos (Tensor): input positions to be assigned for the given batch; shape :math:`(B, S)`. + val (Tensor): value to be assigned; shape :math:`(B, H, S, D)` + cache (Tensor): the cache to store the values; shape:`(1, H, MAX_S, D)` + """ + if k_val.requires_grad: + raise RuntimeError("val must not require gradient") + + B, H, S, K_D = k_val.shape + V_D = v_val.shape[3] + if B != batch_idx.shape[0]: + raise RuntimeError( + f"Expect val and batch_idx have the same batch size " + f"but got B={B} and B={batch_idx.shape[0]}." + ) + if H != k_cache.shape[1]: + raise RuntimeError( + f"Expect val and cache has the same number of heads " + f"but got H={H} and H={k_cache.shape[1]}." + ) + if S != input_pos.shape[1]: + raise RuntimeError( + f"Expect val and input_pos has the same length " + f"but got S={S} and S={input_pos.shape[0]}." + ) + if K_D != k_cache.shape[3]: + raise RuntimeError( + f"Expect k_val and k_cache has the same hidden dim " + f"but got D={K_D} and D={k_cache.shape[3]}." + ) + if V_D != v_cache.shape[3]: + raise RuntimeError( + f"Expect v_val and v_cache has the same hidden dim " + f"but got D={V_D} and D={v_cache.shape[3]}." + ) + + # find address + logical_block_idx = input_pos // self.page_size # [B, S] + logical_block_offset = input_pos % self.page_size # [B, S] + physical_block_idx = torch.gather( + self.page_table[batch_idx], 1, logical_block_idx.to(torch.int64) + ).to(torch.int32) # [B, S] + + addr = (physical_block_idx * self.page_size + logical_block_offset).view( + -1 + ) # [B*S] + + k_val = k_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, K_D) + v_val = v_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, V_D) + + k_cache[:, :, addr, :] = k_val + v_cache[:, :, addr, :] = v_val + + def convert_logical_block_mask( + self, + block_mask: BlockMask, + batch_idx: torch.Tensor | None = None, + kv_len: torch.Tensor | None = None, + ) -> BlockMask: + """ + Converts a logical block mask by mapping its logical kv indices to the corresponding + physical kv indices. + + Args: + block_mask (BlockMask): logical block mask; + kv_indices shape :math:`(B, H, ROWS, MAX_BLOCKS_IN_COL)`. + batch_idx (Tensor): batch index corresponding to the block_mask + batch dimension. This provides flexibility to convert a + block mask with smaller batch size than the page table; + shape :math:`(B)`. + kv_len (Optional[Tensor]): actual KV sequence length for upper bound check; + shape :math:`(B,)` to handle multiple batches. + """ + B, H, ROWS, MAX_BLOCKS_IN_COL = block_mask.kv_indices.shape + + if block_mask.BLOCK_SIZE[1] != self.page_size: + raise RuntimeError( + f"Expect block_mask has the same column block size as page_size" + f"but got size={block_mask.BLOCK_SIZE[1]} and size={self.page_size}" + ) + + # Increase the num columns of converted block mask from logical block mask's + # num columns to n_pages, since a) the converted block mask + # may have larger indices values; and b) `_ordered_to_dense` realizes + # a dense tensor with these converted indices. There would be an IndexError + # if using the logical block mask's num columns. + + device = block_mask.kv_num_blocks.device + + if batch_idx is None: + batch_idx = torch.arange(B, device=device) + page_table = self.page_table[batch_idx] + + new_kv_num_blocks = block_mask.kv_num_blocks.clone() + + new_kv_indices = torch.zeros( + (B, H, ROWS, self.n_pages), dtype=torch.int32, device=device + ) + new_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = ( + torch.gather( + page_table, 1, block_mask.kv_indices.view(B, -1).to(torch.int64) + ) + .view(block_mask.kv_indices.shape) + .to(torch.int32) + ) + + new_full_kv_indices, new_full_kv_num_blocks = None, None + if block_mask.full_kv_num_blocks is not None: + assert block_mask.full_kv_indices is not None + new_full_kv_num_blocks = block_mask.full_kv_num_blocks.clone() + new_full_kv_indices = torch.zeros( + (B, H, ROWS, self.n_pages), dtype=torch.int32, device=device + ) + new_full_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = ( + torch.gather( + page_table, + 1, + block_mask.full_kv_indices.view(B, -1).to(torch.int64), + ) + .view(block_mask.full_kv_indices.shape) + .to(torch.int32) + ) + + new_mask_mod = self.get_mask_mod(block_mask.mask_mod, kv_len) + + seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size) + return BlockMask.from_kv_blocks( + new_kv_num_blocks, + new_kv_indices, + new_full_kv_num_blocks, + new_full_kv_indices, + block_mask.BLOCK_SIZE, + new_mask_mod, + seq_lengths=seq_lengths, + ) + + def get_mask_mod( + self, + mask_mod: _mask_mod_signature | None, + kv_len: torch.Tensor | None = None, + ) -> _mask_mod_signature: + """ + Converts a mask_mod based on mapping from the physical block index to the logical + block index. + + Args: + mask_mod (_mask_mod_signature): mask_mod based on the logical block index. + kv_len (Optional[torch.Tensor]): actual KV sequence length for upper bound check. + """ + if mask_mod is None: + mask_mod = noop_mask + + def new_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ): + physical_kv_block = physical_kv_idx // self.page_size + physical_kv_offset = physical_kv_idx % self.page_size + logical_block_idx = self.physical_to_logical[b, physical_kv_block] + logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset + live_block = logical_block_idx >= 0 + within_upper_bound = ( + logical_kv_idx < kv_len[b] if kv_len is not None else True + ) + within_lower_bound = logical_kv_idx >= 0 + is_valid = live_block & within_upper_bound & within_lower_bound + + return torch.where(is_valid, mask_mod(b, h, q_idx, logical_kv_idx), False) + + return new_mask_mod + + def get_score_mod( + self, + score_mod: _score_mod_signature | None, + kv_len: torch.Tensor | None = None, + ) -> _score_mod_signature: + """ + Converts a score_mod based on mapping from the physical block index to the logical + block index. + + Args: + score_mod (_score_mod_signature): score_mod based on the logical block index. + `kv_len (Optional[torch.Tensor]): actual KV sequence length for upper bound check. + + """ + if score_mod is None: + score_mod = _identity + + def new_score_mod( + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ): + physical_kv_block = physical_kv_idx // self.page_size + physical_kv_offset = physical_kv_idx % self.page_size + logical_block_idx = self.physical_to_logical[b, physical_kv_block] + logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset + live_block = logical_block_idx >= 0 + within_upper_bound = ( + logical_kv_idx < kv_len[b] if kv_len is not None else True + ) + within_lower_bound = logical_kv_idx >= 0 + is_valid = live_block & within_upper_bound & within_lower_bound + + return torch.where( + is_valid, + score_mod(score, b, h, q_idx, logical_kv_idx), + float("-inf"), + ) + + return new_score_mod diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ad922227ccff80de42fcefe74c52ea861124add4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py @@ -0,0 +1,1676 @@ +# mypy: allow-untyped-defs +# flake8: noqa: B950 +"""This module implements the user facing API for flex_attention in PyTorch.""" + +import functools +import inspect +import itertools +import math +import operator +import typing +import warnings +from collections.abc import Callable +from enum import Enum +from typing import Any, Literal, NamedTuple, TypeAlias + +import torch +from torch import Tensor + + +try: + from typing import TypedDict +except ImportError: + from typing_extensions import TypedDict + +try: + from typing import NotRequired +except ImportError: + from typing_extensions import NotRequired + +from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop +from torch._higher_order_ops.utils import _set_compilation_env +from torch._prims_common import DeviceLikeType +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + _temp_remove_pre_dispatch_torch_function_mode, +) +from torch.nn.attention._utils import _validate_sdpa_input +from torch.utils._pytree import GetAttrKey, tree_map_only + + +# Private debug flag to disable internal compilation wrapping for debugging purposes. +# WARNING: This is intended ONLY for debugging score_mod and mask_mod functions. +# When enabled, this bypasses the required internal compilation that ensures correctness +# and performance. Only use this temporarily when you need to set breakpoints +# in your score_mod/mask_mod functions during development. +# +# This flag only affects the internal compilation when flex_attention is called directly. +# If you have already wrapped flex_attention in torch.compile(), this flag has no effect +# and the user's compilation will still occur. +# +# Usage: +# import torch.nn.attention.flex_attention as fa +# fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True +# # Now you can set breakpoints in your score_mod/mask_mod +# output = fa.flex_attention(q, k, v, score_mod=my_score_mod) +# +_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = False + +_WARNINGS_SHOWN: set[str] = set() + + +def _warn_once( + warning_id: str, message: str, category: type[Warning] = UserWarning +) -> None: + """Helper to ensure each warning is shown only once per process.""" + if warning_id not in _WARNINGS_SHOWN: + warnings.warn(message, category, stacklevel=2) + _WARNINGS_SHOWN.add(warning_id) + + +__all__ = [ + "BlockMask", + "flex_attention", + "AuxOutput", + "AuxRequest", + "FlexKernelOptions", + "create_block_mask", + "create_mask", + "or_masks", + "and_masks", + "noop_mask", +] + +_score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor] +_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] +_Backend: TypeAlias = Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"] + + +# pyrefly: ignore [invalid-inheritance] +class FlexKernelOptions(TypedDict, total=False): + """Options for controlling the behavior of FlexAttention kernels. + + These options are passed to the underlying Triton kernels to control performance + and numerical behavior. Most users will not need to specify these options as the + default autotuning provides good performance. + + The options can be prefixed with ``fwd_`` or ``bwd_`` to apply only to forward or + backward pass respectively. For example: ``fwd_BLOCK_M`` and ``bwd_BLOCK_M1``. + + Note: + We currently do not provide any backward compatibility guarantees for these options. + That being said most of these have remained pretty stable since their introduction. But + We do not consider this part of the public API just yet. We think that some documentation + Is better than secret hidden flags, but we may change these options in the future. + + Example Usage: + .. code-block:: python + + # Using dictionary (backward compatible) + kernel_opts = {"BLOCK_M": 64, "BLOCK_N": 64, "PRESCALE_QK": True} + output = flex_attention(q, k, v, kernel_options=kernel_opts) + + # Using TypedDict (recommended for type safety) + from torch.nn.attention.flex_attention import FlexKernelOptions + + kernel_opts: FlexKernelOptions = { + "BLOCK_M": 64, + "BLOCK_N": 64, + "PRESCALE_QK": True, + } + output = flex_attention(q, k, v, kernel_options=kernel_opts) + + # Forward/backward specific options + kernel_opts: FlexKernelOptions = { + "fwd_BLOCK_M": 64, + "bwd_BLOCK_M1": 32, + "PRESCALE_QK": False, + } + output = flex_attention(q, k, v, kernel_options=kernel_opts) + """ + + # Performance tuning options + # pyrefly: ignore [invalid-annotation] + num_warps: NotRequired[int] + """Number of warps to use in the CUDA kernel. Higher values may improve performance + but increase register pressure. Default is determined by autotuning.""" + + # pyrefly: ignore [invalid-annotation] + num_stages: NotRequired[int] + """Number of pipeline stages in the CUDA kernel. Higher values may improve performance + but increase shared memory usage. Default is determined by autotuning.""" + + # pyrefly: ignore [invalid-annotation] + BLOCK_M: NotRequired[int] + """Thread block size for the sequence length dimension of Q in forward pass. + Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning.""" + + # pyrefly: ignore [invalid-annotation] + BLOCK_N: NotRequired[int] + """Thread block size for the sequence length dimension of K/V in forward pass. + Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning.""" + + # Backward-specific block sizes (when prefixed with 'bwd_') + # pyrefly: ignore [invalid-annotation] + BLOCK_M1: NotRequired[int] + """Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'. + Default is determined by autotuning.""" + + # pyrefly: ignore [invalid-annotation] + BLOCK_N1: NotRequired[int] + """Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'. + Default is determined by autotuning.""" + + # pyrefly: ignore [invalid-annotation] + BLOCK_M2: NotRequired[int] + """Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'. + Default is determined by autotuning.""" + + # pyrefly: ignore [invalid-annotation] + BLOCK_N2: NotRequired[int] + """Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'. + Default is determined by autotuning.""" + + # pyrefly: ignore [invalid-annotation] + PRESCALE_QK: NotRequired[bool] + """Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but + may have more numerical error. Default: False.""" + + # pyrefly: ignore [invalid-annotation] + ROWS_GUARANTEED_SAFE: NotRequired[bool] + """If True, guarantees that at least one value in each row is not masked out. + Allows skipping safety checks for better performance. Only set this if you are certain + your mask guarantees this property. For example, causal attention is guaranteed safe + because each query has at least 1 key-value to attend to. Default: False.""" + + # pyrefly: ignore [invalid-annotation] + BLOCKS_ARE_CONTIGUOUS: NotRequired[bool] + """If True, guarantees that all blocks in the mask are contiguous. + Allows optimizing block traversal. For example, causal masks would satisfy this, + but prefix_lm + sliding window would not. Default: False.""" + + # pyrefly: ignore [invalid-annotation] + WRITE_DQ: NotRequired[bool] + """Controls whether gradient scatters are done in the DQ iteration loop of the backward pass. + Setting this to False will force this to happen in the DK loop which depending on your + specific score_mod and mask_mod might be faster. Default: True.""" + + # pyrefly: ignore [invalid-annotation] + FORCE_USE_FLEX_ATTENTION: NotRequired[bool] + """If True, forces the use of the flex attention kernel instead of potentially using + the more optimized flex-decoding kernel for short sequences. This can be a helpful + option for debugging. Default: False.""" + + # pyrefly: ignore [invalid-annotation] + USE_TMA: NotRequired[bool] + """Whether to use Tensor Memory Accelerator (TMA) on supported hardware. + This is experimental and may not work on all hardware, currently specific + to NVIDIA GPUs Hopper+. Default: False.""" + + # ROCm-specific options + # pyrefly: ignore [invalid-annotation] + kpack: NotRequired[int] + """ROCm-specific kernel packing parameter.""" + + # pyrefly: ignore [invalid-annotation] + matrix_instr_nonkdim: NotRequired[int] + """ROCm-specific matrix instruction non-K dimension.""" + + # pyrefly: ignore [invalid-annotation] + waves_per_eu: NotRequired[int] + """ROCm-specific waves per execution unit.""" + + # pyrefly: ignore [invalid-annotation] + BACKEND: NotRequired[_Backend] + """Selects a specific kernel backend. + + Options: + - "AUTO": Use current heuristics (typically Triton-based kernels with + automatic selection between flex_attention and flex_decoding) + - "TRITON": Standard Triton flex_attention kernel + - "TRITON_DECODE": Triton flex_decoding kernel, only available for short sequence lengths with specific configurations + - "FLASH": Experimental: Flash Attention kernel (cute-dsl), user needs to have flash installed + + This option cannot be combined with legacy knobs such as ``FORCE_USE_FLEX_ATTENTION``. + Raises an error if the requested backend cannot be used. Default: "AUTO" + """ + + +class AuxRequest(NamedTuple): + """Request which auxiliary outputs to compute from flex_attention. + + Each field is a boolean indicating whether that auxiliary output should be computed. + """ + + lse: bool = False + max_scores: bool = False + + +class AuxOutput(NamedTuple): + """Auxiliary outputs from flex_attention operation. + + Fields will be None if not requested, or contain the tensor if requested. + """ + + lse: Tensor | None = None + max_scores: Tensor | None = None + + +class _ModificationType(Enum): + """Enum for the type of modification function. + - SCORE_MOD: score_mod function which accepts a score as the first argument + - mask_mod: mask function which does not accept a score and is only used for generating + block mask + """ + + SCORE_MOD = 1 + MASK_MOD = 2 + UNKNOWN = 3 + + +def _get_mod_type(fn: Callable) -> _ModificationType: + """Get the type of modification function. + This function inspects the number of positional arguments of the function to determine + the type of modification function. If the function has 5 positional arguments, it is + considered as a score_mod function. If the function has 4 positional arguments, it is + considered as a mask function. + """ + if hasattr(fn, "__code__"): + code = fn.__code__ + num_positional_total = code.co_argcount + defaults = () + if hasattr(fn, "__defaults__"): + defaults = fn.__defaults__ or () + num_defaults = len(defaults) + num_positional_args = num_positional_total - num_defaults + else: + num_positional_args = sum( + 1 + for param in inspect.signature(fn).parameters.values() + if param.default is inspect.Parameter.empty + ) + assert num_positional_args == 5 or num_positional_args == 4 + if num_positional_args == 5: + return _ModificationType.SCORE_MOD + elif num_positional_args == 4: + return _ModificationType.MASK_MOD + else: + return _ModificationType.UNKNOWN + + +# Need to define it here so that Dynamo doesn't skip it +def _vmap_for_bhqkv( + fn: Callable, + prefix: tuple[int | None, ...], + suffix: tuple[int | None, ...] = (), + out_dims: int | list[int | None] = 0, + group_dim: bool = False, +): + """Used to vmap both score_mods and mask_mods over 4-dimensional/5-dimension inputs. + Mapping over the [b, hq, q_idx, kv_idx] or [b, hkv, g, q_idx, kv_idx] dimensions. + + Args: + fn (callable): The function to vmap. + prefix (tuple): The prefix of the vmap. For score mod functions, + this should be set to (0,). For mask_mods = () + suffix (tuple): We need to add (0,) if gradOut is being mapped over, + and (None,) * len(other_buffers). + out_dims (tuple): For forward cases, keep this as the default 0 since + we are only returning 1 output. For backwards, the joint + graph returns grads for B, H, Q_idx, KV_idx and other_buffers, + so we set this to (0, None, None, None, None) + (None,) * len(other_buffers). + + Returns: + callable: The vmapped function. + """ + # We vamp a function 4 times, broadcasting the [b, h, q_idx, kv_idx] dimensions + dimensions: list[tuple[None | int, None | int, None | int, None | int]] = [] + dimensions = [ + (None, None, None, 0), + (None, None, 0, None), + (None, 0, None, None), + ] + + if group_dim: + dimensions += [ + (None, 0, None, None), + ] + + dimensions += [ + (0, None, None, None), + ] + + for dims in dimensions: + fn = torch.vmap(fn, in_dims=prefix + dims + suffix, out_dims=out_dims) # type: ignore[arg-type] + return fn + + +def _identity( + score: Tensor, + batch: Tensor, + head: Tensor, + token_q: Tensor, + token_kv: Tensor, +) -> Tensor: + return score + + +def noop_mask( + batch: Tensor, + head: Tensor, + token_q: Tensor, + token_kv: Tensor, +) -> Tensor: + """Returns a noop mask_mod""" + return batch.new_ones(size=(), dtype=torch.bool, device=batch.device) + + +def _sliced_mask_mod_error( + batch: Tensor, + head: Tensor, + token_q: Tensor, + token_kv: Tensor, +) -> Tensor: + """ + Raises helpful error when using mask_mod from a sliced BlockMask. + + After slicing a BlockMask, the mask_mod is reset and cannot be used directly. + Users must reassign mask_mod from the original (unsliced) BlockMask. + """ + raise RuntimeError( + "Cannot use mask_mod from a sliced BlockMask. " + "When you slice a BlockMask using [], the mask_mod attribute is reset. " + "You must set it from the original BlockMask's mask_mod." + "\n\nIncorrect usage:" + "\n base_mask = create_block_mask(my_mask_fn, ...)" + "\n sliced_mask = base_mask[:, :, block_idx]" + "\n sliced_mask.mask_mod = apply_offset(sliced_mask.mask_mod, offset) # WRONG!" + "\n\nCorrect usage:" + "\n base_mask = create_block_mask(my_mask_fn, ...)" + "\n sliced_mask = base_mask[:, :, block_idx]" + "\n sliced_mask.mask_mod = apply_offset(base_mask.mask_mod, offset) # Use base_mask!" + ) + + +_DEFAULT_SPARSE_BLOCK_SIZE = 128 +_LARGE_SPARSE_BLOCK_SIZE = 1 << 30 + + +def _ordered_to_dense(num_blocks_in_row: Tensor, col_indices: Tensor): + num_rows = col_indices.shape[-2] + num_cols = col_indices.shape[-1] + batch_dims = num_blocks_in_row.shape[:-1] + device = num_blocks_in_row.device + + def create_dense_one(kv_num_blocks, kv_indices): + dense_mask = kv_indices.new_zeros(num_rows, num_cols + 1, dtype=torch.int32) + + row_indices = torch.arange(num_rows, dtype=torch.int, device=device).unsqueeze( + -1 + ) + col_range = torch.arange(num_cols, dtype=torch.int, device=device) + index_mask = col_range < kv_num_blocks.unsqueeze(-1) + + # We write to one spot "out of bounds" + valid_indices = torch.where(index_mask, kv_indices, num_cols) + + # set the values in 'a' to 1 where the indices are valid + dense_mask[row_indices, valid_indices] = dense_mask.new_ones(()) + return dense_mask[:, :num_cols].contiguous() + + create_dense_batched = create_dense_one + for _ in range(len(batch_dims)): + create_dense_batched = torch.vmap(create_dense_batched, in_dims=(0, 0)) + + out = create_dense_batched(num_blocks_in_row, col_indices) + return out + + +def _dense_to_ordered(dense_mask) -> tuple[Tensor, Tensor]: + dense_mask = dense_mask.to(dtype=torch.int32) + num_blocks_in_row = dense_mask.sum(dim=-1) + col_indices = torch.argsort(dense_mask, dim=-1, descending=True, stable=True) + return ( + num_blocks_in_row.to(torch.int32, memory_format=torch.contiguous_format), + col_indices.to(torch.int32, memory_format=torch.contiguous_format), + ) + + +def _transpose_ordered(num_blocks_in_row: Tensor, col_indices: Tensor): + dense = _ordered_to_dense(num_blocks_in_row, col_indices) + return _dense_to_ordered(dense.transpose(-2, -1)) + + +def _adjust_num_blocks_and_indices( + num_blocks: Tensor, + indices: Tensor, + new_num_rows: int, + new_num_cols: int, +): + indices = indices[:, :, :new_num_rows, :new_num_cols] + num_blocks = num_blocks[:, :, :new_num_rows] + num_blocks = torch.where(num_blocks < new_num_cols, num_blocks, new_num_cols) + num_blocks = torch.sum(indices < num_blocks[:, :, :, None], dim=-1).to(torch.int32) + return num_blocks, indices + + +class BlockMask: + r""" + BlockMask is our format for representing a block-sparse attention mask. + It is somewhat of a cross in-between BCSR and a non-sparse format. + + **Basics** + + A block-sparse mask means that instead of representing the sparsity of + individual elements in the mask, a KV_BLOCK_SIZE x Q_BLOCK_SIZE block is + considered sparse only if every element within that block is sparse. + This aligns well with hardware, which generally expects to perform + contiguous loads and computation. + + This format is primarily optimized for 1. simplicity, and 2. kernel + efficiency. Notably, it is *not* optimized for size, as this mask is always + reduced by a factor of KV_BLOCK_SIZE * Q_BLOCK_SIZE. If the size is a + concern, the tensors can be reduced in size by increasing the block size. + + The essentials of our format are: + + num_blocks_in_row: Tensor[ROWS]: + Describes the number of blocks present in each row. + + col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]: + `col_indices[i]` is the sequence of block positions for row i. The values of + this row after `col_indices[i][num_blocks_in_row[i]]` are undefined. + + For example, to reconstruct the original tensor from this format: + + .. code-block:: python + + dense_mask = torch.zeros(ROWS, COLS) + for row in range(ROWS): + for block_idx in range(num_blocks_in_row[row]): + dense_mask[row, col_indices[row, block_idx]] = 1 + + Notably, this format makes it easier to implement a reduction along the + *rows* of the mask. + + **Details** + + The basics of our format require only kv_num_blocks and kv_indices. But, we + have up to 8 tensors on this object. This represents 4 pairs: + + 1. (kv_num_blocks, kv_indices): Used for the forwards pass of attention, as + we reduce along the KV dimension. + + 2. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): This is optional and + purely an optimization. As it turns out, applying masking to every block + is quite expensive! If we specifically know which blocks are "full" and + don't require masking at all, then we can skip applying mask_mod to these + blocks. This requires the user to split out a separate mask_mod from the + score_mod. For causal masks, this is about a 15% speedup. + + 3. [GENERATED] (q_num_blocks, q_indices): Required for the backwards pass, + as computing dKV requires iterating along the mask along the Q dimension. These are autogenerated from 1. + + 4. [GENERATED] (full_q_num_blocks, full_q_indices): Same as above, but for + the backwards pass. These are autogenerated from 2. + """ + + seq_lengths: tuple[int, int] + kv_num_blocks: Tensor + kv_indices: Tensor + full_kv_num_blocks: Tensor | None + full_kv_indices: Tensor | None + q_num_blocks: Tensor | None + q_indices: Tensor | None + full_q_num_blocks: Tensor | None + full_q_indices: Tensor | None + BLOCK_SIZE: tuple[int, int] + mask_mod: _mask_mod_signature + + # Attribute lists for pytree flatten/unflatten + _TENSOR_ATTRS = [ + "kv_num_blocks", + "kv_indices", + "full_kv_num_blocks", + "full_kv_indices", + "q_num_blocks", + "q_indices", + "full_q_num_blocks", + "full_q_indices", + ] + + _CONTEXT_ATTRS = [ + "seq_lengths", + "BLOCK_SIZE", + "mask_mod", + ] + + def __init__( + self, + seq_lengths: tuple[int, int], + kv_num_blocks: Tensor, + kv_indices: Tensor, + full_kv_num_blocks: Tensor | None, + full_kv_indices: Tensor | None, + q_num_blocks: Tensor | None, + q_indices: Tensor | None, + full_q_num_blocks: Tensor | None, + full_q_indices: Tensor | None, + BLOCK_SIZE: tuple[int, int], + mask_mod: _mask_mod_signature, + ) -> None: + if kv_indices.dim() < 2: + raise RuntimeError("BlockMask must have at least 2 dimensions") + assert kv_num_blocks is not None, "kv_num_blocks must be provided" + assert kv_indices is not None, "kv_indices must be provided" + assert (full_kv_num_blocks is None) == (full_kv_indices is None), ( + "full_kv_num_blocks and full_kv_indices must be both provided or omitted" + ) + assert (full_q_num_blocks is None) == (full_q_indices is None), ( + "full_q_num_blocks and full_q_indices must be both provided or omitted" + ) + + self.seq_lengths = seq_lengths + self.kv_num_blocks = kv_num_blocks + self.kv_indices = kv_indices + self.full_kv_num_blocks = full_kv_num_blocks + self.full_kv_indices = full_kv_indices + self.q_num_blocks = q_num_blocks + self.q_indices = q_indices + self.full_q_num_blocks = full_q_num_blocks + self.full_q_indices = full_q_indices + self.BLOCK_SIZE = BLOCK_SIZE + self.mask_mod = mask_mod + + @classmethod + def from_kv_blocks( + cls, + kv_num_blocks: Tensor, + kv_indices: Tensor, + full_kv_num_blocks: Tensor | None = None, + full_kv_indices: Tensor | None = None, + BLOCK_SIZE: int | tuple[int, int] = _DEFAULT_SPARSE_BLOCK_SIZE, + mask_mod: _mask_mod_signature | None = None, + seq_lengths: tuple[int, int] | None = None, + compute_q_blocks: bool = True, + ): + """ + Creates a BlockMask instance from key-value block information. + + Args: + kv_num_blocks (Tensor): Number of kv_blocks in each Q_BLOCK_SIZE row tile. + kv_indices (Tensor): Indices of key-value blocks in each Q_BLOCK_SIZE row tile. + full_kv_num_blocks (Optional[Tensor]): Number of full kv_blocks in each Q_BLOCK_SIZE row tile. + full_kv_indices (Optional[Tensor]): Indices of full key-value blocks in each Q_BLOCK_SIZE row tile. + BLOCK_SIZE (Union[int, tuple[int, int]]): Size of KV_BLOCK_SIZE x Q_BLOCK_SIZE tiles. + mask_mod (Optional[Callable]): Function to modify the mask. + + Returns: + BlockMask: Instance with full Q information generated via _transposed_ordered + + Raises: + RuntimeError: If kv_indices has < 2 dimensions. + AssertionError: If only one of full_kv_* args is provided. + """ + if kv_indices.dim() < 2: + raise RuntimeError("BlockMask must have at least 2 dimensions") + + assert (full_kv_num_blocks is None) == (full_kv_indices is None), ( + "full_kv_num_blocks and full_kv_indices must be both provided or omitted" + ) + + # Generate q_num_blocks and q_indices + if compute_q_blocks: + q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices) + if full_kv_num_blocks is not None: + assert full_kv_indices is not None + full_q_num_blocks, full_q_indices = _transpose_ordered( + full_kv_num_blocks, full_kv_indices + ) + else: + full_q_num_blocks, full_q_indices = None, None + else: + q_num_blocks, q_indices = None, None + full_q_num_blocks, full_q_indices = None, None + + if isinstance(BLOCK_SIZE, int): + BLOCK_SIZE = (BLOCK_SIZE, BLOCK_SIZE) + + mask_mod = mask_mod if mask_mod is not None else noop_mask + if seq_lengths is None: + q_length = kv_indices.shape[-2] * BLOCK_SIZE[0] + kv_length = kv_indices.shape[-1] * BLOCK_SIZE[1] + seq_lengths = (q_length, kv_length) + + return cls( + seq_lengths=seq_lengths, + kv_num_blocks=kv_num_blocks, + kv_indices=kv_indices, + full_kv_num_blocks=full_kv_num_blocks, + full_kv_indices=full_kv_indices, + q_num_blocks=q_num_blocks, + q_indices=q_indices, + full_q_num_blocks=full_q_num_blocks, + full_q_indices=full_q_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=mask_mod, + ) + + def as_tuple(self, flatten: bool = True): + """ + Returns a tuple of the attributes of the BlockMask. + + Args: + flatten (bool): If True, it will flatten the tuple of (KV_BLOCK_SIZE, Q_BLOCK_SIZE) + """ + if flatten: + block_size = (self.BLOCK_SIZE[0], self.BLOCK_SIZE[1]) # type: ignore[assignment] + seq_lengths = (self.seq_lengths[0], self.seq_lengths[1]) # type: ignore[assignment] + else: + block_size = (self.BLOCK_SIZE,) # type: ignore[assignment] + seq_lengths = (self.seq_lengths,) # type: ignore[assignment] + + # pyrefly: ignore [not-iterable] + return ( + *seq_lengths, + self.kv_num_blocks, + self.kv_indices, + self.full_kv_num_blocks, + self.full_kv_indices, + self.q_num_blocks, + self.q_indices, + self.full_q_num_blocks, + self.full_q_indices, + *block_size, + self.mask_mod, + ) + + @property + def shape(self): + *batch_dims, _, _ = self.kv_indices.shape + return tuple(batch_dims) + self.seq_lengths + + def __str__(self) -> str: + s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n" + mask_str = self.to_string().strip() + s += mask_str + s += "\n)" + return s + + def __getitem__(self, index) -> "BlockMask": + """ + Returns a new BlockMask instance by getting the mask for the given index position. + + Args: + index: Index to apply to all attributes. + + Example Usage: + .. code-block:: python + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + + block_mask = create_block_mask( + causal_mask, 4, 2, 512, 512, device="cuda" + ) + assert block_mask.kv_num_blocks.shape == (4, 2, 4) + assert block_mask.kv_indices.shape == (4, 2, 4, 4) + + # Index on batch dimension + new_block_mask = block_mask[0] + assert new_block_mask.kv_num_blocks.shape == (2, 4) + assert new_block_mask.kv_indices.shape == (2, 4, 4) + + # Index on batch and head dimension + new_block_mask = block_mask[0, 1] + assert new_block_mask.kv_num_blocks.shape == (4,) + assert new_block_mask.kv_indices.shape == (4, 4) + + # slicing on batch and head dimension + new_block_mask = block_mask[0:2, 1:2] + assert new_block_mask.kv_num_blocks.shape == (2, 1, 4) + assert new_block_mask.kv_indices.shape == (2, 1, 4, 4) + + # slicing on batch, head, and query dimension + new_block_mask = block_mask[ + 0:2, 1:2, torch.tensor([1], dtype=torch.int32) + ] + assert new_block_mask.kv_num_blocks.shape == (2, 1, 1) + assert new_block_mask.kv_indices.shape == (2, 1, 1, 4) + """ + index = (index,) if not isinstance(index, tuple) else index + padded = (*index, slice(None), slice(None), slice(None))[:3] + sizes = self.kv_num_blocks.shape[:3] + index = tuple( + (slice(i + n, i + n + 1) if -n <= i < 0 else slice(i, i + 1)) + if isinstance(i, int) + else i + for i, n in zip(padded, sizes, strict=True) + ) + new_kv_num_blocks = self.kv_num_blocks[index] + new_kv_indices = self.kv_indices[index] + if self.full_kv_num_blocks is not None: + assert self.full_kv_indices is not None + new_full_kv_num_blocks = self.full_kv_num_blocks[index] + new_full_kv_indices = self.full_kv_indices[index] + else: + new_full_kv_num_blocks = None + new_full_kv_indices = None + return BlockMask.from_kv_blocks( + new_kv_num_blocks, + new_kv_indices, + new_full_kv_num_blocks, + new_full_kv_indices, + BLOCK_SIZE=self.BLOCK_SIZE, + mask_mod=_sliced_mask_mod_error, + seq_lengths=self.seq_lengths, + compute_q_blocks=self.q_indices is not None, + ) + + def __repr__(self) -> str: + def shape_or_none(x: torch.Tensor | None): + return x.shape if x is not None else None + + return ( + f"BlockMask(\n" + f" kv_num_blocks={self.kv_num_blocks.shape},\n" + f" kv_indices={self.kv_indices.shape},\n" + f" full_kv_num_blocks={shape_or_none(self.full_kv_num_blocks)},\n" + f" full_kv_indices={shape_or_none(self.full_kv_indices)},\n" + f" q_num_blocks={shape_or_none(self.q_num_blocks)},\n" + f" q_indices={shape_or_none(self.q_indices)},\n" + f" full_q_num_blocks={shape_or_none(self.full_q_num_blocks)},\n" + f" full_q_indices={shape_or_none(self.full_q_indices)},\n" + f" BLOCK_SIZE={self.BLOCK_SIZE},\n" + f" shape={self.shape},\n" + f" sparsity={self.sparsity():.2f}%,\n" + f" mask_mod={self.mask_mod.__name__ if hasattr(self.mask_mod, '__name__') else self.mask_mod}\n" + f")" + ) + + def _adjust(self, new_q_len: int, new_kv_len: int): + new_num_rows = (new_q_len + self.BLOCK_SIZE[0] - 1) // self.BLOCK_SIZE[0] + new_num_cols = (new_kv_len + self.BLOCK_SIZE[1] - 1) // self.BLOCK_SIZE[1] + new_kv_num_blocks, new_kv_indices = _adjust_num_blocks_and_indices( + self.kv_num_blocks, self.kv_indices, new_num_rows, new_num_cols + ) + if self.full_kv_num_blocks is not None: + assert self.full_kv_indices is not None + ( + new_full_kv_num_blocks, + new_full_kv_indices, + ) = _adjust_num_blocks_and_indices( + self.full_kv_num_blocks, + self.full_kv_indices, + new_num_rows, + new_num_cols, + ) + else: + new_full_kv_num_blocks = None + new_full_kv_indices = None + return self.from_kv_blocks( + new_kv_num_blocks, + new_kv_indices, + new_full_kv_num_blocks, + new_full_kv_indices, + self.BLOCK_SIZE, + self.mask_mod, + ) + + def numel(self): + """Returns the number of elements (not accounting for sparsity) in the mask.""" + shape = self.shape + + def _prod(xs): + return functools.reduce(operator.mul, xs, 1) + + return _prod(shape) + + def sparsity(self) -> float: + """Computes the percentage of blocks that are sparse (i.e. not computed)""" + total_size = self.numel() + computed_blocks = self.kv_num_blocks.sum() + if self.full_kv_num_blocks is not None: + computed_blocks += self.full_kv_num_blocks.sum() + + computed_size = computed_blocks.item() * self.BLOCK_SIZE[0] * self.BLOCK_SIZE[1] + dense_ratio = computed_size / total_size + return 100 * (1 - dense_ratio) + + def to_dense(self) -> Tensor: + """Returns a dense block that is equivalent to the block mask.""" + partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices) + if self.full_kv_num_blocks is not None: + assert self.full_kv_indices is not None + # pyrefly: ignore [bad-return] + return partial_dense | _ordered_to_dense( + self.full_kv_num_blocks, self.full_kv_indices + ) + return partial_dense + + def to_string(self, grid_size=(20, 20), limit=4): + """Returns a string representation of the block mask. Quite nifty. + + If grid_size is -1, prints out an uncompressed version. Warning, it can be quite big! + """ + dense_mask = self.to_dense() + *batch_dims, num_rows, num_cols = dense_mask.shape + if isinstance(grid_size, int): + max_rows = grid_size + max_cols = grid_size + elif grid_size == -1: + max_rows = num_rows + max_cols = num_cols + else: + max_rows, max_cols = grid_size + + def create_block_vis(*batch_idx): + descriptors = [] + + descriptors.append(f"{batch_idx}") + + vis = ", ".join(reversed(descriptors)) + "\n" + + def summarize_section(section) -> str: + percentage = section.float().mean().item() + if percentage == 1: + return "â–ˆ" + elif percentage == 0: + return " " + else: + return "â–‘" + + def cdiv(a, b): + return (a + (b - 1)) // b + + row_step = max(1, cdiv(num_rows, max_rows)) + col_step = max(1, cdiv(num_cols, max_cols)) + + for r in range(0, num_rows, row_step): + for c in range(0, num_cols, col_step): + cur_mask = dense_mask + for idx in batch_idx: + cur_mask = cur_mask[idx] + char = summarize_section( + cur_mask[r : r + row_step, c : c + col_step] + ) + vis += char * 2 + vis += "\n" + return vis + + total_vis = [] + for idx, batch_idx in enumerate( + itertools.product(*[range(i) for i in batch_dims]) + ): + if idx == limit: + total_vis.append("...") + total_vis.append("To print out more, set BlockMask.to_string(limit=N)") + total_vis.append( + "You can also index (BlockMask[batch, head]) to choose a specific batch or head" + ) + break + block_vis = create_block_vis(*batch_idx) + total_vis.append(block_vis) + + return "\n".join(total_vis) + + def to(self, device: torch.device | str) -> "BlockMask": + """Moves the BlockMask to the specified device. + + Args: + device (torch.device or str): The target device to move the BlockMask to. + Can be a torch.device object or a string (e.g., 'cpu', 'cuda:0'). + + Returns: + BlockMask: A new BlockMask instance with all tensor components moved + to the specified device. + + Note: + This method does not modify the original BlockMask in-place. + Instead, it returns a new BlockMask instance where individual tensor attributes + may or may not be moved to the specified device, depending on their + current device placement. + """ + mapped_attributes = tree_map_only( + torch.Tensor, + lambda x: x.to(device), + self.as_tuple(flatten=False), + ) + return BlockMask(*mapped_attributes) + + def _flatten(self): + """Flatten BlockMask into a list of tensors and context.""" + tensors = tuple(getattr(self, attr) for attr in self._TENSOR_ATTRS) + context = tuple(getattr(self, attr) for attr in self._CONTEXT_ATTRS) + return tensors, context + + @classmethod + def _unflatten(cls, tensors, context): + """Unflatten tensors and context back into a BlockMask.""" + kwargs = { + **dict(zip(cls._CONTEXT_ATTRS, context)), + **dict(zip(cls._TENSOR_ATTRS, tensors)), + } + # pyrefly: ignore [bad-argument-type] + return cls(**kwargs) + + def _flatten_with_keys(self): + """Flatten BlockMask with keys for better tracing.""" + tensors = tuple( + (GetAttrKey(attr), getattr(self, attr)) for attr in self._TENSOR_ATTRS + ) + context = tuple( + (GetAttrKey(attr), getattr(self, attr)) for attr in self._CONTEXT_ATTRS + ) + return tensors, context + + +def _broadcast_to_dim(x, dim): + while x.dim() < dim: + x = x.unsqueeze(0) + return x + + +def _round_up_to_multiple(x, multiple): + return (x + multiple - 1) // multiple * multiple + + +def _convert_mask_to_block_mask( + mask: Tensor, + Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, + KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, + separate_full_blocks: bool = False, +) -> tuple[Tensor, Tensor | None]: + assert mask.dtype == torch.bool + mask = _broadcast_to_dim(mask, 4) + + def padding_needed_for_multiple(x, multiple): + return _round_up_to_multiple(x, multiple) - x + + mask = torch.nn.functional.pad( + mask, + ( + 0, + padding_needed_for_multiple(mask.shape[-1], KV_BLOCK_SIZE), + 0, + padding_needed_for_multiple(mask.shape[-2], Q_BLOCK_SIZE), + ), + ) + B, H, Q, KV = mask.shape + assert Q % Q_BLOCK_SIZE == 0 + assert KV % KV_BLOCK_SIZE == 0 + mask = mask.view( + B, H, Q // Q_BLOCK_SIZE, Q_BLOCK_SIZE, KV // KV_BLOCK_SIZE, KV_BLOCK_SIZE + ) # [B, H, Q//Q_BLOCK_SIZE, Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE, KV_BLOCK_SIZE] + mask = mask.permute( + 0, 1, 2, 4, 3, 5 + ) # [B, H, Q//Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE, Q_BLOCK_SIZE, KV_BLOCK_SIZE] + mask_block_sum = mask.sum( + dim=[-2, -1] + ) # [B, H, Q//Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE] + if separate_full_blocks: + full_block_sum = Q_BLOCK_SIZE * KV_BLOCK_SIZE + full_blocks = mask_block_sum == full_block_sum + partial_blocks = (mask_block_sum > 0) & (mask_block_sum < full_block_sum) + partial_blocks = partial_blocks.to(dtype=torch.int8) + full_blocks = full_blocks.to(dtype=torch.int8) + return partial_blocks, full_blocks + else: + partial_blocks = mask_block_sum > 0 + partial_blocks = partial_blocks.to(dtype=torch.int8) + return partial_blocks, None + + +def or_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature: + """Returns a mask_mod that's the union of provided mask_mods""" + if not all(callable(arg) for arg in mask_mods): + raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}") + + def or_mask(b, h, q_idx, kv_idx): + result = b.new_zeros((), dtype=torch.bool) + for mask in mask_mods: + result = result | mask(b, h, q_idx, kv_idx) + return result + + return or_mask + + +def and_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature: + """Returns a mask_mod that's the intersection of provided mask_mods""" + if not all(callable(arg) for arg in mask_mods): + raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}") + + def and_mask(b, h, q_idx, kv_idx): + result = b.new_ones((), dtype=torch.bool) + for mask in mask_mods: + result = result & mask(b, h, q_idx, kv_idx) + return result + + return and_mask + + +def _convert_block_mask_to_mask( + block_mask, + KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, + Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, +) -> Tensor: + assert block_mask.dim() == 4 + B, H, Q, KV = block_mask.shape + block_mask = block_mask.expand(Q_BLOCK_SIZE, KV_BLOCK_SIZE, *block_mask.shape) + block_mask = block_mask.permute(2, 3, 4, 0, 5, 1).reshape( + B, H, Q * Q_BLOCK_SIZE, KV * KV_BLOCK_SIZE + ) + return block_mask + + +def _create_sparse_block_from_block_mask( + block_mask: tuple[Tensor, Tensor | None], + mask_mod: Callable | None, + seq_lengths: tuple[int, int], + Q_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE, + KV_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE, +) -> BlockMask: + partial_blocks, full_blocks = block_mask + + partial_bm = _dense_to_ordered(partial_blocks) + if full_blocks is not None: + full_bm: tuple[Tensor | None, Tensor | None] = _dense_to_ordered(full_blocks) + else: + full_bm = (None, None) + + return BlockMask.from_kv_blocks( + partial_bm[0], + partial_bm[1], + full_bm[0], + full_bm[1], + BLOCK_SIZE=(Q_BLOCK_SIZE, KV_BLOCK_SIZE), + mask_mod=mask_mod, + seq_lengths=seq_lengths, + ) + + +def create_mask( + mod_fn: _score_mod_signature | _mask_mod_signature, + B: int | None, + H: int | None, + Q_LEN: int, + KV_LEN: int, + device: DeviceLikeType | None = None, +) -> Tensor: + r"""This function creates a mask tensor from a mod_fn function. + + Args: + mod_fn (Union[_score_mod_signature, _mask_mod_signature]): Function to modify attention scores. + B (int): Batch size. + H (int): Number of query heads. + Q_LEN (int): Sequence length of query. + KV_LEN (int): Sequence length of key/value. + device (str): Device to run the mask creation on. + + Returns: + mask (Tensor): A mask tensor with shape (B, H, M, N). + """ + if device is None: + device = torch.accelerator.current_accelerator() or "cpu" + if B is None: + B = 1 + if H is None: + H = 1 + b = torch.arange(0, B, device=device) + h = torch.arange(0, H, device=device) + m = torch.arange(0, Q_LEN, device=device) + n = torch.arange(0, KV_LEN, device=device) + mod_type = _get_mod_type(mod_fn) + + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + + with TransformGetItemToIndex(): + if mod_type == _ModificationType.SCORE_MOD: + score_mod = mod_fn + score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,)) # first input is score + out = score_mod(torch.zeros(B, H, Q_LEN, KV_LEN, device=device), b, h, m, n) + mask = torch.where(torch.isneginf(out), False, True) + return mask + elif mod_type == _ModificationType.MASK_MOD: + mask_mod = mod_fn + mask_mod = _vmap_for_bhqkv(mask_mod, prefix=()) + mask = mask_mod(b, h, m, n) + return mask + else: + raise AssertionError + + +def create_block_mask( + mask_mod: _mask_mod_signature, + B: int | None, + H: int | None, + Q_LEN: int, + KV_LEN: int, + device: DeviceLikeType | None = None, + BLOCK_SIZE: int | tuple[int, int] = _DEFAULT_SPARSE_BLOCK_SIZE, + _compile=False, +) -> BlockMask: + r"""This function creates a block mask tuple from a mask_mod function. + + Args: + mask_mod (Callable): mask_mod function. This is a callable that defines the + masking pattern for the attention mechanism. It takes four arguments: + b (batch size), h (number of heads), q_idx (query index), and kv_idx (key/value index). + It should return a boolean tensor indicating which attention connections are allowed (True) + or masked out (False). + B (int): Batch size. + H (int): Number of query heads. + Q_LEN (int): Sequence length of query. + KV_LEN (int): Sequence length of key/value. + device (str): Device to run the mask creation on. + BLOCK_SIZE (int or tuple[int, int]): Block size for the block mask. If a single int is provided it is used for both query and key/value. + + Returns: + BlockMask: A BlockMask object that contains the block mask information. + + Example Usage: + .. code-block:: python + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + + block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda") + query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) + key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) + value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) + output = flex_attention(query, key, value, block_mask=block_mask) + """ + if device is None: + device = torch.accelerator.current_accelerator() or "cpu" + mod_type = _get_mod_type(mask_mod) + assert mod_type == _ModificationType.MASK_MOD, ( + f"create-block_mask requires a mask_mod function! Got {mask_mod}" + ) + if B is None: + B = 1 + if H is None: + H = 1 + if isinstance(BLOCK_SIZE, int): + Q_BLOCK_SIZE = BLOCK_SIZE + KV_BLOCK_SIZE = BLOCK_SIZE + else: + Q_BLOCK_SIZE, KV_BLOCK_SIZE = BLOCK_SIZE + + if _compile: + warnings.warn( + "_compile flag on create_block_mask was originally added to work around a torch.compile limitation. That limitation has since been addressed. So, to compile create_block_mask, we suggest doing torch.compile(create_block_mask). This still works for now, but will be removed in the future.", + DeprecationWarning, + stacklevel=2, + ) + return torch.compile(create_block_mask)( + mask_mod, B, H, Q_LEN, KV_LEN, device, BLOCK_SIZE + ) + + mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device) + partial_block_mask, full_block_mask = _convert_mask_to_block_mask( + mask_tensor, + Q_BLOCK_SIZE=Q_BLOCK_SIZE, + KV_BLOCK_SIZE=KV_BLOCK_SIZE, + separate_full_blocks=True, + ) + block_mask = _create_sparse_block_from_block_mask( + (partial_block_mask, full_block_mask), + mask_mod, + (Q_LEN, KV_LEN), + Q_BLOCK_SIZE, + KV_BLOCK_SIZE, + ) + return block_mask + + +def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask: + r"""Default block mask for flex attention. + If users don't specify any block sparse mask info, we create this + empty block sparse mask. Which creates a BlockMask with 1 block that is the full length + of the query and key tensors. + """ + device = query.device + return BlockMask.from_kv_blocks( + kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device), + kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device), + BLOCK_SIZE=_LARGE_SPARSE_BLOCK_SIZE, + seq_lengths=(1, 1), + ) + + +def _apply_kernel_options( + query: Tensor, + key: Tensor, + value: Tensor, + return_lse: bool, + kernel_options, + return_aux: AuxRequest | None = None, +): + kernel_options = {} if kernel_options is None else dict(kernel_options) + + if "BACKEND" in kernel_options and kernel_options.get( + "FORCE_USE_FLEX_ATTENTION", False + ): + # TODO: remove FORCE_USE_FLEX_ATTENTION once BACKEND is fully adopted. + raise RuntimeError( + "BACKEND cannot be combined with legacy FORCE_USE_FLEX_ATTENTION. " + "BACKEND supersedes the legacy knob; please drop FORCE_USE_FLEX_ATTENTION " + "and only specify the desired BACKEND." + ) + + if "BACKEND" in kernel_options: + valid_backends = typing.get_args(_Backend) + if kernel_options["BACKEND"] not in valid_backends: + raise ValueError( + f"Invalid BACKEND value '{kernel_options['BACKEND']}'. " + f"Must be one of {valid_backends}" + ) + + kernel_options.setdefault("BACKEND", "AUTO") + kernel_options.setdefault("PRESCALE_QK", False) + kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False) + kernel_options.setdefault("BLOCKS_ARE_CONTIGUOUS", False) + # This forces all biases grad scatters to be done in the DQ iteration loop of the backwards + kernel_options.setdefault("WRITE_DQ", True) + + any_inputs_on_cpu_device = ( + query.device.type == "cpu" + or key.device.type == "cpu" + or value.device.type == "cpu" + ) + + # Determine what auxiliary outputs are needed + output_lse = return_lse + output_max = False + + if return_aux is not None: + # New API takes precedence over legacy parameters + output_lse = return_aux.lse + output_max = return_aux.max_scores + + # If forward kernel needs to return logsumexp is decided by this rule internally. + assert "OUTPUT_LOGSUMEXP" not in kernel_options + kernel_options["OUTPUT_LOGSUMEXP"] = True + if not output_lse: + # We used to check if q,k,v required grads but since captured buffers can require grad + # we always write unless in no_grad + kernel_options["OUTPUT_LOGSUMEXP"] = torch.is_grad_enabled() + if any_inputs_on_cpu_device: + # CPU with torch.compile now supports inference, and will not return lse + # TODO: support CPU for training and return lse + kernel_options["OUTPUT_LOGSUMEXP"] = False + + # If forward kernel needs to return max is decided by this rule internally. + assert "OUTPUT_MAX" not in kernel_options + kernel_options["OUTPUT_MAX"] = output_max + if any_inputs_on_cpu_device and output_max: + # CPU doesn't support returning max yet + # TODO: support CPU for returning max + raise NotImplementedError("Returning max scores is not supported on CPU.") + kernel_options["OUTPUT_MAX"] = False + + return kernel_options + + +def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor) -> None: + if query.size(-1) != key.size(-1): + raise ValueError( + f"Expect query and key/value to have the same embedding dimension " + f"but got E={query.size(-1)} and E={key.size(-1)}." + ) + + +def _validate_device(query: Tensor, key: Tensor, value: Tensor) -> None: + """TODO: Remove once non cuda/cpu devices support is added + We only need to check query since we have already that q,k,v are on the same device + """ + supported_devices = {"cuda", "cpu", "xpu", "hpu"} + if query.device.type not in supported_devices: + raise ValueError( + "FlexAttention is only supported on CUDA, CPU or HPU devices. " + f"Found input tensors on {query.device.type} device." + ) + + +def _enforce_mem_layouts( + query: Tensor, key: Tensor, value: Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Enforce memory layouts for query, key, and value tensors. + + For non-FP8 dtypes, no action is taken. + + For FP8 dtypes, we enforce the following memory layouts: + - Query tensor must be in row-major memory layout, as it will be the left-operand in the FP8 GEMM `q @ k.T`. + - Key tensor must be in row-major memory layout, as it will be transposed when used as the right-operand + in the FP8 GEMM `q @ k.T`, meaning it will correctly be in column-major memory layout for the GEMM. + - Value tensor must be in column-major memory layout, as it will be the right-operand in the FP8 GEMM `softmax_scores @ v`. + + Returns the query, key, and value tensors with the enforced memory layouts. + """ + + def is_row_major(tensor: Tensor) -> bool: + return tensor.stride()[-1] == 1 + + def is_col_major(tensor: Tensor) -> bool: + return tensor.stride()[-2] == 1 + + # These memory layout constraint are only for FP8 GEMMs on NVIDIA GPU architectures >= SM89 and < SM100. + # This is because GPU arch < SM89 does not not support FP8 GEMMs, and + # SM100 has support for TN, NT, TT, NN layouts for FP8 GEMMs + # (i.e., left and right operands can be in row or column major layouts) + # so this check is only needed for older architectures. + # See: https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md + fp8_dtypes = ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ) + gemm_precision = query.dtype + + should_enforce_mem_layout = ( + gemm_precision in fp8_dtypes + and torch.version.cuda is not None + and torch.cuda.get_device_capability("cuda") >= (8, 9) + and torch.cuda.get_device_capability("cuda") < (10, 0) + ) + if not should_enforce_mem_layout: + return query, key, value + + # Query must be in row-major memory layout as the left-operand in the FP8 GEMM `q @ k.T` + if not is_row_major(query): + query = query.contiguous() + + # Key must be in row-major memory layout as it will be transposed when used as the right-operand + # in the FP8 GEMM `q @ k.T`, meaning it will correctly be in column-major memory layout for the GEMM. + if not is_row_major(key): + key = key.contiguous() + + # Value must be in column-major memory layout as the right-operand in the FP8 GEMM `softmax_scores @ v` + if not is_col_major(value): + value = value.transpose(-2, -1).contiguous().transpose(-2, -1) + return query, key, value + + +def flex_attention( + query: Tensor, + key: Tensor, + value: Tensor, + score_mod: _score_mod_signature | None = None, + block_mask: BlockMask | None = None, + scale: float | None = None, + enable_gqa: bool = False, + return_lse: bool = False, + kernel_options: FlexKernelOptions | None = None, + *, + return_aux: AuxRequest | None = None, +) -> Tensor | tuple[Tensor, Tensor] | tuple[Tensor, AuxOutput]: + r"""This function implements scaled dot product attention with an arbitrary attention score modification function. + + This function computes the scaled dot product attention between query, key, and value tensors with a user-defined + attention score modification function. The attention score modification function will be applied after the attention + scores have been calculated between the query and key tensors. The attention scores are calculated as follows: + + The ``score_mod`` function should have the following signature: + + .. code-block:: python + + def score_mod( + score: Tensor, + batch: Tensor, + head: Tensor, + q_idx: Tensor, + k_idx: Tensor + ) -> Tensor: + + Where: + - ``score``: A scalar tensor representing the attention score, + with the same data type and device as the query, key, and value tensors. + - ``batch``, ``head``, ``q_idx``, ``k_idx``: Scalar tensors indicating + the batch index, query head index, query index, and key/value index, respectively. + These should have the ``torch.int`` data type and be located on the same device as the score tensor. + + Args: + query (Tensor): Query tensor; shape :math:`(B, Hq, L, E)`. For FP8 dtypes, should be in row-major memory layout for optimal performance. + key (Tensor): Key tensor; shape :math:`(B, Hkv, S, E)`. For FP8 dtypes, should be in row-major memory layout for optimal performance. + value (Tensor): Value tensor; shape :math:`(B, Hkv, S, Ev)`. For FP8 dtypes, should be in column-major memory layout for optimal performance. + score_mod (Optional[Callable]): Function to modify attention scores. By default no score_mod is applied. + block_mask (Optional[BlockMask]): BlockMask object that controls the blocksparsity pattern of the attention. + scale (Optional[float]): Scaling factor applied prior to softmax. If none, the default value is set to :math:`\frac{1}{\sqrt{E}}`. + enable_gqa (bool): If set to True, enables Grouped Query Attention (GQA) and broadcasts key/value heads to query heads. + return_lse (bool): Whether to return the logsumexp of the attention scores. Default is False. **Deprecated**: Use ``return_aux=AuxRequest(lse=True)`` instead. + kernel_options (Optional[FlexKernelOptions]): + Options to control the behavior of the underlying Triton kernels. + See :class:`FlexKernelOptions` for available options and usage examples. + return_aux (Optional[AuxRequest]): Specifies which auxiliary outputs to compute and return. + If None, only the attention output is returned. Use ``AuxRequest(lse=True, max_scores=True)`` + to request both auxiliary outputs. + + Returns: + output (Tensor): Attention output; shape :math:`(B, Hq, L, Ev)`. + + When ``return_aux`` is not None: + aux (AuxOutput): Auxiliary outputs with requested fields populated. + + When ``return_aux`` is None (deprecated paths): + lse (Tensor): Log-sum-exp of attention scores; shape :math:`(B, Hq, L)`. Only returned if ``return_lse=True``. + + Shape legend: + - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` + - :math:`S: \text{Source sequence length}` + - :math:`L: \text{Target sequence length}` + - :math:`E: \text{Embedding dimension of the query and key}` + - :math:`Ev: \text{Embedding dimension of the value}` + + .. warning:: + `torch.nn.attention.flex_attention` is a prototype feature in PyTorch. + Please look forward to a more stable implementation in a future version of PyTorch. + Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + """ + # Some basic input validation + _validate_sdpa_input(query, key, value) + _validate_embed_dim(query, key, value) + _validate_device(query, key, value) + query, key, value = _enforce_mem_layouts(query, key, value) + if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: + raise NotImplementedError("NYI: query, key, and value must be 4D tensors") + if (not enable_gqa) and query.size(-3) != key.size(-3): + raise ValueError( + f"Expect query and key/value to have the same number of heads " + f"but got Hq={query.size(-3)} and Hkv={key.size(-3)}. " + f"Try setting enable_gqa=True for GQA." + ) + if enable_gqa: + Hq = query.size(1) + Hkv = key.size(1) + if Hq % Hkv != 0: + raise ValueError( + f"Expect number of query heads to be a multiple of kv heads for GQA " + f"but got Hq={Hq} and Hkv={Hkv}." + ) + if query.size(0) != key.size(0): + if block_mask is None: + raise ValueError( + f"Expect query and key/value to have the same batch size, " + f"or non-none block_mask, " + f"but got block_mask=None, Bq={query.size(0)}, and Bkv={key.size(0)}." + ) + + if block_mask.kv_num_blocks.size(0) != query.size(0): + raise ValueError( + f"Expect query and key/value to have the same batch size, " + f"or block_mask and query to have the same batch size, " + f"but got Bq={query.size(0)}, Bkv={key.size(0)}, B_block_mask={block_mask.kv_num_blocks.size(0)}." + ) + + if score_mod is None: + score_mod = _identity + + if block_mask is None: + block_mask = _create_empty_block_mask(query, key) + + # If BlockMask was sliced, its mask_mod is intentionally replaced with an error-raising stub. + # This guard ensures we surface the intended error message before any shape-based checks. + if getattr(block_mask, "mask_mod", None) is _sliced_mask_mod_error: + raise RuntimeError("Cannot use mask_mod from a sliced BlockMask") + + if ( + block_mask.BLOCK_SIZE[0] == _LARGE_SPARSE_BLOCK_SIZE + and block_mask.BLOCK_SIZE[1] == _LARGE_SPARSE_BLOCK_SIZE + ): + # This corresponds to the case where we essentially have a "no-op" block mask. + pass + else: + block_mask_q_len = block_mask.shape[-2] + block_mask_kv_len = block_mask.shape[-1] + if query.size(-2) > block_mask_q_len or key.size(-2) > block_mask_kv_len: + raise ValueError( + f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. " + "As the block mask was created for a smaller length than you're using it for, you likely need to create a new block mask." + ) + elif ( + query.size(-2) < block_mask_q_len and key.size(-2) <= block_mask_kv_len + ) or (query.size(-2) <= block_mask_q_len and key.size(-2) < block_mask_kv_len): + raise ValueError( + f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. " + "As the block mask was created for a larger length than you're using it for, you can either 1. create a new block mask with the correct length, or 2. 'adjust' the existing block mask to the correct length by calling block_mask._adjust(q_len, kv_len). This essentially 'crops' the block mask to the upper left corner, which does not work for all mask_mods!" + ) + assert query.size(-2) == block_mask_q_len + assert key.size(-2) == block_mask_kv_len + + if scale is None: + scale = 1.0 / math.sqrt(query.size(-1)) + + if query.device != block_mask.kv_num_blocks.device: # type: ignore[union-attr] + raise RuntimeError( + f"Expect q/k/v and block_mask to be on the same device " + f"but got {query.device} and {block_mask.kv_num_blocks.device}." # type: ignore[union-attr] + ) + + # Handle deprecation warnings for old parameters + if return_lse and return_aux is not None: + raise ValueError( + "Cannot specify both return_lse and return_aux. " + "return_lse is deprecated, please use return_aux=AuxRequest(lse=True) instead." + ) + elif return_lse and return_aux is None: + _warn_once( + "deprecated_return_lse", + "return_lse is deprecated and will be removed in v2.10. " + "Please use return_aux=AuxRequest(lse=True) instead.", + category=FutureWarning, + ) + + kernel_options = _apply_kernel_options( + query, + key, + value, + return_lse, + kernel_options, + return_aux, + ) + + def _finalize_outputs( + out, + lse, + max_scores, + *, + return_aux: AuxRequest | None, + return_lse: bool, + ): + """Normalize stats and build return value (aux-aware, legacy-compatible).""" + ln2 = math.log(2.0) + return_lse = return_lse or return_aux is not None and return_aux.lse + return_max = return_aux is not None and return_aux.max_scores + + lse_scaled = lse * ln2 if (return_lse and lse.numel() > 0) else None + max_scaled = ( + max_scores * ln2 if (return_max and max_scores.numel() > 0) else None + ) + + if return_aux is not None: + return out, AuxOutput( + lse=lse_scaled, + max_scores=max_scaled, + ) + + if return_lse: + return out, lse_scaled + + return out + + if torch.compiler.is_dynamo_compiling(): + # mark head_dim and number of heads to be static + for x in [query, key, value]: + torch._dynamo.mark_static(x, -3) + torch._dynamo.mark_static(x, -1) + + out, lse, max_scores = flex_attention_hop( + query, + key, + value, + score_mod, + block_mask.as_tuple(), + scale, + kernel_options, # type: ignore[union-attr] + ) + return _finalize_outputs( + out, lse, max_scores, return_aux=return_aux, return_lse=return_lse + ) + + if not _FLEX_ATTENTION_DISABLE_COMPILE_DEBUG: + _warn_once( + warning_id="flex_attention_performance", + message=( + "flex_attention called without torch.compile() - this will use an unfused implementation that materializes the full scores matrix instead of generating a fused kernel.\n\n" + "SOLUTION: Use torch.compile(flex_attention)(...)\n\n" + "If you want to debug your score_mod/mask_mod, you can set:\n" + "torch.nn.attention.flex_attention._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True\n\n" + "This will allow you to use print statements or breakpoints. Note: This doesn't work with the backwards pass and may produce incorrect results." + ), + ) + + if not torch._dynamo.is_dynamo_supported(): + raise RuntimeError("flex_attention requires dynamo support") + + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + + # Dynamo is expecting a callable with "__code__" attribute. + # We cannot directly pass hop to it. So we wrap it in a dummy function. + def _flex_attention_hop_wrapper(*args, **kwargs): + return flex_attention_hop(*args, **kwargs) + + with _set_compilation_env(): + with torch._dynamo.utils.disable_cache_limit(): + with _temp_remove_pre_dispatch_torch_function_mode(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend: str | Callable[..., Any] = ( + make_eager_backend_with_torch_function_mode(metadata_mode) + ) + else: + backend = "eager" + + if _FLEX_ATTENTION_DISABLE_COMPILE_DEBUG: + flex_fn = _flex_attention_hop_wrapper + else: + flex_fn = torch.compile( + _flex_attention_hop_wrapper, backend=backend, fullgraph=True + ) + + out, lse, max_scores = flex_fn( + query, + key, + value, + score_mod, + block_mask.as_tuple(), # type: ignore[union-attr] + scale, + kernel_options, + ) + return _finalize_outputs( + out, lse, max_scores, return_aux=return_aux, return_lse=return_lse + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/varlen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..b20c1b4b2e49a37cf0e29603f20ef50e0caf6146 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/attention/varlen.py @@ -0,0 +1,326 @@ +""" +Variable-length attention implementation using Flash Attention. + +This module provides a high-level Python interface for variable-length attention +that calls into the optimized Flash Attention kernels. +""" + +import logging +from functools import lru_cache +from typing import Any, NamedTuple + +import torch + + +log = logging.getLogger(__name__) + +__all__ = ["varlen_attn", "AuxRequest"] + + +@lru_cache(maxsize=8) +def _should_use_cudnn(device_index: int) -> bool: + """Cache device capability check to avoid repeated CUDA calls.""" + return False + + +class AuxRequest(NamedTuple): + """ + Request which auxiliary outputs to compute from varlen_attn. + + Each field is a boolean indicating whether that auxiliary output should be computed. + """ + + lse: bool = False + + +@torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={}) +def _varlen_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seq_q: torch.Tensor, + cu_seq_k: torch.Tensor, + max_q: int, + max_k: int, + is_causal: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Private custom op for variable-length attention. + + This is the internal implementation. Users should use the public varlen_attn function instead. + """ + + use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index) + + if use_cudnn: + log.info("Using cuDNN backend for varlen_attn") + result = torch.ops.aten._cudnn_attention_forward( + query, + key, + value, + None, # attn_bias + cu_seq_q, + cu_seq_k, + max_q, + max_k, + True, # compute_log_sumexp + 0.0, # dropout_p hardcoded to 0.0 + is_causal, + False, # return_debug_mask + ) + # cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask) + output, softmax_lse, rng_state = result[0], result[1], result[6] + else: + log.info("Using Flash Attention backend for varlen_attn") + output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward( + query, + key, + value, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + 0.0, # dropout_p hardcoded to 0.0 + is_causal, + return_debug_mask=False, + ) + + rng_state_ = torch.zeros( + (2,), dtype=torch.uint64, device=query.device + ) # hardcoded since dropout is hardcoded to 0 + return output, softmax_lse, rng_state_ + + +@_varlen_attn.register_fake +def _varlen_attn_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seq_q: torch.Tensor, + cu_seq_k: torch.Tensor, + max_q: int, + max_k: int, + is_causal: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Fake implementation for meta tensor computation and tracing. + + Based on the 3D varlen path from meta__flash_attention_forward: + - query shape: (total, num_heads, head_dim) + - logsumexp shape: (num_heads, total_q) + """ + # Output has same shape as query + output = torch.empty_like(query) + + # For varlen path: logsumexp shape is (num_heads, total_q) + total_q = query.size(0) + num_heads = query.size(1) + logsumexp = torch.empty( + (num_heads, total_q), dtype=torch.float, device=query.device + ) + + rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device) + + return output, logsumexp, rng_state + + +def varlen_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seq_q: torch.Tensor, + cu_seq_k: torch.Tensor, + max_q: int, + max_k: int, + is_causal: bool = False, + return_aux: AuxRequest | None = None, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Compute variable-length attention using Flash Attention. + This function is similar to scaled_dot_product_attention but optimized for + variable-length sequences using cumulative sequence position tensors. + Args: + - query (Tensor): Query tensor; shape :math:`(T_q, H, D)` + - key (Tensor): Key tensor; shape :math:`(T_k, H, D)` + - value (Tensor): Value tensor; shape :math:`(T_k, H, D)` + - cu_seq_q (Tensor): Cumulative sequence positions for queries; shape :math:`(N+1,)` + - cu_seq_k (Tensor): Cumulative sequence positions for keys/values; shape :math:`(N+1,)` + - max_q (int): Maximum query sequence length in the batch. + - max_k (int): Maximum key/value sequence length in the batch. + - is_causal (bool, optional): If set to True, applies causal masking (default: False). + - return_aux (Optional[AuxRequest]): If not None and ``return_aux.lse`` is True, also returns the logsumexp tensor. + + Shape legend: + - :math:`N`: Batch size + - :math:`T_q`: Total number of query tokens in the batch (sum of all query sequence lengths) + - :math:`T_k`: Total number of key/value tokens in the batch (sum of all key/value sequence lengths) + - :math:`H`: Number of attention heads + - :math:`D`: Head dimension + + Returns: + - Tensor: Output tensor from attention computation + - If ``return_aux`` is not None and ``return_aux.lse`` is True, returns a tuple of Tensors: + (output, lse), where lse is the logsumexp + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> batch_size, max_seq_len, embed_dim, num_heads = 2, 512, 1024, 16 + >>> head_dim = embed_dim // num_heads + >>> seq_lengths = [] + >>> for _ in range(batch_size): + ... length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64 + ... seq_lengths.append(min(length, max_seq_len)) + >>> seq_lengths = torch.tensor(seq_lengths, device="cuda") + >>> total_tokens = seq_lengths.sum().item() + >>> + >>> # Create packed query, key, value tensors + >>> query = torch.randn( + ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" + ... ) + >>> key = torch.randn( + ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" + ... ) + >>> value = torch.randn( + ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" + ... ) + >>> + >>> # Build cumulative sequence tensor + >>> cu_seq = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + >>> cu_seq[1:] = seq_lengths.cumsum(0) + >>> max_len = seq_lengths.max().item() + >>> + >>> # Call varlen_attn + >>> output = varlen_attn( + ... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False + ... ) + """ + out, lse, _ = torch.ops.torch_attn._varlen_attn( + query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal + ) + if return_aux is not None and return_aux.lse: + return out, lse + return out + + +def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None: + query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal = inputs + out, lse, rng_state = output + + ctx.save_for_backward(query, key, value, cu_seq_q, cu_seq_k, out, lse, rng_state) + + ctx.max_q = max_q + ctx.max_k = max_k + ctx.is_causal = is_causal + + +@torch.library.custom_op("torch_attn::_varlen_attn_backward", mutates_args={}) +def _varlen_attn_backward( + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seq_q: torch.Tensor, + cu_seq_k: torch.Tensor, + max_q: int, + max_k: int, + is_causal: bool, + rng_state: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + unused = torch.empty(0, device=query.device) + + use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index) + if use_cudnn: + log.info("Using cuDNN backend for varlen_attn") + dq, dk, dv = torch.ops.aten._cudnn_attention_backward( + grad_out, + query, + key, + value, + out, + lse, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + 0.0, + is_causal, + rng_state, + unused, + ) + else: + log.info("Using Flash Attention backend for varlen_attn") + dq, dk, dv = torch.ops.aten._flash_attention_backward( + grad_out, + query, + key, + value, + out, + lse, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + 0.0, + is_causal, + rng_state, + unused, + ) + return dq, dk, dv + + +@_varlen_attn_backward.register_fake +def _varlen_attn_backward_fake( + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seq_q: torch.Tensor, + cu_seq_k: torch.Tensor, + max_q: int, + max_k: int, + is_causal: bool, + rng_state: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Fake implementation for meta tensor computation and tracing. + """ + + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) + + return grad_query, grad_key, grad_value + + +def _backward( + ctx: Any, grad_out: torch.Tensor, grad_lse: torch.Tensor, grad_rng: torch.Tensor +) -> tuple[torch.Tensor | None, ...]: + query, key, value, cu_seq_q, cu_seq_k, out, lse, rng_state = ctx.saved_tensors + + max_q = ctx.max_q + max_k = ctx.max_k + is_causal = ctx.is_causal + + dq, dk, dv = torch.ops.torch_attn._varlen_attn_backward( + grad_out, + query, + key, + value, + out, + lse, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + is_causal, + rng_state, + ) + return dq, dk, dv, None, None, None, None, None, None + + +_varlen_attn.register_autograd(_backward, setup_context=_setup_context) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/backends/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/backends/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/backends/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1da41c9221005f05f6f8ba23f6ee9af2074f4541 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/backends/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/backends/__pycache__/thnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/backends/__pycache__/thnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b23570d3079a6fd401ac3658b07ffb384a6faa88 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/backends/__pycache__/thnn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/backends/thnn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/backends/thnn.py new file mode 100644 index 0000000000000000000000000000000000000000..c56e923a84383a79c2a3f7ebddb3dfa1ce1f0953 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/backends/thnn.py @@ -0,0 +1,6 @@ +# mypy: allow-untyped-defs +# this is for historical pickle deserialization, it is not used otherwise + + +def _get_thnn_function_backend() -> None: + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe9a09aa31464fd4e88b2e46b4210561a70e42e7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/__init__.py @@ -0,0 +1,36 @@ +from torch.ao.nn.intrinsic import ( + BNReLU2d, + BNReLU3d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + LinearBn1d, + LinearReLU, +) +from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401 + +# Include the subpackages in case user imports from it directly +from torch.nn.intrinsic import modules, qat, quantized # noqa: F401 + + +__all__ = [ + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "BNReLU2d", + "BNReLU3d", + "LinearBn1d", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20c056becec38dbadeb890cae4eb9915313745fb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f0d85ab1aba10947e3fe0570a0ee49a4bb66e9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/modules/__init__.py @@ -0,0 +1,33 @@ +from torch.nn.intrinsic.modules.fused import ( + _FusedModule, + BNReLU2d, + BNReLU3d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + LinearBn1d, + LinearReLU, +) + + +__all__ = [ + "BNReLU2d", + "BNReLU3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearBn1d", + "LinearReLU", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b521c49a6e70fb06221570189b124211fae44fd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/modules/__pycache__/fused.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/modules/__pycache__/fused.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..faad502383847c8b351dd7cf3e75f69db4405885 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/modules/__pycache__/fused.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/modules/fused.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/modules/fused.py new file mode 100644 index 0000000000000000000000000000000000000000..6ba11c92f05d21d97473349eb13174bbca8ca4d2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/modules/fused.py @@ -0,0 +1,33 @@ +from torch.ao.nn.intrinsic import ( + BNReLU2d, + BNReLU3d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + LinearBn1d, + LinearReLU, +) +from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401 + + +__all__ = [ + "BNReLU2d", + "BNReLU3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearBn1d", + "LinearReLU", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..302ed488399965e7db72656cd4bc8e522085df3d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/__init__.py @@ -0,0 +1 @@ +from torch.nn.intrinsic.qat.modules import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed4a42dbb6057c5f0adf193ef4f79b16669c71a8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc439cd2fa6b439cec7b6ba0c9eec372c60a0caf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__init__.py @@ -0,0 +1,32 @@ +from torch.nn.intrinsic.qat.modules.conv_fused import ( + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + freeze_bn_stats, + update_bn_stats, +) +from torch.nn.intrinsic.qat.modules.linear_fused import LinearBn1d +from torch.nn.intrinsic.qat.modules.linear_relu import LinearReLU + + +__all__ = [ + "LinearReLU", + "LinearBn1d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "update_bn_stats", + "freeze_bn_stats", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8168efe3f2369d00a610963c4a8d633b5899bc8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e759661d3cb7f88270c363a2aed91886827839f7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..539355c26fdcd80abdc46d5a867ede01978eb03c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d30a3213f678aeb6a7b815db92b08da7bce3ea9d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/conv_fused.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/conv_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..f8dc1d49aad310819ef892b2ca6e984d1e392f06 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/conv_fused.py @@ -0,0 +1,39 @@ +r"""Intrinsic QAT Modules. + +This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/intrinsic/qat/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.intrinsic.qat import ( + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + freeze_bn_stats, + update_bn_stats, +) + + +__all__ = [ + # Modules + "ConvBn1d", + "ConvBnReLU1d", + "ConvReLU1d", + "ConvBn2d", + "ConvBnReLU2d", + "ConvReLU2d", + "ConvBn3d", + "ConvBnReLU3d", + "ConvReLU3d", + # Utilities + "freeze_bn_stats", + "update_bn_stats", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/linear_fused.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/linear_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..79567d67bd1f930b89f6270232efbefec222ac61 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/linear_fused.py @@ -0,0 +1,15 @@ +r"""Intrinsic QAT Modules. + +This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/intrinsic/qat/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.intrinsic.qat import LinearBn1d + + +__all__ = [ + "LinearBn1d", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/linear_relu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..71705320075efd08040418fb69c492882a6bd03d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/qat/modules/linear_relu.py @@ -0,0 +1,15 @@ +r"""Intrinsic QAT Modules. + +This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/intrinsic/qat/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.intrinsic.qat import LinearReLU + + +__all__ = [ + "LinearReLU", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c09fddf6e75f989349d206ee1dcab6c9eb25678 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/__init__.py @@ -0,0 +1,14 @@ +# to ensure customers can use the module below +# without importing it directly +from torch.nn.intrinsic.quantized import dynamic, modules # noqa: F401 +from torch.nn.intrinsic.quantized.modules import * # noqa: F403 + + +__all__ = [ + "BNReLU2d", + "BNReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cd507a240818f7ef743f2eda1c14f2835c1d37e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ce645703ddb584c34d80c15b562f0dacd14350b4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from torch.nn.intrinsic.quantized.dynamic.modules import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b615fdf4937f67b68f2c3d5db97c730cf2376f8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49d148aa61a549518de0e98dad2a441bd837eba9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py @@ -0,0 +1,6 @@ +from torch.nn.intrinsic.quantized.dynamic.modules.linear_relu import LinearReLU + + +__all__ = [ + "LinearReLU", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfe2f86a0342783c412cfc650f22799961fd5c00 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b00a28594c4c487722706dece60b70694e1325d2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..85df3dd4b441a342c8126dd4e8cea89f5bfa5c08 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -0,0 +1,6 @@ +from torch.ao.nn.intrinsic.quantized.dynamic import LinearReLU + + +__all__ = [ + "LinearReLU", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc79b04418f8a7a8128c52a64e938396a94ee94 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py @@ -0,0 +1,17 @@ +from torch.nn.intrinsic.quantized.modules.bn_relu import BNReLU2d, BNReLU3d +from torch.nn.intrinsic.quantized.modules.conv_relu import ( + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, +) +from torch.nn.intrinsic.quantized.modules.linear_relu import LinearReLU + + +__all__ = [ + "LinearReLU", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "BNReLU2d", + "BNReLU3d", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0fa38f4768b96f5d474f3b09de2d3365c3c225f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f13472aa254694846ea29ce7cab6c6f0543d43cb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..636467db551e40e6992846571e79edee08c2454d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..950b8319e5fdd4b2cc63f3fbf47cb5ba9aa3b7be Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..f17253c93246f52c49b5310788028e91c61b92ad --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py @@ -0,0 +1,7 @@ +from torch.ao.nn.intrinsic.quantized import BNReLU2d, BNReLU3d + + +__all__ = [ + "BNReLU2d", + "BNReLU3d", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..5a72b4d5ebe8afd13d67eb6c1e9d227fce7aa6b2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py @@ -0,0 +1,8 @@ +from torch.ao.nn.intrinsic.quantized import ConvReLU1d, ConvReLU2d, ConvReLU3d + + +__all__ = [ + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..2d848c84d927669f2381edcbadd7e557ac429822 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py @@ -0,0 +1,6 @@ +from torch.ao.nn.intrinsic.quantized import LinearReLU + + +__all__ = [ + "LinearReLU", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..99260ad43fc477c36a9780c057824f57d4914719 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__init__.py @@ -0,0 +1,334 @@ +from .module import Module # usort: skip +from .linear import Bilinear, Identity, LazyLinear, Linear # usort: skip +from .activation import ( + CELU, + ELU, + GELU, + GLU, + Hardshrink, + Hardsigmoid, + Hardswish, + Hardtanh, + LeakyReLU, + LogSigmoid, + LogSoftmax, + Mish, + MultiheadAttention, + PReLU, + ReLU, + ReLU6, + RReLU, + SELU, + Sigmoid, + SiLU, + Softmax, + Softmax2d, + Softmin, + Softplus, + Softshrink, + Softsign, + Tanh, + Tanhshrink, + Threshold, +) +from .adaptive import AdaptiveLogSoftmaxWithLoss +from .batchnorm import ( + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, + LazyBatchNorm1d, + LazyBatchNorm2d, + LazyBatchNorm3d, + SyncBatchNorm, +) +from .channelshuffle import ChannelShuffle +from .container import ( + Container, + ModuleDict, + ModuleList, + ParameterDict, + ParameterList, + Sequential, +) +from .conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, + LazyConv1d, + LazyConv2d, + LazyConv3d, + LazyConvTranspose1d, + LazyConvTranspose2d, + LazyConvTranspose3d, +) +from .distance import CosineSimilarity, PairwiseDistance +from .dropout import ( + AlphaDropout, + Dropout, + Dropout1d, + Dropout2d, + Dropout3d, + FeatureAlphaDropout, +) +from .flatten import Flatten, Unflatten +from .fold import Fold, Unfold +from .instancenorm import ( + InstanceNorm1d, + InstanceNorm2d, + InstanceNorm3d, + LazyInstanceNorm1d, + LazyInstanceNorm2d, + LazyInstanceNorm3d, +) +from .loss import ( + BCELoss, + BCEWithLogitsLoss, + CosineEmbeddingLoss, + CrossEntropyLoss, + CTCLoss, + GaussianNLLLoss, + HingeEmbeddingLoss, + HuberLoss, + KLDivLoss, + L1Loss, + MarginRankingLoss, + MSELoss, + MultiLabelMarginLoss, + MultiLabelSoftMarginLoss, + MultiMarginLoss, + NLLLoss, + NLLLoss2d, + PoissonNLLLoss, + SmoothL1Loss, + SoftMarginLoss, + TripletMarginLoss, + TripletMarginWithDistanceLoss, +) +from .normalization import ( + CrossMapLRN2d, + GroupNorm, + LayerNorm, + LocalResponseNorm, + RMSNorm, +) +from .padding import ( + CircularPad1d, + CircularPad2d, + CircularPad3d, + ConstantPad1d, + ConstantPad2d, + ConstantPad3d, + ReflectionPad1d, + ReflectionPad2d, + ReflectionPad3d, + ReplicationPad1d, + ReplicationPad2d, + ReplicationPad3d, + ZeroPad1d, + ZeroPad2d, + ZeroPad3d, +) +from .pixelshuffle import PixelShuffle, PixelUnshuffle +from .pooling import ( + AdaptiveAvgPool1d, + AdaptiveAvgPool2d, + AdaptiveAvgPool3d, + AdaptiveMaxPool1d, + AdaptiveMaxPool2d, + AdaptiveMaxPool3d, + AvgPool1d, + AvgPool2d, + AvgPool3d, + FractionalMaxPool2d, + FractionalMaxPool3d, + LPPool1d, + LPPool2d, + LPPool3d, + MaxPool1d, + MaxPool2d, + MaxPool3d, + MaxUnpool1d, + MaxUnpool2d, + MaxUnpool3d, +) +from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNN, RNNBase, RNNCell, RNNCellBase +from .sparse import Embedding, EmbeddingBag +from .transformer import ( + Transformer, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, +) +from .upsampling import Upsample, UpsamplingBilinear2d, UpsamplingNearest2d + + +__all__ = [ + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", + "AdaptiveLogSoftmaxWithLoss", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + "AlphaDropout", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + "BCELoss", + "BCEWithLogitsLoss", + "BatchNorm1d", + "BatchNorm2d", + "BatchNorm3d", + "Bilinear", + "CELU", + "CTCLoss", + "ChannelShuffle", + "CircularPad1d", + "CircularPad2d", + "CircularPad3d", + "ConstantPad1d", + "ConstantPad2d", + "ConstantPad3d", + "Container", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "CosineEmbeddingLoss", + "CosineSimilarity", + "CrossEntropyLoss", + "CrossMapLRN2d", + "Dropout", + "Dropout1d", + "Dropout2d", + "Dropout3d", + "ELU", + "Embedding", + "EmbeddingBag", + "FeatureAlphaDropout", + "Flatten", + "Fold", + "FractionalMaxPool2d", + "FractionalMaxPool3d", + "GELU", + "GLU", + "GRU", + "GRUCell", + "GaussianNLLLoss", + "GroupNorm", + "Hardshrink", + "Hardsigmoid", + "Hardswish", + "Hardtanh", + "HingeEmbeddingLoss", + "HuberLoss", + "Identity", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "KLDivLoss", + "L1Loss", + "LPPool1d", + "LPPool2d", + "LPPool3d", + "LSTM", + "LSTMCell", + "LayerNorm", + "LazyBatchNorm1d", + "LazyBatchNorm2d", + "LazyBatchNorm3d", + "LazyConv1d", + "LazyConv2d", + "LazyConv3d", + "LazyConvTranspose1d", + "LazyConvTranspose2d", + "LazyConvTranspose3d", + "LazyInstanceNorm1d", + "LazyInstanceNorm2d", + "LazyInstanceNorm3d", + "LazyLinear", + "LeakyReLU", + "Linear", + "LocalResponseNorm", + "LogSigmoid", + "LogSoftmax", + "MSELoss", + "MarginRankingLoss", + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "MaxUnpool1d", + "MaxUnpool2d", + "MaxUnpool3d", + "Mish", + "Module", + "ModuleDict", + "ModuleList", + "MultiLabelMarginLoss", + "MultiLabelSoftMarginLoss", + "MultiMarginLoss", + "MultiheadAttention", + "NLLLoss", + "NLLLoss2d", + "PReLU", + "PairwiseDistance", + "ParameterDict", + "ParameterList", + "PixelShuffle", + "PixelUnshuffle", + "PoissonNLLLoss", + "RMSNorm", + "RNN", + "RNNBase", + "RNNCell", + "RNNCellBase", + "RReLU", + "ReLU", + "ReLU6", + "ReflectionPad1d", + "ReflectionPad2d", + "ReflectionPad3d", + "ReplicationPad1d", + "ReplicationPad2d", + "ReplicationPad3d", + "SELU", + "Sequential", + "SiLU", + "Sigmoid", + "SmoothL1Loss", + "SoftMarginLoss", + "Softmax", + "Softmax2d", + "Softmin", + "Softplus", + "Softshrink", + "Softsign", + "SyncBatchNorm", + "Tanh", + "Tanhshrink", + "Threshold", + "Transformer", + "TransformerDecoder", + "TransformerDecoderLayer", + "TransformerEncoder", + "TransformerEncoderLayer", + "TripletMarginLoss", + "TripletMarginWithDistanceLoss", + "Unflatten", + "Unfold", + "Upsample", + "UpsamplingBilinear2d", + "UpsamplingNearest2d", + "ZeroPad1d", + "ZeroPad2d", + "ZeroPad3d", +] + +# Please keep this list sorted +assert __all__ == sorted(__all__) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54ad35fc0f070a94c7b760585c6f3c141dccc840 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1420a4b38cfac2d296c2e6518777cee58f5b85b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/activation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/activation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..303151243d1a2904518dcce0355a4d0f92d3ca05 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/activation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0577443795c6cca4ce2576c9f34b3ead6ee1f45e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b06330ab65b02eb51912a8cc2a4514afdeaf3c6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e331d588324e784354c068bbff4419825a48a222 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/container.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/container.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60ee98c6c5a78b60aa3074d32b75f1a0fb19d13b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/container.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/conv.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e460261fd542a3838a0b684247c9dff835ce3736 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/distance.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/distance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cae9e46bac37e9e2e5891e60dc09ad911ea9ccc1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/distance.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/dropout.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/dropout.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1d8b3229fba37f738fac1749a0764bcf18b7676 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/dropout.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/flatten.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/flatten.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ab2f7fb6675cd90f98c21ce4a99d979be495786 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/flatten.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/fold.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/fold.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2490ba47d56ea582357f3e216603f62dcb0f7b70 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/fold.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7df6e75bf0a30beb28707ee3aa8d5794580aabf1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/lazy.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/lazy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17065520826c69918102f1bdf6236cb8c856c006 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/lazy.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/linear.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e518694e283156b20b852e3d397d986e173e170 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/normalization.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/normalization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5fd223740733e268c34eb68a3cbd2621cd5af13 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/normalization.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/padding.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/padding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b9aee1571182047fc2f71438fe598dc61a5cbec Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/padding.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e6058499d50d994db547dbe4b665ef4ac1cff63 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/pooling.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/pooling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0c537b52594f53223db93120ed7b5831be438e2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/pooling.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/rnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05b10b549eb7a584aa6d9fca137274c558170aa6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/rnn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/sparse.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/sparse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4ad8afd9951bbb53e32d3fc46bacefbeaddb975 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/sparse.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/transformer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..213ab5bb1067d928c4661a4275788b3f20377c9d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/transformer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f475a45a38964f12cc6fa9d79f0a03c4252f22b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f151977ab0b59021f00d971c64fe54e37114a8d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/_functions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..408e6ef42f12843ddbfc38d540fc68e454c9e958 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/_functions.py @@ -0,0 +1,319 @@ +# mypy: allow-untyped-defs +import torch +import torch.distributed as dist +from torch.autograd.function import Function + + +class SyncBatchNorm(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward( + self, + input, + weight, + bias, + running_mean, + running_var, + eps, + momentum, + process_group, + world_size, + ): + if not ( + input.is_contiguous(memory_format=torch.channels_last) + or input.is_contiguous(memory_format=torch.channels_last_3d) + ): + input = input.contiguous() + if weight is not None: + weight = weight.contiguous() + + size = int(input.numel() // input.size(1)) + if size == 1 and world_size < 2: + raise ValueError( + f"Expected more than 1 value per channel when training, got input size {size}" + ) + + num_channels = input.shape[1] + if input.numel() > 0: + # calculate mean/invstd for input. + mean, invstd = torch.batch_norm_stats(input, eps) + + count = torch.full( + (1,), + input.numel() // input.size(1), + dtype=mean.dtype, + device=mean.device, + ) + + # C, C, 1 -> (2C + 1) + combined = torch.cat([mean, invstd, count], dim=0) + else: + # for empty input, set stats and the count to zero. The stats with + # zero count will be filtered out later when computing global mean + # & invstd, but they still needs to participate the all_gather + # collective communication to unblock other peer processes. + combined = torch.zeros( + 2 * num_channels + 1, dtype=input.dtype, device=input.device + ) + + # Use allgather instead of allreduce because count could be different across + # ranks, simple all reduce op can not give correct results. + # batch_norm_gather_stats_with_counts calculates global mean & invstd based on + # all gathered mean, invstd and count. + # for nccl backend, use the optimized version of all gather. + # The Gloo backend does not support `all_gather_into_tensor`. + if process_group._get_backend_name() != "gloo": + # world_size * (2C + 1) + combined_size = combined.numel() + combined_flat = torch.empty( + 1, + combined_size * world_size, + dtype=combined.dtype, + device=combined.device, + ) + dist.all_gather_into_tensor( + combined_flat, combined, process_group, async_op=False + ) + combined = torch.reshape(combined_flat, (world_size, combined_size)) + # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 + mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) + else: + # world_size * (2C + 1) + combined_list = [torch.empty_like(combined) for _ in range(world_size)] + dist.all_gather(combined_list, combined, process_group, async_op=False) + combined = torch.stack(combined_list, dim=0) + # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 + mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) + + if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()): + # The lines below force a synchronization between CUDA and CPU, because + # the shape of the result count_all depends on the values in mask tensor. + # Such synchronizations break CUDA Graph capturing. + # See https://github.com/pytorch/pytorch/issues/78549 + # FIXME: https://github.com/pytorch/pytorch/issues/78656 describes + # a better longer-term solution. + + # remove stats from empty inputs + mask = count_all.squeeze(-1) >= 1 + count_all = count_all[mask] + mean_all = mean_all[mask] + invstd_all = invstd_all[mask] + + # calculate global mean & invstd + counts = count_all.view(-1) + if running_mean is not None and counts.dtype != running_mean.dtype: + counts = counts.to(running_mean.dtype) + mean, invstd = torch.batch_norm_gather_stats_with_counts( + input, + mean_all, + invstd_all, + running_mean, + running_var, + momentum, + eps, + counts, + ) + + self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32)) + self.process_group = process_group + + # apply element-wise normalization + if input.numel() > 0: + return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) + else: + return torch.empty_like(input) + + @staticmethod + def backward(self, grad_output): + if not ( + grad_output.is_contiguous(memory_format=torch.channels_last) + or grad_output.is_contiguous(memory_format=torch.channels_last_3d) + ): + grad_output = grad_output.contiguous() + saved_input, weight, mean, invstd, count_tensor = self.saved_tensors + grad_input = grad_weight = grad_bias = None + process_group = self.process_group + + if saved_input.numel() > 0: + # calculate local stats as well as grad_weight / grad_bias + ( + sum_dy, + sum_dy_xmu, + grad_weight, + grad_bias, + ) = torch.batch_norm_backward_reduce( + grad_output, + saved_input, + mean, + invstd, + weight, + self.needs_input_grad[0], + self.needs_input_grad[1], + self.needs_input_grad[2], + ) + + if self.needs_input_grad[0]: + # synchronizing stats used to calculate input gradient. + num_channels = sum_dy.shape[0] + combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) + torch.distributed.all_reduce( + combined, + torch.distributed.ReduceOp.SUM, + process_group, + async_op=False, + ) + sum_dy, sum_dy_xmu = torch.split(combined, num_channels) + + # backward pass for gradient calculation + if weight is not None and weight.dtype != mean.dtype: + weight = weight.to(mean.dtype) + grad_input = torch.batch_norm_backward_elemt( + grad_output, + saved_input, + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count_tensor, + ) + # synchronizing of grad_weight / grad_bias is not needed as distributed + # training would handle all reduce. + if weight is None or not self.needs_input_grad[1]: + grad_weight = None + + if weight is None or not self.needs_input_grad[2]: + grad_bias = None + else: + # This process got an empty input tensor in the forward pass. + # Although this process can directly set grad_input as an empty + # tensor of zeros, it still needs to participate in the collective + # communication to unblock its peers, as other peer processes might + # have received non-empty inputs. + num_channels = saved_input.shape[1] + if self.needs_input_grad[0]: + # launch all_reduce to unblock other peer processes + combined = torch.zeros( + 2 * num_channels, dtype=saved_input.dtype, device=saved_input.device + ) + torch.distributed.all_reduce( + combined, + torch.distributed.ReduceOp.SUM, + process_group, + async_op=False, + ) + + # Leave grad_input, grad_weight and grad_bias as None, which will be + # interpreted by the autograd engine as Tensors full of zeros. + + return grad_input, grad_weight, grad_bias, None, None, None, None, None, None + + +class CrossMapLRN2d(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): + ctx.size = size + ctx.alpha = alpha + ctx.beta = beta + ctx.k = k + ctx.scale = None + + if input.dim() != 4: + raise ValueError( + f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead." + ) + + ctx.scale = ctx.scale or input.new() + output = input.new() + channels = input.size(1) + + output.resize_as_(input) + ctx.scale.resize_as_(input) + + # use output storage as temporary buffer + input_square = output + torch.pow(input, 2, out=input_square) + + pre_pad = int((ctx.size - 1) / 2 + 1) + pre_pad_crop = min(pre_pad, channels) + + scale_first = ctx.scale.select(1, 0) + scale_first.zero_() + # compute first feature map normalization + for c in range(pre_pad_crop): + scale_first.add_(input_square.select(1, c)) + + # reuse computations for next feature maps normalization + # by adding the next feature map and removing the previous + for c in range(1, channels): + scale_previous = ctx.scale.select(1, c - 1) + scale_current = ctx.scale.select(1, c) + scale_current.copy_(scale_previous) + if c < channels - pre_pad + 1: + square_next = input_square.select(1, c + pre_pad - 1) + scale_current.add_(square_next, alpha=1) + + if c > pre_pad: + square_previous = input_square.select(1, c - pre_pad) + scale_current.add_(square_previous, alpha=-1) + + ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k) + + torch.pow(ctx.scale, -ctx.beta, out=output) + output.mul_(input) + + ctx.save_for_backward(input, output) + return output + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + input, output = ctx.saved_tensors + grad_input = grad_output.new() + + batch_size = input.size(0) + channels = input.size(1) + input_height = input.size(2) + input_width = input.size(3) + + paddded_ratio = input.new(channels + ctx.size - 1, input_height, input_width) + accum_ratio = input.new(input_height, input_width) + + cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size + inversePrePad = int(ctx.size - (ctx.size - 1) / 2) + + grad_input.resize_as_(input) + torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output) + + paddded_ratio.zero_() + padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, channels) + for n in range(batch_size): + torch.mul(grad_output[n], output[n], out=padded_ratio_center) + padded_ratio_center.div_(ctx.scale[n]) + torch.sum( + paddded_ratio.narrow(0, 0, ctx.size - 1), + 0, + keepdim=False, + out=accum_ratio, + ) + for c in range(channels): + accum_ratio.add_(paddded_ratio[c + ctx.size - 1]) + grad_input[n][c].addcmul_( + input[n][c], accum_ratio, value=-cache_ratio_value + ) + accum_ratio.add_(paddded_ratio[c], alpha=-1) + + return grad_input, None, None, None, None + + +class BackwardHookFunction(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, *args): + ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad]) + return args + + @staticmethod + def backward(ctx, *args): + return args diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/activation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..dac27cdb0d2464847a85e4ee8683326188875977 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/activation.py @@ -0,0 +1,1905 @@ +# mypy: allow-untyped-defs +import warnings + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ +from torch.nn.parameter import Parameter + +from .linear import NonDynamicallyQuantizableLinear +from .module import Module + + +__all__ = [ + "Threshold", + "ReLU", + "RReLU", + "Hardtanh", + "ReLU6", + "Sigmoid", + "Hardsigmoid", + "Tanh", + "SiLU", + "Mish", + "Hardswish", + "ELU", + "CELU", + "SELU", + "GLU", + "GELU", + "Hardshrink", + "LeakyReLU", + "LogSigmoid", + "Softplus", + "Softshrink", + "MultiheadAttention", + "PReLU", + "Softsign", + "Tanhshrink", + "Softmin", + "Softmax", + "Softmax2d", + "LogSoftmax", +] + + +class Threshold(Module): + r"""Thresholds each element of the input Tensor. + + Threshold is defined as: + + .. math:: + y = + \begin{cases} + x, &\text{ if } x > \text{threshold} \\ + \text{value}, &\text{ otherwise } + \end{cases} + + Args: + threshold: The value to threshold at + value: The value to replace with + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Threshold.png + + Examples:: + + >>> m = nn.Threshold(0, 0.5) + >>> input = torch.arange(-3, 3) + >>> output = m(input) + """ + + __constants__ = ["threshold", "value", "inplace"] + + threshold: float + value: float + inplace: bool + + def __init__(self, threshold: float, value: float, inplace: bool = False) -> None: + super().__init__() + self.threshold = threshold + self.value = value + self.inplace = inplace + # TODO: check in THNN (if inplace == True, then assert value <= threshold) + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.threshold(input, self.threshold, self.value, self.inplace) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + inplace_str = ", inplace=True" if self.inplace else "" + return f"threshold={self.threshold}, value={self.value}{inplace_str}" + + +class ReLU(Module): + r"""Applies the rectified linear unit function element-wise. + + :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)` + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/ReLU.png + + Examples:: + + >>> m = nn.ReLU() + >>> input = torch.randn(2) + >>> output = m(input) + + + An implementation of CReLU - https://arxiv.org/abs/1603.05201 + + >>> m = nn.ReLU() + >>> input = torch.randn(2).unsqueeze(0) + >>> output = torch.cat((m(input), m(-input))) + """ + + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.relu(input, inplace=self.inplace) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class RReLU(Module): + r"""Applies the randomized leaky rectified linear unit function, element-wise. + + Method described in the paper: + `Empirical Evaluation of Rectified Activations in Convolutional Network `_. + + The function is defined as: + + .. math:: + \text{RReLU}(x) = + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} + + where :math:`a` is randomly sampled from uniform distribution + :math:`\mathcal{U}(\text{lower}, \text{upper})` during training while during + evaluation :math:`a` is fixed with :math:`a = \frac{\text{lower} + \text{upper}}{2}`. + + Args: + lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` + upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/RReLU.png + + Examples:: + + >>> m = nn.RReLU(0.1, 0.3) + >>> input = torch.randn(2) + >>> output = m(input) + + """ + + __constants__ = ["lower", "upper", "inplace"] + + lower: float + upper: float + inplace: bool + + def __init__( + self, lower: float = 1.0 / 8, upper: float = 1.0 / 3, inplace: bool = False + ) -> None: + super().__init__() + self.lower = lower + self.upper = upper + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.rrelu(input, self.lower, self.upper, self.training, self.inplace) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + inplace_str = ", inplace=True" if self.inplace else "" + return f"lower={self.lower}, upper={self.upper}{inplace_str}" + + +class Hardtanh(Module): + r"""Applies the HardTanh function element-wise. + + HardTanh is defined as: + + .. math:: + \text{HardTanh}(x) = \begin{cases} + \text{max\_val} & \text{ if } x > \text{ max\_val } \\ + \text{min\_val} & \text{ if } x < \text{ min\_val } \\ + x & \text{ otherwise } \\ + \end{cases} + + Args: + min_val: minimum value of the linear region range. Default: -1 + max_val: maximum value of the linear region range. Default: 1 + inplace: can optionally do the operation in-place. Default: ``False`` + + Keyword arguments :attr:`min_value` and :attr:`max_value` + have been deprecated in favor of :attr:`min_val` and :attr:`max_val`. + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Hardtanh.png + + Examples:: + + >>> m = nn.Hardtanh(-2, 2) + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["min_val", "max_val", "inplace"] + + min_val: float + max_val: float + inplace: bool + + def __init__( + self, + min_val: float = -1.0, + max_val: float = 1.0, + inplace: bool = False, + min_value: float | None = None, + max_value: float | None = None, + ) -> None: + super().__init__() + if min_value is not None: + warnings.warn( + "keyword argument `min_value` is deprecated and rename to `min_val`", + FutureWarning, + stacklevel=2, + ) + min_val = min_value + if max_value is not None: + warnings.warn( + "keyword argument `max_value` is deprecated and rename to `max_val`", + FutureWarning, + stacklevel=2, + ) + max_val = max_value + + self.min_val = min_val + self.max_val = max_val + self.inplace = inplace + assert self.max_val > self.min_val + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.hardtanh(input, self.min_val, self.max_val, self.inplace) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + inplace_str = ", inplace=True" if self.inplace else "" + return f"min_val={self.min_val}, max_val={self.max_val}{inplace_str}" + + +class ReLU6(Hardtanh): + r"""Applies the ReLU6 function element-wise. + + .. math:: + \text{ReLU6}(x) = \min(\max(0,x), 6) + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/ReLU6.png + + Examples:: + + >>> m = nn.ReLU6() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def __init__(self, inplace: bool = False) -> None: + super().__init__(0.0, 6.0, inplace) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class Sigmoid(Module): + r"""Applies the Sigmoid function element-wise. + + .. math:: + \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} + + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Sigmoid.png + + Examples:: + + >>> m = nn.Sigmoid() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return torch.sigmoid(input) + + +class Hardsigmoid(Module): + r"""Applies the Hardsigmoid function element-wise. + + Hardsigmoid is defined as: + + .. math:: + \text{Hardsigmoid}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + 1 & \text{if~} x \ge +3, \\ + x / 6 + 1 / 2 & \text{otherwise} + \end{cases} + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Hardsigmoid.png + + Examples:: + + >>> m = nn.Hardsigmoid() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["inplace"] + + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.hardsigmoid(input, self.inplace) + + +class Tanh(Module): + r"""Applies the Hyperbolic Tangent (Tanh) function element-wise. + + Tanh is defined as: + + .. math:: + \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)} + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Tanh.png + + Examples:: + + >>> m = nn.Tanh() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return torch.tanh(input) + + +class SiLU(Module): + r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise. + + The SiLU function is also known as the swish function. + + .. math:: + \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} + + .. note:: + See `Gaussian Error Linear Units (GELUs) `_ + where the SiLU (Sigmoid Linear Unit) was originally coined, and see + `Sigmoid-Weighted Linear Units for Neural Network Function Approximation + in Reinforcement Learning `_ and `Swish: + a Self-Gated Activation Function `_ + where the SiLU was experimented with later. + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/SiLU.png + + Examples:: + + >>> m = nn.SiLU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.silu(input, inplace=self.inplace) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class Mish(Module): + r"""Applies the Mish function, element-wise. + + Mish: A Self Regularized Non-Monotonic Neural Activation Function. + + .. math:: + \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) + + .. note:: + See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Mish.png + + Examples:: + + >>> m = nn.Mish() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.mish(input, inplace=self.inplace) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class Hardswish(Module): + r"""Applies the Hardswish function, element-wise. + + Method described in the paper: `Searching for MobileNetV3 `_. + + Hardswish is defined as: + + .. math:: + \text{Hardswish}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + x & \text{if~} x \ge +3, \\ + x \cdot (x + 3) /6 & \text{otherwise} + \end{cases} + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Hardswish.png + + Examples:: + + >>> m = nn.Hardswish() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["inplace"] + + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.hardswish(input, self.inplace) + + +class ELU(Module): + r"""Applies the Exponential Linear Unit (ELU) function, element-wise. + + Method described in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear + Units (ELUs) `__. + + ELU is defined as: + + .. math:: + \text{ELU}(x) = \begin{cases} + x, & \text{ if } x > 0\\ + \alpha * (\exp(x) - 1), & \text{ if } x \leq 0 + \end{cases} + + Args: + alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0 + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/ELU.png + + Examples:: + + >>> m = nn.ELU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["alpha", "inplace"] + alpha: float + inplace: bool + + def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: + super().__init__() + self.alpha = alpha + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.elu(input, self.alpha, self.inplace) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + inplace_str = ", inplace=True" if self.inplace else "" + return f"alpha={self.alpha}{inplace_str}" + + +class CELU(Module): + r"""Applies the CELU function element-wise. + + .. math:: + \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1)) + + More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ . + + Args: + alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/CELU.png + + Examples:: + + >>> m = nn.CELU() + >>> input = torch.randn(2) + >>> output = m(input) + + .. _`Continuously Differentiable Exponential Linear Units`: + https://arxiv.org/abs/1704.07483 + """ + + __constants__ = ["alpha", "inplace"] + alpha: float + inplace: bool + + def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: + super().__init__() + self.alpha = alpha + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.celu(input, self.alpha, self.inplace) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + inplace_str = ", inplace=True" if self.inplace else "" + return f"alpha={self.alpha}{inplace_str}" + + +class SELU(Module): + r"""Applies the SELU function element-wise. + + .. math:: + \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))) + + with :math:`\alpha = 1.6732632423543772848170429916717` and + :math:`\text{scale} = 1.0507009873554804934193349852946`. + + .. warning:: + When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation, + ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'`` + in order to get `Self-Normalizing Neural Networks`_. + See :func:`torch.nn.init.calculate_gain` for more information. + + More details can be found in the paper `Self-Normalizing Neural Networks`_ . + + Args: + inplace (bool, optional): can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/SELU.png + + Examples:: + + >>> m = nn.SELU() + >>> input = torch.randn(2) + >>> output = m(input) + + .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 + """ + + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.selu(input, self.inplace) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class GLU(Module): + r"""Applies the gated linear unit function. + + :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half + of the input matrices and :math:`b` is the second half. + + Args: + dim (int): the dimension on which to split the input. Default: -1 + + Shape: + - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional + dimensions + - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` + + .. image:: ../scripts/activation_images/GLU.png + + Examples:: + + >>> m = nn.GLU() + >>> input = torch.randn(4, 2) + >>> output = m(input) + """ + + __constants__ = ["dim"] + dim: int + + def __init__(self, dim: int = -1) -> None: + super().__init__() + self.dim = dim + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.glu(input, self.dim) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"dim={self.dim}" + + +class GELU(Module): + r"""Applies the Gaussian Error Linear Units function. + + .. math:: \text{GELU}(x) = x * \Phi(x) + + where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. + + When the approximate argument is 'tanh', Gelu is estimated with: + + .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3))) + + Args: + approximate (str, optional): the gelu approximation algorithm to use: + ``'none'`` | ``'tanh'``. Default: ``'none'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/GELU.png + + Examples:: + + >>> m = nn.GELU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["approximate"] + approximate: str + + def __init__(self, approximate: str = "none") -> None: + super().__init__() + self.approximate = approximate + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.gelu(input, approximate=self.approximate) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"approximate={repr(self.approximate)}" + + +class Hardshrink(Module): + r"""Applies the Hard Shrinkage (Hardshrink) function element-wise. + + Hardshrink is defined as: + + .. math:: + \text{HardShrink}(x) = + \begin{cases} + x, & \text{ if } x > \lambda \\ + x, & \text{ if } x < -\lambda \\ + 0, & \text{ otherwise } + \end{cases} + + Args: + lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Hardshrink.png + + Examples:: + + >>> m = nn.Hardshrink() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["lambd"] + lambd: float + + def __init__(self, lambd: float = 0.5) -> None: + super().__init__() + self.lambd = lambd + + def forward(self, input: Tensor) -> Tensor: + """ + Run forward pass. + """ + return F.hardshrink(input, self.lambd) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"{self.lambd}" + + +class LeakyReLU(Module): + r"""Applies the LeakyReLU function element-wise. + + .. math:: + \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x) + + + or + + .. math:: + \text{LeakyReLU}(x) = + \begin{cases} + x, & \text{ if } x \geq 0 \\ + \text{negative\_slope} \times x, & \text{ otherwise } + \end{cases} + + Args: + negative_slope: Controls the angle of the negative slope (which is used for + negative input values). Default: 1e-2 + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + .. image:: ../scripts/activation_images/LeakyReLU.png + + Examples:: + + >>> m = nn.LeakyReLU(0.1) + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["inplace", "negative_slope"] + inplace: bool + negative_slope: float + + def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None: + super().__init__() + self.negative_slope = negative_slope + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + """ + Run forward pass. + """ + return F.leaky_relu(input, self.negative_slope, self.inplace) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + inplace_str = ", inplace=True" if self.inplace else "" + return f"negative_slope={self.negative_slope}{inplace_str}" + + +class LogSigmoid(Module): + r"""Applies the Logsigmoid function element-wise. + + .. math:: + \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right) + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/LogSigmoid.png + + Examples:: + + >>> m = nn.LogSigmoid() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + """ + Run forward pass. + """ + return F.logsigmoid(input) + + +class Softplus(Module): + r"""Applies the Softplus function element-wise. + + .. math:: + \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) + + SoftPlus is a smooth approximation to the ReLU function and can be used + to constrain the output of a machine to always be positive. + + For numerical stability the implementation reverts to the linear function + when :math:`input \times \beta > threshold`. + + Args: + beta: the :math:`\beta` value for the Softplus formulation. Default: 1 + threshold: values above this revert to a linear function. Default: 20 + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Softplus.png + + Examples:: + + >>> m = nn.Softplus() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["beta", "threshold"] + beta: float + threshold: float + + def __init__(self, beta: float = 1.0, threshold: float = 20.0) -> None: + super().__init__() + self.beta = beta + self.threshold = threshold + + def forward(self, input: Tensor) -> Tensor: + """ + Run forward pass. + """ + return F.softplus(input, self.beta, self.threshold) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"beta={self.beta}, threshold={self.threshold}" + + +class Softshrink(Module): + r"""Applies the soft shrinkage function element-wise. + + .. math:: + \text{SoftShrinkage}(x) = + \begin{cases} + x - \lambda, & \text{ if } x > \lambda \\ + x + \lambda, & \text{ if } x < -\lambda \\ + 0, & \text{ otherwise } + \end{cases} + + Args: + lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5 + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Softshrink.png + + Examples:: + + >>> m = nn.Softshrink() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["lambd"] + lambd: float + + def __init__(self, lambd: float = 0.5) -> None: + super().__init__() + self.lambd = lambd + + def forward(self, input: Tensor) -> Tensor: + """ + Run forward pass. + """ + return F.softshrink(input, self.lambd) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return str(self.lambd) + + +def _check_arg_device(x: torch.Tensor | None) -> bool: + if x is not None: + return x.device.type in [ + "cpu", + "cuda", + torch.utils.backend_registration._privateuse1_backend_name, + ] + return True + + +def _arg_requires_grad(x: torch.Tensor | None) -> bool: + if x is not None: + return x.requires_grad + return False + + +def _is_make_fx_tracing(): + if not torch.jit.is_scripting(): + torch_dispatch_mode_stack = ( + torch.utils._python_dispatch._get_current_dispatch_mode_stack() + ) + # this can be triggered when dynamo inlining the module too. + return ( + any( + type(x) is torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode + for x in torch_dispatch_mode_stack + ) + or torch.compiler.is_exporting() + ) + else: + return False + + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information from different representation subspaces. + + This MultiheadAttention layer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. + + Multi-Head Attention is defined as: + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O + + where :math:`\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + + ``nn.MultiheadAttention`` will use the optimized implementations of + ``scaled_dot_product_attention()`` when possible. + + In addition to support for the new ``scaled_dot_product_attention()`` + function, for speeding up Inference, MHA will use + fastpath inference with support for Nested Tensors, iff: + + - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor). + - inputs are batched (3D) with ``batch_first==True`` + - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - ``add_bias_kv`` is ``False`` + - ``add_zero_attn`` is ``False`` + - ``kdim`` and ``vdim`` are equal to ``embed_dim`` + - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` + nor ``attn_mask`` is passed + - autocast is disabled + + If the optimized inference fastpath implementation is in use, a + `NestedTensor `_ can be passed for + ``query``/``key``/``value`` to represent padding more efficiently than using a + padding mask. In this case, a `NestedTensor `_ + will be returned, and an additional speedup proportional to the fraction of the input + that is padding can be expected. + + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + bias: If specified, adds bias to input / output projection layers. Default: ``True``. + add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Examples:: + + >>> # xdoctest: +SKIP + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`: + https://arxiv.org/abs/2205.14135 + + """ + + __constants__ = ["batch_first"] + bias_k: torch.Tensor | None + bias_v: torch.Tensor | None + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None, + ) -> None: + if embed_dim <= 0 or num_heads <= 0: + raise ValueError( + f"embed_dim and num_heads must be greater than 0," + f" got embed_dim={embed_dim} and num_heads={num_heads} instead" + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, ( + "embed_dim must be divisible by num_heads" + ) + + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) + self.k_proj_weight = Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) + self.v_proj_weight = Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self) -> None: + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super().__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Tensor | None = None, + need_weights: bool = True, + attn_mask: Tensor | None = None, + average_attn_weights: bool = True, + is_causal: bool = False, + ) -> tuple[Tensor, Tensor | None]: + r"""Compute attention outputs using query, key, and value embeddings. + + Supports optional parameters for padding, masks and attention weights. + + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and float masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention`` + and achieve the best performance for MHA. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + If both attn_mask and key_padding_mask are supplied, their types should match. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + is_causal: If specified, applies a causal mask as attention mask. + Default: ``False``. + Warning: + ``is_causal`` provides a hint that ``attn_mask`` is the + causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ # noqa: B950 + why_not_fast_path = "" + if ( + (attn_mask is not None and torch.is_floating_point(attn_mask)) + or (key_padding_mask is not None) + and torch.is_floating_point(key_padding_mask) + ): + why_not_fast_path = "floating-point masks are not supported for fast path." + + is_batched = query.dim() == 3 + + key_padding_mask = F._canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=F._none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype, + ) + + attn_mask = F._canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled() + + if not is_fastpath_enabled: + why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True" + elif not is_batched: + why_not_fast_path = ( + f"input not batched; expected query.dim() of 3 but got {query.dim()}" + ) + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" + elif self.in_proj_weight is None: + why_not_fast_path = "in_proj_weight was None" + elif query.dtype != self.in_proj_weight.dtype: + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" + elif self.training: + why_not_fast_path = "training is enabled" + elif (self.num_heads % 2) != 0: + why_not_fast_path = "self.num_heads is not even" + elif not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif self.bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif self.bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif self.add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + elif query.is_nested and ( + key_padding_mask is not None or attn_mask is not None + ): + why_not_fast_path = ( + "supplying both src_key_padding_mask and src_mask at the same time \ + is not supported with NestedTensor input" + ) + elif torch.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + ) + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif _is_make_fx_tracing(): + why_not_fast_path = "we are running make_fx tracing" + elif not all(_check_arg_device(x) for x in tensor_args): + why_not_fast_path = ( + "some Tensor argument's device is neither one of " + f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}" + ) + elif torch.is_grad_enabled() and any( + _arg_requires_grad(x) for x in tensor_args + ): + why_not_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + if not why_not_fast_path: + merged_mask, mask_type = self.merge_masks( + attn_mask, key_padding_mask, query + ) + + if self.in_proj_bias is not None and self.in_proj_weight is not None: + return torch._native_multi_head_attention( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + merged_mask, + need_weights, + average_attn_weights, + mask_type, + ) + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ( + "MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}" + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = (x.transpose(1, 0) for x in (query, key)) + value = key + else: + query, key, value = (x.transpose(1, 0) for x in (query, key, value)) + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + is_causal=is_causal, + ) + else: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + is_causal=is_causal, + ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights + else: + return attn_output, attn_output_weights + + def merge_masks( + self, + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, + query: Tensor, + ) -> tuple[Tensor | None, int | None]: + r"""Determine mask type and combine masks if necessary. + + If only one mask is provided, that mask + and the corresponding mask type will be returned. If both masks are provided, they will be both + expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` + and mask type 2 will be returned + Args: + attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 + key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 + query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` + Returns: + merged_mask: merged mask + mask_type: merged mask type (0, 1, or 2) + """ + mask_type: int | None = None + merged_mask: Tensor | None = None + + if key_padding_mask is not None: + mask_type = 1 + merged_mask = key_padding_mask + + if attn_mask is not None: + # In this branch query can't be a nested tensor, so it has a shape + batch_size, seq_len, _ = query.shape + mask_type = 2 + + # Always expands attn_mask to 4D + if attn_mask.dim() == 3: + attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len) + else: # attn_mask.dim() == 2: + attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand( + batch_size, self.num_heads, -1, -1 + ) + merged_mask = attn_mask_expanded + + if key_padding_mask is not None: + key_padding_mask_expanded = key_padding_mask.view( + batch_size, 1, 1, seq_len + ).expand(-1, self.num_heads, -1, -1) + merged_mask = attn_mask_expanded + key_padding_mask_expanded + + # no attn_mask and no key_padding_mask, returns None, None + return merged_mask, mask_type + + +class PReLU(Module): + r"""Applies the element-wise PReLU function. + + .. math:: + \text{PReLU}(x) = \max(0,x) + a * \min(0,x) + + or + + .. math:: + \text{PReLU}(x) = + \begin{cases} + x, & \text{ if } x \ge 0 \\ + ax, & \text{ otherwise } + \end{cases} + + Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single + parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`, + a separate :math:`a` is used for each input channel. + + + .. note:: + weight decay should not be used when learning :math:`a` for good performance. + + .. note:: + Channel dim is the 2nd dim of input. When input has dims < 2, then there is + no channel dim and the number of channels = 1. + + Args: + num_parameters (int): number of :math:`a` to learn. + Although it takes an int as input, there is only two values are legitimate: + 1, or the number of channels at input. Default: 1 + init (float): the initial value of :math:`a`. Default: 0.25 + + Shape: + - Input: :math:`( *)` where `*` means, any number of additional + dimensions. + - Output: :math:`(*)`, same shape as the input. + + Attributes: + weight (Tensor): the learnable weights of shape (:attr:`num_parameters`). + + .. image:: ../scripts/activation_images/PReLU.png + + Examples:: + + >>> m = nn.PReLU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + __constants__ = ["num_parameters"] + num_parameters: int + + def __init__( + self, num_parameters: int = 1, init: float = 0.25, device=None, dtype=None + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + self.num_parameters = num_parameters + super().__init__() + self.init = init + self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs)) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in ``__init__``. + """ + torch.nn.init.constant_(self.weight, self.init) + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.prelu(input, self.weight) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"num_parameters={self.num_parameters}" + + +class Softsign(Module): + r"""Applies the element-wise Softsign function. + + .. math:: + \text{SoftSign}(x) = \frac{x}{ 1 + |x|} + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Softsign.png + + Examples:: + + >>> m = nn.Softsign() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.softsign(input) + + +class Tanhshrink(Module): + r"""Applies the element-wise Tanhshrink function. + + .. math:: + \text{Tanhshrink}(x) = x - \tanh(x) + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + .. image:: ../scripts/activation_images/Tanhshrink.png + + Examples:: + + >>> m = nn.Tanhshrink() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.tanhshrink(input) + + +class Softmin(Module): + r"""Applies the Softmin function to an n-dimensional input Tensor. + + Rescales them so that the elements of the n-dimensional output Tensor + lie in the range `[0, 1]` and sum to 1. + + Softmin is defined as: + + .. math:: + \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)} + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + Args: + dim (int): A dimension along which Softmin will be computed (so every slice + along dim will sum to 1). + + Returns: + a Tensor of the same dimension and shape as the input, with + values in the range [0, 1] + + Examples:: + + >>> m = nn.Softmin(dim=1) + >>> input = torch.randn(2, 3) + >>> output = m(input) + """ + + __constants__ = ["dim"] + dim: int | None + + def __init__(self, dim: int | None = None) -> None: + super().__init__() + self.dim = dim + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "dim"): + self.dim = None + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.softmin(input, self.dim, _stacklevel=5) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"dim={self.dim}" + + +class Softmax(Module): + r"""Applies the Softmax function to an n-dimensional input Tensor. + + Rescales them so that the elements of the n-dimensional output Tensor + lie in the range [0,1] and sum to 1. + + Softmax is defined as: + + .. math:: + \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} + + When the input Tensor is a sparse tensor then the unspecified + values are treated as ``-inf``. + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [0, 1] + + Args: + dim (int): A dimension along which Softmax will be computed (so every slice + along dim will sum to 1). + + .. note:: + This module doesn't work directly with NLLLoss, + which expects the Log to be computed between the Softmax and itself. + Use `LogSoftmax` instead (it's faster and has better numerical properties). + + Examples:: + + >>> m = nn.Softmax(dim=1) + >>> input = torch.randn(2, 3) + >>> output = m(input) + + """ + + __constants__ = ["dim"] + dim: int | None + + def __init__(self, dim: int | None = None) -> None: + super().__init__() + self.dim = dim + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "dim"): + self.dim = None + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.softmax(input, self.dim, _stacklevel=5) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"dim={self.dim}" + + +class Softmax2d(Module): + r"""Applies SoftMax over features to each spatial location. + + When given an image of ``Channels x Height x Width``, it will + apply `Softmax` to each location :math:`(Channels, h_i, w_j)` + + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`. + - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [0, 1] + + Examples:: + + >>> m = nn.Softmax2d() + >>> # you softmax over the 2nd dimension + >>> input = torch.randn(2, 3, 12, 13) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + if input.dim() not in (3, 4): + raise ValueError( + f"Softmax2d: expected input to be 3D or 4D, got {input.dim()}D instead" + ) + return F.softmax(input, -3, _stacklevel=5) + + +class LogSoftmax(Module): + r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor. + + The LogSoftmax formulation can be simplified as: + + .. math:: + \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right) + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + Args: + dim (int): A dimension along which LogSoftmax will be computed. + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [-inf, 0) + + Examples:: + + >>> m = nn.LogSoftmax(dim=1) + >>> input = torch.randn(2, 3) + >>> output = m(input) + """ + + __constants__ = ["dim"] + dim: int | None + + def __init__(self, dim: int | None = None) -> None: + super().__init__() + self.dim = dim + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "dim"): + self.dim = None + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.log_softmax(input, self.dim, _stacklevel=5) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"dim={self.dim}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/adaptive.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/adaptive.py new file mode 100644 index 0000000000000000000000000000000000000000..4267ed9993bff1ff69d57028308f4a3121ef2050 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/adaptive.py @@ -0,0 +1,339 @@ +# mypy: allow-untyped-defs + +import itertools +from collections import namedtuple +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from torch import Tensor + +from .container import ModuleList, Sequential +from .linear import Linear +from .module import Module + + +__all__ = ["AdaptiveLogSoftmaxWithLoss"] + +_ASMoutput = namedtuple("_ASMoutput", ["output", "loss"]) + + +class AdaptiveLogSoftmaxWithLoss(Module): + ( + """Efficient softmax approximation. + + As described in + `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin, + Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou + `__. +""" + r""" + Adaptive softmax is an approximate strategy for training models with large + output spaces. It is most effective when the label distribution is highly + imbalanced, for example in natural language modelling, where the word + frequency distribution approximately follows the `Zipf's law`_. + + Adaptive softmax partitions the labels into several clusters, according to + their frequency. These clusters may contain different number of targets + each. + Additionally, clusters containing less frequent labels assign lower + dimensional embeddings to those labels, which speeds up the computation. + For each minibatch, only clusters for which at least one target is + present are evaluated. + + The idea is that the clusters which are accessed frequently + (like the first one, containing most frequent labels), should also be cheap + to compute -- that is, contain a small number of assigned labels. + + We highly recommend taking a look at the original paper for more details. + + * :attr:`cutoffs` should be an ordered Sequence of integers sorted + in the increasing order. + It controls number of clusters and the partitioning of targets into + clusters. For example setting ``cutoffs = [10, 100, 1000]`` + means that first `10` targets will be assigned + to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be + assigned to the first cluster, and targets `101, 102, ..., 1000` will be + assigned to the second cluster, while targets + `1001, 1002, ..., n_classes - 1` will be assigned + to the last, third cluster. + + * :attr:`div_value` is used to compute the size of each additional cluster, + which is given as + :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, + where :math:`idx` is the cluster index (with clusters + for less frequent words having larger indices, + and indices starting from :math:`1`). + + * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the + adaptive softmax. See paper for details. Set to False in the official + implementation. + + .. warning:: + Labels passed as inputs to this module should be sorted according to + their frequency. This means that the most frequent label should be + represented by the index `0`, and the least frequent + label should be represented by the index `n_classes - 1`. + + .. note:: + This module returns a ``NamedTuple`` with ``output`` + and ``loss`` fields. See further documentation for details. + + .. note:: + To compute log-probabilities for all classes, the ``log_prob`` + method can be used. + + Args: + in_features (int): Number of features in the input tensor + n_classes (int): Number of classes in the dataset + cutoffs (Sequence): Cutoffs used to assign targets to their buckets + div_value (float, optional): value used as an exponent to compute sizes + of the clusters. Default: 4.0 + head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the + adaptive softmax. Default: ``False`` + + Returns: + ``NamedTuple`` with ``output`` and ``loss`` fields: + * **output** is a Tensor of size ``N`` containing computed target + log probabilities for each example + * **loss** is a Scalar representing the computed negative + log likelihood loss + + Shape: + - input: :math:`(N, \texttt{in\_features})` or :math:`(\texttt{in\_features})` + - target: :math:`(N)` or :math:`()` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}` + - output1: :math:`(N)` or :math:`()` + - output2: ``Scalar`` + + .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law + """ + ) + + in_features: int + n_classes: int + cutoffs: list[int] + div_value: float + head_bias: bool + head: Linear + tail: ModuleList + + def __init__( + self, + in_features: int, + n_classes: int, + cutoffs: Sequence[int], + div_value: float = 4.0, + head_bias: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + cutoffs = list(cutoffs) + + if len(cutoffs) == 0: + raise ValueError("cutoffs should be a sequence of length larger than 0") + + if ( + (cutoffs != sorted(cutoffs)) + or (min(cutoffs) <= 0) + or (max(cutoffs) > (n_classes - 1)) + or (len(set(cutoffs)) != len(cutoffs)) + or any(int(c) != c for c in cutoffs) + ): + raise ValueError( + "cutoffs should be a sequence of unique, positive " + "integers sorted in an increasing order, where " + "each value is between 1 and n_classes-1" + ) + + self.in_features = in_features + self.n_classes = n_classes + self.cutoffs = cutoffs + [n_classes] + self.div_value = div_value + self.head_bias = head_bias + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + + self.head = Linear( + self.in_features, self.head_size, bias=self.head_bias, **factory_kwargs + ) + self.tail = ModuleList() + + for i in range(self.n_clusters): + hsz = int(self.in_features // (self.div_value ** (i + 1))) + osz = self.cutoffs[i + 1] - self.cutoffs[i] + + projection = Sequential( + Linear(self.in_features, hsz, bias=False, **factory_kwargs), + Linear(hsz, osz, bias=False, **factory_kwargs), + ) + + self.tail.append(projection) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in ``__init__``. + """ + self.head.reset_parameters() + for i2h, h2o in self.tail: # type: ignore[misc] + i2h.reset_parameters() # type: ignore[has-type] + h2o.reset_parameters() # type: ignore[has-type] + + def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: + """ + Runs the forward pass. + """ + targ_dim = target_.dim() + + if targ_dim == 1: + if input_.size(0) != target_.size(0): + raise RuntimeError( + "Input and target should have the same size in the batch dimension." + ) + if input_.dim() != 2: + raise RuntimeError( + "1D target tensor expects 2D input tensors, " + "but found inputs with size", + input_.size(), + ) + elif targ_dim == 0: + if input_.dim() != 1: + raise RuntimeError( + "0D target tensor expects 1D input tensors, " + "but found inputs with size", + input_.size(), + ) + else: + raise RuntimeError( + "0D or 1D target tensor expected, multi-target not supported" + ) + + is_batched = targ_dim > 0 + input = input_ if is_batched else input_.unsqueeze(0) + target = target_ if is_batched else target_.unsqueeze(0) + + used_rows = 0 + batch_size = target.size(0) + + output = input.new_zeros(batch_size) + gather_inds = target.new_empty(batch_size) + + cutoff_values = [0] + self.cutoffs + for i in range(len(cutoff_values) - 1): + low_idx = cutoff_values[i] + high_idx = cutoff_values[i + 1] + + target_mask = (target >= low_idx) & (target < high_idx) + row_indices = target_mask.nonzero().squeeze() + + if row_indices.numel() == 0: + continue + + if i == 0: + gather_inds.index_copy_(0, row_indices, target[target_mask]) + + else: + relative_target = target[target_mask] - low_idx + input_subset = input.index_select(0, row_indices) + + cluster_output = self.tail[i - 1](input_subset) + cluster_index = self.shortlist_size + i - 1 + + gather_inds.index_fill_(0, row_indices, cluster_index) + cluster_logprob = F.log_softmax(cluster_output, dim=1) + local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)) + output.index_copy_(0, row_indices, local_logprob.squeeze(1)) + + used_rows += row_indices.numel() + + if used_rows != batch_size: + raise RuntimeError( + f"Target values should be in [0, {self.n_classes - 1}], " + f"but values in range [{target.min().item()}, {target.max().item()}] " + "were found. " + ) + + head_output = self.head(input) + head_logprob = F.log_softmax(head_output, dim=1) + output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze() + loss = (-output).mean() + + if not is_batched: + output = output.squeeze(0) + + return _ASMoutput(output, loss) + + def _get_full_log_prob(self, input, head_output): + """Given input tensor, and output of ``self.head``, compute the log of the full distribution.""" + out = input.new_empty((head_output.size(0), self.n_classes)) + head_logprob = F.log_softmax(head_output, dim=1) + + out[:, : self.shortlist_size] = head_logprob[:, : self.shortlist_size] + + for i, (start_idx, stop_idx) in enumerate(itertools.pairwise(self.cutoffs)): + cluster_output = self.tail[i](input) + cluster_logprob = F.log_softmax(cluster_output, dim=1) + output_logprob = cluster_logprob + head_logprob[ + :, self.shortlist_size + i + ].unsqueeze(1) + + out[:, start_idx:stop_idx] = output_logprob + + return out + + def log_prob(self, input: Tensor) -> Tensor: + r"""Compute log probabilities for all :math:`\texttt{n\_classes}`. + + Args: + input (Tensor): a minibatch of examples + + Returns: + log-probabilities of for each class :math:`c` + in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a + parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. + + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N, \texttt{n\_classes})` + + """ + head_output = self.head(input) + return self._get_full_log_prob(input, head_output) + + def predict(self, input: Tensor) -> Tensor: + r"""Return the class with the highest probability for each example in the input minibatch. + + This is equivalent to ``self.log_prob(input).argmax(dim=1)``, but is more efficient in some cases. + + Args: + input (Tensor): a minibatch of examples + + Returns: + output (Tensor): a class with the highest probability for each example + + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N)` + """ + head_output = self.head(input) + output = torch.argmax(head_output, dim=1) + not_in_shortlist = output >= self.shortlist_size + all_in_shortlist = not (not_in_shortlist.any()) + + if all_in_shortlist: + return output + + elif not_in_shortlist.all(): + log_prob = self._get_full_log_prob(input, head_output) + return torch.argmax(log_prob, dim=1) + + else: + log_prob = self._get_full_log_prob( + input[not_in_shortlist], head_output[not_in_shortlist] + ) + output[not_in_shortlist] = torch.argmax(log_prob, dim=1) + return output diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/batchnorm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..40a912b4f05682792b1a3126b6df53230ced88c0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/batchnorm.py @@ -0,0 +1,902 @@ +# mypy: allow-untyped-defs +from typing import Any + +import torch +from torch import Tensor +from torch.nn import functional as F, init +from torch.nn.parameter import Parameter, UninitializedBuffer, UninitializedParameter + +from ._functions import SyncBatchNorm as sync_batch_norm +from .lazy import LazyModuleMixin +from .module import Module + + +__all__ = [ + "BatchNorm1d", + "LazyBatchNorm1d", + "BatchNorm2d", + "LazyBatchNorm2d", + "BatchNorm3d", + "LazyBatchNorm3d", + "SyncBatchNorm", +] + + +class _NormBase(Module): + """Common base of _InstanceNorm and _BatchNorm.""" + + _version = 2 + __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"] + num_features: int + eps: float + momentum: float | None + affine: bool + track_running_stats: bool + # WARNING: weight and bias purposely not defined here. + # See https://github.com/pytorch/pytorch/issues/39670 + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float | None = 0.1, + affine: bool = True, + track_running_stats: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + if self.affine: + self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) + self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if self.track_running_stats: + self.register_buffer( + "running_mean", torch.zeros(num_features, **factory_kwargs) + ) + self.register_buffer( + "running_var", torch.ones(num_features, **factory_kwargs) + ) + self.running_mean: Tensor | None + self.running_var: Tensor | None + self.register_buffer( + "num_batches_tracked", + torch.tensor( + 0, + dtype=torch.long, + # pyrefly: ignore [bad-argument-type] + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ), + ) + self.num_batches_tracked: Tensor | None + else: + self.register_buffer("running_mean", None) + self.register_buffer("running_var", None) + self.register_buffer("num_batches_tracked", None) + self.reset_parameters() + + def reset_running_stats(self) -> None: + if self.track_running_stats: + # running_mean/running_var/num_batches... are registered at runtime depending + # if self.track_running_stats is on + self.running_mean.zero_() # type: ignore[union-attr] + self.running_var.fill_(1) # type: ignore[union-attr] + self.num_batches_tracked.zero_() # type: ignore[union-attr,operator] + + def reset_parameters(self) -> None: + self.reset_running_stats() + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def _check_input_dim(self, input): + raise NotImplementedError + + def extra_repr(self): + return ( + "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " + "track_running_stats={track_running_stats}".format(**self.__dict__) + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) -> None: + version = local_metadata.get("version", None) + + if (version is None or version < 2) and self.track_running_stats: + # at version 2: added num_batches_tracked buffer + # this should have a default value of 0 + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key not in state_dict: + state_dict[num_batches_tracked_key] = ( + self.num_batches_tracked + if self.num_batches_tracked is not None + and self.num_batches_tracked.device != torch.device("meta") + else torch.tensor(0, dtype=torch.long) + ) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class _BatchNorm(_NormBase): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float | None = 0.1, + affine: bool = True, + track_running_stats: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: # type: ignore[has-type] + self.num_batches_tracked.add_(1) # type: ignore[has-type] + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + return F.batch_norm( + input, + # If buffers are not to be tracked, ensure that they won't be updated + ( + self.running_mean + if not self.training or self.track_running_stats + else None + ), + self.running_var if not self.training or self.track_running_stats else None, + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps, + ) + + +class _LazyNormBase(LazyModuleMixin, _NormBase): + weight: UninitializedParameter # type: ignore[assignment] + bias: UninitializedParameter # type: ignore[assignment] + + def __init__( + self, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore [bad-argument-type] + super().__init__( + # affine and track_running_stats are hardcoded to False to + # avoid creating tensors that will soon be overwritten. + 0, + eps, + momentum, + False, + False, + **factory_kwargs, + ) + self.affine = affine + self.track_running_stats = track_running_stats + if self.affine: + # pyrefly: ignore [bad-argument-type] + self.weight = UninitializedParameter(**factory_kwargs) + # pyrefly: ignore [bad-argument-type] + self.bias = UninitializedParameter(**factory_kwargs) + if self.track_running_stats: + # pyrefly: ignore [bad-argument-type] + self.running_mean = UninitializedBuffer(**factory_kwargs) + # pyrefly: ignore [bad-argument-type] + self.running_var = UninitializedBuffer(**factory_kwargs) + self.num_batches_tracked = torch.tensor( + 0, + dtype=torch.long, + # pyrefly: ignore [bad-argument-type] + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ) + + def reset_parameters(self) -> None: + # pyrefly: ignore [bad-argument-type] + if not self.has_uninitialized_params() and self.num_features != 0: + super().reset_parameters() + + def initialize_parameters(self, input) -> None: # type: ignore[override] + # pyrefly: ignore [bad-argument-type] + if self.has_uninitialized_params(): + self.num_features = input.shape[1] + if self.affine: + assert isinstance(self.weight, UninitializedParameter) + assert isinstance(self.bias, UninitializedParameter) + self.weight.materialize((self.num_features,)) + self.bias.materialize((self.num_features,)) + if self.track_running_stats: + self.running_mean.materialize( # type:ignore[union-attr] + (self.num_features,) + ) + self.running_var.materialize( # type:ignore[union-attr] + (self.num_features,) + ) + self.reset_parameters() + + +class BatchNorm1d(_BatchNorm): + r"""Applies Batch Normalization over a 2D or 3D input. + + Method described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the number of features or channels of the input). By default, the + elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. + At train time in the forward pass, the variance is calculated via the biased estimator, + equivalent to ``torch.var(input, correction=0)``. However, the value stored in the + moving average of the variance is calculated via the unbiased estimator, equivalent to + ``torch.var(input, correction=1)``. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. + + Args: + num_features: number of features or channels :math:`C` of the input + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size, + :math:`C` is the number of features or channels, and :math:`L` is the sequence length + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples:: + + >>> # With Learnable Parameters + >>> m = nn.BatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm1d(100, affine=False) + >>> input = torch.randn(20, 100) + >>> output = m(input) + """ + + def _check_input_dim(self, input) -> None: + if input.dim() != 2 and input.dim() != 3: + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") + + +# pyrefly: ignore [inconsistent-inheritance] +class LazyBatchNorm1d(_LazyNormBase, _BatchNorm): + r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization. + + Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred + from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, + `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + """ + + cls_to_become = BatchNorm1d # type: ignore[assignment] + + def _check_input_dim(self, input) -> None: + if input.dim() != 2 and input.dim() != 3: + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") + + +class BatchNorm2d(_BatchNorm): + r"""Applies Batch Normalization over a 4D input. + + 4D is a mini-batch of 2D inputs + with additional channel dimension. Method described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set + to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the + standard-deviation is calculated via the biased estimator, equivalent to + ``torch.var(input, correction=0)``. However, the value stored in the moving average of the + standard-deviation is calculated via the unbiased estimator, equivalent to + ``torch.var(input, correction=1)``. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, H, W)` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples:: + + >>> # With Learnable Parameters + >>> m = nn.BatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm2d(100, affine=False) + >>> input = torch.randn(20, 100, 35, 45) + >>> output = m(input) + """ + + def _check_input_dim(self, input) -> None: + if input.dim() != 4: + raise ValueError(f"expected 4D input (got {input.dim()}D input)") + + +# pyrefly: ignore [inconsistent-inheritance] +class LazyBatchNorm2d(_LazyNormBase, _BatchNorm): + r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization. + + Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred + from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, + `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + """ + + cls_to_become = BatchNorm2d # type: ignore[assignment] + + def _check_input_dim(self, input) -> None: + if input.dim() != 4: + raise ValueError(f"expected 4D input (got {input.dim()}D input)") + + +class BatchNorm3d(_BatchNorm): + r"""Applies Batch Normalization over a 5D input. + + 5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set + to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the + standard-deviation is calculated via the biased estimator, equivalent to + ``torch.var(input, correction=0)``. However, the value stored in the moving average of the + standard-deviation is calculated via the unbiased estimator, equivalent to + ``torch.var(input, correction=1)``. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization + or Spatio-temporal Batch Normalization. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, D, H, W)` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples:: + + >>> # With Learnable Parameters + >>> m = nn.BatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm3d(100, affine=False) + >>> input = torch.randn(20, 100, 35, 45, 10) + >>> output = m(input) + """ + + def _check_input_dim(self, input) -> None: + if input.dim() != 5: + raise ValueError(f"expected 5D input (got {input.dim()}D input)") + + +# pyrefly: ignore [inconsistent-inheritance] +class LazyBatchNorm3d(_LazyNormBase, _BatchNorm): + r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization. + + Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred + from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, + `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + """ + + cls_to_become = BatchNorm3d # type: ignore[assignment] + + def _check_input_dim(self, input) -> None: + if input.dim() != 5: + raise ValueError(f"expected 5D input (got {input.dim()}D input)") + + +class SyncBatchNorm(_BatchNorm): + r"""Applies Batch Normalization over a N-Dimensional input. + + The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over all + mini-batches of the same process groups. :math:`\gamma` and :math:`\beta` + are learnable parameter vectors of size `C` (where `C` is the input size). + By default, the elements of :math:`\gamma` are sampled from + :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. + The standard-deviation is calculated via the biased estimator, equivalent to + `torch.var(input, correction=0)`. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done for each channel in the ``C`` dimension, computing + statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch + Normalization or Spatio-temporal Batch Normalization. + + Currently :class:`SyncBatchNorm` only supports + :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use + :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert + :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping + Network with DDP. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, +)` + eps: a value added to the denominator for numerical stability. + Default: ``1e-5`` + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + process_group: synchronization of stats happen within each process group + individually. Default behavior is synchronization across the whole + world + + Shape: + - Input: :math:`(N, C, +)` + - Output: :math:`(N, C, +)` (same shape as input) + + .. note:: + Synchronization of batchnorm statistics occurs only while training, i.e. + synchronization is disabled when ``model.eval()`` is set or if + ``self.training`` is otherwise ``False``. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With Learnable Parameters + >>> m = nn.SyncBatchNorm(100) + >>> # creating process group (optional) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) + >>> input = torch.randn(20, 100, 35, 45, 10) + >>> output = m(input) + + >>> # network is nn.BatchNorm layer + >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group) + >>> # only single gpu per process is currently supported + >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel( + >>> sync_bn_network, + >>> device_ids=[args.local_rank], + >>> output_device=args.local_rank) + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float | None = 0.1, + affine: bool = True, + track_running_stats: bool = True, + process_group: Any | None = None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.process_group = process_group + + def _check_input_dim(self, input) -> None: + if input.dim() < 2: + raise ValueError(f"expected at least 2D input (got {input.dim()}D input)") + + def _check_non_zero_input_channels(self, input) -> None: + if input.size(1) == 0: + raise ValueError( + "SyncBatchNorm number of input channels should be non-zero" + ) + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + self._check_input_dim(input) + self._check_non_zero_input_channels(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + assert self.num_batches_tracked is not None + self.num_batches_tracked.add_(1) + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / self.num_batches_tracked.item() + else: # use exponential moving average + exponential_average_factor = self.momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + # If buffers are not to be tracked, ensure that they won't be updated + running_mean = ( + self.running_mean if not self.training or self.track_running_stats else None + ) + running_var = ( + self.running_var if not self.training or self.track_running_stats else None + ) + + # Don't sync batchnorm stats in inference mode (model.eval()). + need_sync = ( + bn_training + and self.training + and torch.distributed.is_available() + and torch.distributed.is_initialized() + ) + if need_sync: + # currently only GPU/PrivateUse1 input is supported + if input.device.type not in [ + "cuda", + "hpu", + "xpu", + torch._C._get_privateuse1_backend_name(), + ]: + raise ValueError( + "SyncBatchNorm expected input tensor to be on GPU or XPU or " + f"{torch._C._get_privateuse1_backend_name()}" + ) + + process_group = torch.distributed.group.WORLD + if self.process_group: + process_group = self.process_group + world_size = torch.distributed.get_world_size(process_group) + need_sync = world_size > 1 + + # fallback to framework BN when synchronization is not necessary + if not need_sync: + return F.batch_norm( + input, + running_mean, + running_var, + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps, + ) + else: + assert bn_training + return sync_batch_norm.apply( + input, + self.weight, + self.bias, + running_mean, + running_var, + self.eps, + exponential_average_factor, + process_group, # type: ignore[possibly-undefined] + world_size, # type: ignore[possibly-undefined] + ) + + @classmethod + def convert_sync_batchnorm(cls, module, process_group=None): + r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers. + + Args: + module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers + process_group (optional): process group to scope synchronization, + default is the whole world + + Returns: + The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm` + layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer, + a new :class:`torch.nn.SyncBatchNorm` layer object will be returned + instead. + + Example:: + + >>> # Network with nn.BatchNorm layer + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> module = torch.nn.Sequential( + >>> torch.nn.Linear(20, 100), + >>> torch.nn.BatchNorm1d(100), + >>> ).cuda() + >>> # creating process group (optional) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> # xdoctest: +SKIP("distributed") + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] + >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group) + + """ + module_output = module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + module_output = torch.nn.SyncBatchNorm( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group, + ) + if module.affine: + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + module_output.training = module.training + if hasattr(module, "qconfig"): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module( + name, cls.convert_sync_batchnorm(child, process_group) + ) + del module + return module_output diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/channelshuffle.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/channelshuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..34a48f04f853dd1c458b035635728a122e9cc4d3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/channelshuffle.py @@ -0,0 +1,62 @@ +import torch.nn.functional as F +from torch import Tensor + +from .module import Module + + +__all__ = ["ChannelShuffle"] + + +class ChannelShuffle(Module): + r"""Divides and rearranges the channels in a tensor. + + This operation divides the channels in a tensor of shape :math:`(N, C, *)` + into g groups as :math:`(N, \frac{C}{g}, g, *)` and shuffles them, + while retaining the original tensor shape in the final output. + + Args: + groups (int): number of groups to divide channels in. + + Examples:: + + >>> channel_shuffle = nn.ChannelShuffle(2) + >>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2) + >>> input + tensor([[[[ 1., 2.], + [ 3., 4.]], + [[ 5., 6.], + [ 7., 8.]], + [[ 9., 10.], + [11., 12.]], + [[13., 14.], + [15., 16.]]]]) + >>> output = channel_shuffle(input) + >>> output + tensor([[[[ 1., 2.], + [ 3., 4.]], + [[ 9., 10.], + [11., 12.]], + [[ 5., 6.], + [ 7., 8.]], + [[13., 14.], + [15., 16.]]]]) + """ + + __constants__ = ["groups"] + groups: int + + def __init__(self, groups: int) -> None: + super().__init__() + self.groups = groups + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.channel_shuffle(input, self.groups) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"groups={self.groups}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/container.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/container.py new file mode 100644 index 0000000000000000000000000000000000000000..d99151369e18e4d55ef843d6b8c6f4395d6a6453 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/container.py @@ -0,0 +1,1043 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import operator +from collections import abc as container_abcs, OrderedDict +from itertools import chain, islice +from typing import Any, overload, TYPE_CHECKING, TypeVar +from typing_extensions import deprecated, Self + +import torch +from torch._jit_internal import _copy_to_script_wrapper +from torch.nn.parameter import Parameter + +from .module import Module + + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping + + +__all__ = [ + "Container", + "Sequential", + "ModuleList", + "ModuleDict", + "ParameterList", + "ParameterDict", +] + +T = TypeVar("T", bound=Module) +_V = TypeVar("_V") + + +# Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList +def _addindent(s_, numSpaces): + s = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + +@deprecated( + "`nn.Container` is deprecated. " + "All of it's functionality is now implemented in `nn.Module`. Subclass that instead.", + category=FutureWarning, +) +class Container(Module): + def __init__(self, **kwargs: Any) -> None: + super().__init__() + for key, value in kwargs.items(): + self.add_module(key, value) + + +class Sequential(Module): + r"""A sequential container. + + Modules will be added to it in the order they are passed in the + constructor. Alternatively, an ``OrderedDict`` of modules can be + passed in. The ``forward()`` method of ``Sequential`` accepts any + input and forwards it to the first module it contains. It then + "chains" outputs to inputs sequentially for each subsequent module, + finally returning the output of the last module. + + The value a ``Sequential`` provides over manually calling a sequence + of modules is that it allows treating the whole container as a + single module, such that performing a transformation on the + ``Sequential`` applies to each of the modules it stores (which are + each a registered submodule of the ``Sequential``). + + What's the difference between a ``Sequential`` and a + :class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it + sounds like--a list for storing ``Module`` s! On the other hand, + the layers in a ``Sequential`` are connected in a cascading way. + + Example:: + + # Using Sequential to create a small model. When `model` is run, + # input will first be passed to `Conv2d(1,20,5)`. The output of + # `Conv2d(1,20,5)` will be used as the input to the first + # `ReLU`; the output of the first `ReLU` will become the input + # for `Conv2d(20,64,5)`. Finally, the output of + # `Conv2d(20,64,5)` will be used as input to the second `ReLU` + model = nn.Sequential( + nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU() + ) + + # Using Sequential with OrderedDict. This is functionally the + # same as the above code + model = nn.Sequential( + OrderedDict( + [ + ("conv1", nn.Conv2d(1, 20, 5)), + ("relu1", nn.ReLU()), + ("conv2", nn.Conv2d(20, 64, 5)), + ("relu2", nn.ReLU()), + ] + ) + ) + """ + + _modules: dict[str, Module] # type: ignore[assignment] + + @overload + def __init__(self, *args: Module) -> None: ... + + @overload + # pyrefly: ignore [inconsistent-overload] + def __init__(self, arg: OrderedDict[str, Module]) -> None: ... + + def __init__(self, *args): + super().__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + + def _get_item_by_idx(self, iterator: Iterable[_V], idx: int) -> _V: + """Get the idx-th item of the iterator.""" + size = len(self) + idx = operator.index(idx) + if not -size <= idx < size: + raise IndexError(f"index {idx} is out of range") + idx %= size + return next(islice(iterator, idx, None)) + + @_copy_to_script_wrapper + def __getitem__(self, idx: slice | int) -> Sequential | Module: + if isinstance(idx, slice): + return self.__class__(OrderedDict(list(self._modules.items())[idx])) + else: + return self._get_item_by_idx(self._modules.values(), idx) + + def __setitem__(self, idx: int, module: Module) -> None: + key: str = self._get_item_by_idx(self._modules.keys(), idx) + return setattr(self, key, module) + + def __delitem__(self, idx: slice | int) -> None: + if isinstance(idx, slice): + for key in list(self._modules.keys())[idx]: + delattr(self, key) + else: + key = self._get_item_by_idx(self._modules.keys(), idx) + delattr(self, key) + # To preserve numbering + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict( + zip(str_indices, self._modules.values(), strict=True) + ) + + @_copy_to_script_wrapper + def __len__(self) -> int: + return len(self._modules) + + def __add__(self, other) -> Sequential: + if isinstance(other, Sequential): + ret = Sequential() + for layer in self: + ret.append(layer) + for layer in other: + ret.append(layer) + return ret + else: + raise ValueError( + "add operator supports only objects " + f"of Sequential class, but {str(type(other))} is given." + ) + + def pop(self, key: int | slice) -> Module: + """ + Pop ``key`` from self. + """ + v = self[key] + del self[key] + return v + + def __iadd__(self, other) -> Self: + if isinstance(other, Sequential): + offset = len(self) + for i, module in enumerate(other): + self.add_module(str(i + offset), module) + return self + else: + raise ValueError( + "add operator supports only objects " + f"of Sequential class, but {str(type(other))} is given." + ) + + def __mul__(self, other: int) -> Sequential: + if not isinstance(other, int): + raise TypeError( + f"unsupported operand type(s) for *: {type(self)} and {type(other)}" + ) + elif other <= 0: + raise ValueError( + f"Non-positive multiplication factor {other} for {type(self)}" + ) + else: + combined = Sequential() + offset = 0 + for _ in range(other): + for module in self: + combined.add_module(str(offset), module) + offset += 1 + return combined + + def __rmul__(self, other: int) -> Sequential: + return self.__mul__(other) + + def __imul__(self, other: int) -> Self: + if not isinstance(other, int): + raise TypeError( + f"unsupported operand type(s) for *: {type(self)} and {type(other)}" + ) + elif other <= 0: + raise ValueError( + f"Non-positive multiplication factor {other} for {type(self)}" + ) + else: + len_original = len(self) + offset = len(self) + for _ in range(other - 1): + for i in range(len_original): + self.add_module(str(i + offset), self._modules[str(i)]) + offset += len_original + return self + + @_copy_to_script_wrapper + def __dir__(self) -> list[str]: + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + @_copy_to_script_wrapper + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + # NB: We can't really type check this function as the type of input + # may change dynamically (as is tested in + # TestScript.test_sequential_intermediary_types). Cannot annotate + # with Any as TorchScript expects a more precise type + def forward(self, input): + """ + Runs the forward pass. + """ + for module in self: + input = module(input) + return input + + def append(self, module: Module) -> Self: + r"""Append a given module to the end. + + Args: + module (nn.Module): module to append + + Example:: + + >>> import torch.nn as nn + >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) + >>> n.append(nn.Linear(3, 4)) + Sequential( + (0): Linear(in_features=1, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=3, bias=True) + (2): Linear(in_features=3, out_features=4, bias=True) + ) + + """ + self.add_module(str(len(self)), module) + return self + + def insert(self, index: int, module: Module) -> Self: + """ + Inserts a module into the Sequential container at the specified index. + + Args: + index (int): The index to insert the module. + module (Module): The module to be inserted. + + Example:: + + >>> import torch.nn as nn + >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) + >>> n.insert(0, nn.Linear(3, 4)) + Sequential( + (0): Linear(in_features=3, out_features=4, bias=True) + (1): Linear(in_features=1, out_features=2, bias=True) + (2): Linear(in_features=2, out_features=3, bias=True) + ) + + """ + if not isinstance(module, Module): + raise AssertionError(f"module should be of type: {Module}") + n = len(self._modules) + if not (-n <= index <= n): + raise IndexError(f"Index out of range: {index}") + if index < 0: + index += n + for i in range(n, index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + return self + + def extend(self, sequential: Iterable[Module]) -> Self: + """ + Extends the current Sequential container with layers from another Sequential container. + + Args: + sequential (Sequential): A Sequential container whose layers will be added to the current container. + + Example:: + + >>> import torch.nn as nn + >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) + >>> other = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 5)) + >>> n.extend(other) # or `n + other` + Sequential( + (0): Linear(in_features=1, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=3, bias=True) + (2): Linear(in_features=3, out_features=4, bias=True) + (3): Linear(in_features=4, out_features=5, bias=True) + ) + + """ + for layer in sequential: + self.append(layer) + return self + + +class ModuleList(Module): + r"""Holds submodules in a list. + + :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but + modules it contains are properly registered, and will be visible by all + :class:`~torch.nn.Module` methods. + + Args: + modules (iterable, optional): an iterable of modules to add + + Example:: + + class MyModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) + + def forward(self, x): + # ModuleList can act as an iterable, or be indexed using ints + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return x + """ + + _modules: dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Iterable[Module] | None = None) -> None: + super().__init__() + if modules is not None: + self += modules + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules.""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError(f"index {idx} is out of range") + if idx < 0: + idx += len(self) + return str(idx) + + @overload + def __getitem__(self, idx: slice) -> ModuleList: ... + + @overload + def __getitem__(self, idx: int) -> Module: ... + + @_copy_to_script_wrapper + def __getitem__(self, idx: int | slice) -> Module | ModuleList: + if isinstance(idx, slice): + return self.__class__(list(self._modules.values())[idx]) + else: + return self._modules[self._get_abs_string_index(idx)] + + def __setitem__(self, idx: int, module: Module) -> None: + idx = self._get_abs_string_index(idx) + return setattr(self, str(idx), module) + + def __delitem__(self, idx: int | slice) -> None: + if isinstance(idx, slice): + for k in range(len(self._modules))[idx]: + delattr(self, str(k)) + else: + delattr(self, self._get_abs_string_index(idx)) + # To preserve numbering, self._modules is being reconstructed with modules after deletion + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict( + zip(str_indices, self._modules.values(), strict=True) + ) + + @_copy_to_script_wrapper + def __len__(self) -> int: + return len(self._modules) + + @_copy_to_script_wrapper + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + def __iadd__(self, modules: Iterable[Module]) -> Self: + return self.extend(modules) + + def __add__(self, other: Iterable[Module]) -> ModuleList: + combined = ModuleList() + for i, module in enumerate(chain(self, other)): + combined.add_module(str(i), module) + return combined + + def __repr__(self) -> str: + """Return a custom repr for ModuleList that compresses repeated module representations.""" + list_of_reprs = [repr(item) for item in self] + if len(list_of_reprs) == 0: + return self._get_name() + "()" + + start_end_indices = [[0, 0]] + repeated_blocks = [list_of_reprs[0]] + for i, r in enumerate(list_of_reprs[1:], 1): + if r == repeated_blocks[-1]: + start_end_indices[-1][1] += 1 + continue + + start_end_indices.append([i, i]) + repeated_blocks.append(r) + + lines = [] + main_str = self._get_name() + "(" + for (start_id, end_id), b in zip( + start_end_indices, repeated_blocks, strict=True + ): + local_repr = f"({start_id}): {b}" # default repr + + if start_id != end_id: + n = end_id - start_id + 1 + local_repr = f"({start_id}-{end_id}): {n} x {b}" + + local_repr = _addindent(local_repr, 2) + lines.append(local_repr) + + main_str += "\n " + "\n ".join(lines) + "\n" + main_str += ")" + return main_str + + @_copy_to_script_wrapper + def __dir__(self) -> list[str]: + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def insert(self, index: int, module: Module) -> None: + r"""Insert a given module before a given index in the list. + + Args: + index (int): index to insert. + module (nn.Module): module to insert + """ + for i in range(len(self._modules), index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + + def append(self, module: Module) -> Self: + r"""Append a given module to the end of the list. + + Args: + module (nn.Module): module to append + """ + self.add_module(str(len(self)), module) + return self + + def pop(self, key: int | slice) -> Module: + v = self[key] + del self[key] + return v + + def extend(self, modules: Iterable[Module]) -> Self: + r"""Append modules from a Python iterable to the end of the list. + + Args: + modules (iterable): iterable of modules to append + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError( + "ModuleList.extend should be called with an " + "iterable, but got " + type(modules).__name__ + ) + offset = len(self) + for i, module in enumerate(modules): + self.add_module(str(offset + i), module) + return self + + # remove forward altogether to fallback on Module's _forward_unimplemented + + +class ModuleDict(Module): + r"""Holds submodules in a dictionary. + + :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary, + but modules it contains are properly registered, and will be visible by all + :class:`~torch.nn.Module` methods. + + :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects + + * the order of insertion, and + + * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged + ``OrderedDict``, ``dict`` (started from Python 3.6) or another + :class:`~torch.nn.ModuleDict` (the argument to + :meth:`~torch.nn.ModuleDict.update`). + + Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping + types does not preserve the order of the merged mapping. + + Args: + modules (iterable, optional): a mapping (dictionary) of (string: module) + or an iterable of key-value pairs of type (string, module) + + Example:: + + class MyModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.choices = nn.ModuleDict( + {"conv": nn.Conv2d(10, 10, 3), "pool": nn.MaxPool2d(3)} + ) + self.activations = nn.ModuleDict( + [["lrelu", nn.LeakyReLU()], ["prelu", nn.PReLU()]] + ) + + def forward(self, x, choice, act): + x = self.choices[choice](x) + x = self.activations[act](x) + return x + """ + + _modules: dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Mapping[str, Module] | None = None) -> None: + super().__init__() + if modules is not None: + self.update(modules) + + @_copy_to_script_wrapper + def __getitem__(self, key: str) -> Module: + return self._modules[key] + + def __setitem__(self, key: str, module: Module) -> None: + self.add_module(key, module) + + def __delitem__(self, key: str) -> None: + del self._modules[key] + + @_copy_to_script_wrapper + def __len__(self) -> int: + return len(self._modules) + + @_copy_to_script_wrapper + def __iter__(self) -> Iterator[str]: + return iter(self._modules) + + @_copy_to_script_wrapper + def __contains__(self, key: str) -> bool: + return key in self._modules + + def clear(self) -> None: + """Remove all items from the ModuleDict.""" + self._modules.clear() + + def pop(self, key: str) -> Module: + r"""Remove key from the ModuleDict and return its module. + + Args: + key (str): key to pop from the ModuleDict + """ + v = self[key] + del self[key] + return v + + @_copy_to_script_wrapper + def keys(self) -> container_abcs.KeysView[str]: + r"""Return an iterable of the ModuleDict keys.""" + return self._modules.keys() + + @_copy_to_script_wrapper + def items(self) -> container_abcs.ItemsView[str, Module]: + r"""Return an iterable of the ModuleDict key/value pairs.""" + return self._modules.items() + + @_copy_to_script_wrapper + def values(self) -> container_abcs.ValuesView[Module]: + r"""Return an iterable of the ModuleDict values.""" + return self._modules.values() + + def update(self, modules: Mapping[str, Module]) -> None: + r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys. + + .. note:: + If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or + an iterable of key-value pairs, the order of new elements in it is preserved. + + Args: + modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`, + or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`) + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError( + "ModuleDict.update should be called with an " + "iterable of key/value pairs, but got " + type(modules).__name__ + ) + + if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): + for key, module in modules.items(): + self[key] = module + else: + # modules here can be a list with two items + for j, m in enumerate(modules): + if not isinstance(m, container_abcs.Iterable): + raise TypeError( + "ModuleDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(m).__name__ + ) + # pyrefly: ignore [bad-argument-type] + if not len(m) == 2: + raise ValueError( + "ModuleDict update sequence element " + # pyrefly: ignore [bad-argument-type] + "#" + str(j) + " has length " + str(len(m)) + "; 2 is required" + ) + # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] + # that's too cumbersome to type correctly with overloads, so we add an ignore here + self[m[0]] = m[1] # type: ignore[assignment] + + # remove forward altogether to fallback on Module's _forward_unimplemented + + +class ParameterList(Module): + r"""Holds parameters in a list. + + :class:`~torch.nn.ParameterList` can be used like a regular Python + list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered, + and will be visible by all :class:`~torch.nn.Module` methods. + + Note that the constructor, assigning an element of the list, the + :meth:`~torch.nn.ParameterList.append` method and the :meth:`~torch.nn.ParameterList.extend` + method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`. + + Args: + parameters (iterable, optional): an iterable of elements to add to the list. + + Example:: + + class MyModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.params = nn.ParameterList( + [nn.Parameter(torch.randn(10, 10)) for i in range(10)] + ) + + def forward(self, x): + # ParameterList can act as an iterable, or be indexed using ints + for i, p in enumerate(self.params): + x = self.params[i // 2].mm(x) + p.mm(x) + return x + """ + + def __init__(self, values: Iterable[Any] | None = None) -> None: + super().__init__() + self._size = 0 + if values is not None: + self += values + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules.""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError(f"index {idx} is out of range") + if idx < 0: + idx += len(self) + return str(idx) + + @overload + def __getitem__(self, idx: int) -> Any: ... + + @overload + # pyrefly: ignore [inconsistent-overload] + def __getitem__(self: T, idx: slice) -> T: ... + + def __getitem__(self, idx): + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + out = self.__class__() + for i in range(start, stop, step): + out.append(self[i]) + return out + else: + idx = self._get_abs_string_index(idx) + return getattr(self, str(idx)) + + def __setitem__(self, idx: int, param: Any) -> None: + # Note that all other function that add an entry to the list part of + # the ParameterList end up here. So this is the only place where we need + # to wrap things into Parameter if needed. + # Objects added via setattr() are not in the list part and thus won't + # call into this function. + idx = self._get_abs_string_index(idx) + if isinstance(param, torch.Tensor) and not isinstance(param, Parameter): + param = Parameter(param) + return setattr(self, str(idx), param) + + def __len__(self) -> int: + return self._size + + def __iter__(self) -> Iterator[Any]: + return iter(self[i] for i in range(len(self))) + + def __iadd__(self, parameters: Iterable[Any]) -> Self: + return self.extend(parameters) + + def __dir__(self) -> list[str]: + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def append(self, value: Any) -> Self: + """Append a given value at the end of the list. + + Args: + value (Any): value to append + """ + new_idx = len(self) + self._size += 1 + self[new_idx] = value + return self + + def extend(self, values: Iterable[Any]) -> Self: + """Append values from a Python iterable to the end of the list. + + Args: + values (iterable): iterable of values to append + """ + # Tensor is an iterable but we never want to unpack it here + if not isinstance(values, container_abcs.Iterable) or isinstance( + values, torch.Tensor + ): + raise TypeError( + "ParameterList.extend should be called with an " + "iterable, but got " + type(values).__name__ + ) + for value in values: + self.append(value) + return self + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + child_lines = [] + for k, p in enumerate(self): + if isinstance(p, torch.Tensor): + size_str = "x".join(str(size) for size in p.size()) + if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: + device_str = f" ({p.device})" + else: + device_str = "" + parastr = "{} containing: [{} of size {}{}]".format( + "Parameter" if isinstance(p, Parameter) else "Tensor", + p.dtype, + size_str, + device_str, + ) + # pyrefly: ignore [bad-argument-type] + child_lines.append(" (" + str(k) + "): " + parastr) + else: + child_lines.append( + # pyrefly: ignore [bad-argument-type] + " (" + str(k) + "): Object of type: " + type(p).__name__ + ) + + tmpstr = "\n".join(child_lines) + return tmpstr + + def __call__(self, *args, **kwargs): + raise RuntimeError("ParameterList should not be called.") + + +class ParameterDict(Module): + r"""Holds parameters in a dictionary. + + ParameterDict can be indexed like a regular Python dictionary, but Parameters it + contains are properly registered, and will be visible by all Module methods. + Other objects are treated as would be done by a regular Python dictionary + + :class:`~torch.nn.ParameterDict` is an **ordered** dictionary. + :meth:`~torch.nn.ParameterDict.update` with other unordered mapping + types (e.g., Python's plain ``dict``) does not preserve the order of the + merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict` + will preserve their ordering. + + Note that the constructor, assigning an element of the dictionary and the + :meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into + :class:`~torch.nn.Parameter`. + + Args: + values (iterable, optional): a mapping (dictionary) of + (string : Any) or an iterable of key-value pairs + of type (string, Any) + + Example:: + + class MyModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.params = nn.ParameterDict( + { + "left": nn.Parameter(torch.randn(5, 10)), + "right": nn.Parameter(torch.randn(5, 10)), + } + ) + + def forward(self, x, choice): + x = self.params[choice].mm(x) + return x + """ + + def __init__(self, parameters: Any = None) -> None: + super().__init__() + self._keys: dict[str, None] = {} + if parameters is not None: + self.update(parameters) + + def _key_to_attr(self, key: str) -> str: + if not isinstance(key, str): + raise TypeError( + "Index given to ParameterDict cannot be used as a key as it is " + f"not a string (type is '{type(key).__name__}'). Open an issue on " + "github if you need non-string keys." + ) + else: + # Use the key as-is so that `.named_parameters()` returns the right thing + return key + + def __getitem__(self, key: str) -> Any: + attr = self._key_to_attr(key) + return getattr(self, attr) + + def __setitem__(self, key: str, value: Any) -> None: + # Note that all other function that add an entry to the dictionary part of + # the ParameterDict end up here. So this is the only place where we need + # to wrap things into Parameter if needed. + # Objects added via setattr() are not in the dictionary part and thus won't + # call into this function. + self._keys[key] = None + attr = self._key_to_attr(key) + if isinstance(value, torch.Tensor) and not isinstance(value, Parameter): + value = Parameter(value) + setattr(self, attr, value) + + def __delitem__(self, key: str) -> None: + del self._keys[key] + attr = self._key_to_attr(key) + delattr(self, attr) + + def __len__(self) -> int: + return len(self._keys) + + def __iter__(self) -> Iterator[str]: + return iter(self._keys) + + def __reversed__(self) -> Iterator[str]: + return reversed(self._keys) + + def copy(self) -> ParameterDict: + """Return a copy of this :class:`~torch.nn.ParameterDict` instance.""" + # We have to use an OrderedDict because the ParameterDict constructor + # behaves differently on plain dict vs OrderedDict + return ParameterDict(OrderedDict((k, self[k]) for k in self._keys)) + + def __contains__(self, key: str) -> bool: + return key in self._keys + + def setdefault(self, key: str, default: Any | None = None) -> Any: + """Set the default for a key in the Parameterdict. + + If key is in the ParameterDict, return its value. + If not, insert `key` with a parameter `default` and return `default`. + `default` defaults to `None`. + + Args: + key (str): key to set default for + default (Any): the parameter set to the key + """ + if key not in self: + self[key] = default + return self[key] + + def clear(self) -> None: + """Remove all items from the ParameterDict.""" + for k in self._keys.copy(): + del self[k] + + def pop(self, key: str) -> Any: + r"""Remove key from the ParameterDict and return its parameter. + + Args: + key (str): key to pop from the ParameterDict + """ + v = self[key] + del self[key] + return v + + def popitem(self) -> tuple[str, Any]: + """Remove and return the last inserted `(key, parameter)` pair from the ParameterDict.""" + k, _ = self._keys.popitem() + # We need the key in the _keys to be able to access/del + self._keys[k] = None + val = self[k] + del self[k] + return k, val + + def get(self, key: str, default: Any | None = None) -> Any: + r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not. + + Args: + key (str): key to get from the ParameterDict + default (Parameter, optional): value to return if key not present + """ + return self[key] if key in self else default # noqa: SIM401 + + def fromkeys( + self, keys: Iterable[str], default: Any | None = None + ) -> ParameterDict: + r"""Return a new ParameterDict with the keys provided. + + Args: + keys (iterable, string): keys to make the new ParameterDict from + default (Parameter, optional): value to set for all keys + """ + return ParameterDict((k, default) for k in keys) + + def keys(self) -> container_abcs.KeysView[str]: + r"""Return an iterable of the ParameterDict keys.""" + return self._keys.keys() + + def items(self) -> Iterable[tuple[str, Any]]: + r"""Return an iterable of the ParameterDict key/value pairs.""" + return ((k, self[k]) for k in self._keys) + + def values(self) -> Iterable[Any]: + r"""Return an iterable of the ParameterDict values.""" + return (self[k] for k in self._keys) + + def update(self, parameters: Mapping[str, Any] | ParameterDict) -> None: + r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys. + + .. note:: + If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or + an iterable of key-value pairs, the order of new elements in it is preserved. + + Args: + parameters (iterable): a mapping (dictionary) from string to + :class:`~torch.nn.Parameter`, or an iterable of + key-value pairs of type (string, :class:`~torch.nn.Parameter`) + """ + if not isinstance(parameters, container_abcs.Iterable): + raise TypeError( + "ParametersDict.update should be called with an " + "iterable of key/value pairs, but got " + type(parameters).__name__ + ) + + if isinstance(parameters, (OrderedDict, ParameterDict)): + for key, parameter in parameters.items(): + self[key] = parameter + elif isinstance(parameters, container_abcs.Mapping): + for key, parameter in sorted(parameters.items()): + self[key] = parameter + else: + for j, p in enumerate(parameters): + if not isinstance(p, container_abcs.Iterable): + raise TypeError( + "ParameterDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(p).__name__ + ) + # pyrefly: ignore [bad-argument-type] + if not len(p) == 2: + raise ValueError( + "ParameterDict update sequence element " + # pyrefly: ignore [bad-argument-type] + "#" + str(j) + " has length " + str(len(p)) + "; 2 is required" + ) + # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment + self[p[0]] = p[1] # type: ignore[assignment] + + def extra_repr(self) -> str: + child_lines = [] + for k, p in self.items(): + if isinstance(p, torch.Tensor): + size_str = "x".join(str(size) for size in p.size()) + if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: + device_str = f" ({p.device})" + else: + device_str = "" + parastr = "{} containing: [{} of size {}{}]".format( + "Parameter" if isinstance(p, Parameter) else "Tensor", + torch.typename(p), + size_str, + device_str, + ) + # pyrefly: ignore [bad-argument-type] + child_lines.append(" (" + str(k) + "): " + parastr) + else: + child_lines.append( + # pyrefly: ignore [bad-argument-type] + " (" + str(k) + "): Object of type: " + type(p).__name__ + ) + tmpstr = "\n".join(child_lines) + return tmpstr + + def __call__(self, input): + raise RuntimeError("ParameterDict should not be called.") + + def __or__(self, other: ParameterDict) -> ParameterDict: + copy = self.copy() + copy.update(other) + return copy + + def __ror__(self, other: ParameterDict) -> ParameterDict: + copy = other.copy() + copy.update(self) + return copy + + def __ior__(self, other: ParameterDict) -> Self: + self.update(other) + return self diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/conv.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..8b74b6a5a39e8ebfec821a047936e82b3cf002f0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/conv.py @@ -0,0 +1,1904 @@ +# mypy: allow-untyped-defs +import math +from typing import Literal, Optional +from typing_extensions import deprecated + +import torch +from torch import Tensor +from torch._torch_docs import reproducibility_notes +from torch.nn import functional as F, init +from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t +from torch.nn.parameter import Parameter, UninitializedParameter + +from .lazy import LazyModuleMixin +from .module import Module +from .utils import _pair, _reverse_repeat_tuple, _single, _triple + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "LazyConv1d", + "LazyConv2d", + "LazyConv3d", + "LazyConvTranspose1d", + "LazyConvTranspose2d", + "LazyConvTranspose3d", +] + +convolution_notes = { + "groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs. + :attr:`in_channels` and :attr:`out_channels` must both be divisible by + :attr:`groups`. For example, + + * At groups=1, all inputs are convolved to all outputs. + * At groups=2, the operation becomes equivalent to having two conv + layers side by side, each seeing half the input channels + and producing half the output channels, and both subsequently + concatenated. + * At groups= :attr:`in_channels`, each input channel is convolved with + its own set of filters (of size + :math:`\frac{\text{out\_channels}}{\text{in\_channels}}`).""", + "depthwise_separable_note": r"""When `groups == in_channels` and `out_channels == K * in_channels`, + where `K` is a positive integer, this operation is also known as a "depthwise convolution". + + In other words, for an input of size :math:`(N, C_{in}, L_{in})`, + a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments + :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`.""", +} # noqa: B950 + + +class _ConvNd(Module): + __constants__ = [ + "stride", + "padding", + "dilation", + "groups", + "padding_mode", + "output_padding", + "in_channels", + "out_channels", + "kernel_size", + ] + __annotations__ = {"bias": Optional[torch.Tensor]} + + def _conv_forward( # type: ignore[empty-body] + self, input: Tensor, weight: Tensor, bias: Tensor | None + ) -> Tensor: ... + + in_channels: int + _reversed_padding_repeated_twice: list[int] + out_channels: int + kernel_size: tuple[int, ...] + stride: tuple[int, ...] + padding: str | tuple[int, ...] + dilation: tuple[int, ...] + transposed: bool + output_padding: tuple[int, ...] + groups: int + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] + weight: Tensor + bias: Tensor | None + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: tuple[int, ...], + stride: tuple[int, ...], + padding: str | tuple[int, ...], + dilation: tuple[int, ...], + transposed: bool, + output_padding: tuple[int, ...], + groups: int, + bias: bool, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"], + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if groups <= 0: + raise ValueError("groups must be a positive integer") + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups") + if out_channels % groups != 0: + raise ValueError("out_channels must be divisible by groups") + valid_padding_strings = {"same", "valid"} + if isinstance(padding, str): + if padding not in valid_padding_strings: + raise ValueError( + f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}" + ) + if padding == "same" and any(s != 1 for s in stride): + raise ValueError( + "padding='same' is not supported for strided convolutions" + ) + + valid_padding_modes = {"zeros", "reflect", "replicate", "circular"} + if padding_mode not in valid_padding_modes: + raise ValueError( + f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'" + ) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + self.padding_mode = padding_mode + # `_reversed_padding_repeated_twice` is the padding to be passed to + # `F.pad` if needed (e.g., for non-zero padding types that are + # implemented as two ops: padding + conv). `F.pad` accepts paddings in + # reverse order than the dimension. + if isinstance(self.padding, str): + self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) + if padding == "same": + for d, k, i in zip( + dilation, + kernel_size, + range(len(kernel_size) - 1, -1, -1), + strict=False, + ): + total_padding = d * (k - 1) + left_pad = total_padding // 2 + self._reversed_padding_repeated_twice[2 * i] = left_pad + self._reversed_padding_repeated_twice[2 * i + 1] = ( + total_padding - left_pad + ) + else: + self._reversed_padding_repeated_twice = _reverse_repeat_tuple( + self.padding, 2 + ) + + if transposed: + self.weight = Parameter( + torch.empty( + (in_channels, out_channels // groups, *kernel_size), + **factory_kwargs, + ) + ) + else: + self.weight = Parameter( + torch.empty( + (out_channels, in_channels // groups, *kernel_size), + **factory_kwargs, + ) + ) + if bias: + self.bias = Parameter(torch.empty(out_channels, **factory_kwargs)) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size) + # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def extra_repr(self): + s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.dilation != (1,) * len(self.dilation): + s += ", dilation={dilation}" + if self.output_padding != (0,) * len(self.output_padding): + s += ", output_padding={output_padding}" + if self.groups != 1: + s += ", groups={groups}" + if self.bias is None: + s += ", bias=False" + if self.padding_mode != "zeros": + s += ", padding_mode={padding_mode}" + return s.format(**self.__dict__) + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "padding_mode"): + self.padding_mode = "zeros" + + +class Conv1d(_ConvNd): + __doc__ = ( + r"""Applies a 1D convolution over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size + :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be + precisely described as: + + .. math:: + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k) + \star \text{input}(N_i, k) + + where :math:`\star` is the valid `cross-correlation`_ operator, + :math:`N` is a batch size, :math:`C` denotes a number of channels, + :math:`L` is a length of signal sequence. + """ + + r""" + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation, a single + number or a one-element tuple. + + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or a tuple of ints giving the + amount of implicit padding applied on both sides. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also + known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ + has a nice visualization of what :attr:`dilation` does. +""" + r""" + {groups_note} + + Note: + {depthwise_separable_note} + Note: + {cudnn_reproducibility_note} + + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + + Note: + This module supports complex data types i.e. ``complex32, complex64, complex128``. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + + """.format(**reproducibility_notes, **convolution_notes) + + r""" + + Shape: + - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})` + - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where + + .. math:: + L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation} + \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out\_channels}, + \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` + bias (Tensor): the learnable bias of the module of shape + (out_channels). If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` + + Examples:: + + >>> m = nn.Conv1d(16, 33, 3, stride=2) + >>> input = torch.randn(20, 16, 50) + >>> output = m(input) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + ) + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: str | _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + # we create new variables below to make mypy happy since kernel_size has + # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int] + kernel_size_ = _single(kernel_size) + stride_ = _single(stride) + padding_ = padding if isinstance(padding, str) else _single(padding) + dilation_ = _single(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride_, + padding_, + dilation_, + False, + _single(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None): + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + weight, + bias, + self.stride, + _single(0), + self.dilation, + self.groups, + ) + + return F.conv1d( + input, weight, bias, self.stride, self.padding, self.dilation, self.groups + ) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) + + +class Conv2d(_ConvNd): + __doc__ = ( + r"""Applies a 2D convolution over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size + :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` + can be precisely described as: + + .. math:: + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) + + + where :math:`\star` is the valid 2D `cross-correlation`_ operator, + :math:`N` is a batch size, :math:`C` denotes a number of channels, + :math:`H` is a height of input planes in pixels, and :math:`W` is + width in pixels. + """ + + r""" + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation, a single + number or a tuple. + + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or an int / a tuple of ints giving the + amount of implicit padding applied on both sides. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also + known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ + has a nice visualization of what :attr:`dilation` does. +""" + r""" + + {groups_note} + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimension + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + Note: + {depthwise_separable_note} + + Note: + {cudnn_reproducibility_note} + + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + + Note: + This module supports complex data types i.e. ``complex32, complex64, complex128``. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all four sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + """.format(**reproducibility_notes, **convolution_notes) + + r""" + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] + \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] + \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` + :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` + bias (Tensor): the learnable bias of the module of shape + (out_channels). If :attr:`bias` is ``True``, + then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` + + Examples: + + >>> # With square kernels and equal stride + >>> m = nn.Conv2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) + >>> input = torch.randn(20, 16, 50, 100) + >>> output = m(input) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + ) + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: str | _size_2_t = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size_ = _pair(kernel_size) + stride_ = _pair(stride) + padding_ = padding if isinstance(padding, str) else _pair(padding) + dilation_ = _pair(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride_, + padding_, + dilation_, + False, + _pair(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None): + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + weight, + bias, + self.stride, + _pair(0), + self.dilation, + self.groups, + ) + + return F.conv2d( + input, weight, bias, self.stride, self.padding, self.dilation, self.groups + ) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) + + +class Conv3d(_ConvNd): + __doc__ = ( + r"""Applies a 3D convolution over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)` + and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as: + + .. math:: + out(N_i, C_{out_j}) = bias(C_{out_j}) + + \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k) + + where :math:`\star` is the valid 3D `cross-correlation`_ operator + """ + + r""" + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation. + + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or a tuple of ints giving the + amount of implicit padding applied on both sides. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" + + {groups_note} + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimension + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + Note: + {depthwise_separable_note} + + Note: + {cudnn_reproducibility_note} + + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + + Note: + This module supports complex data types i.e. ``complex32, complex64, complex128``. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all six sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + """.format(**reproducibility_notes, **convolution_notes) + + r""" + + Shape: + - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or :math:`(C_{out}, D_{out}, H_{out}, W_{out})`, + where + + .. math:: + D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] + \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] + \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] + \times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` + :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, + then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + + Examples:: + + >>> # With square kernels and equal stride + >>> m = nn.Conv3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)) + >>> input = torch.randn(20, 16, 10, 50, 100) + >>> output = m(input) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + ) + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: str | _size_3_t = 0, + dilation: _size_3_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size_ = _triple(kernel_size) + stride_ = _triple(stride) + padding_ = padding if isinstance(padding, str) else _triple(padding) + dilation_ = _triple(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride_, + padding_, + dilation_, + False, + _triple(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None): + if self.padding_mode != "zeros": + return F.conv3d( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + weight, + bias, + self.stride, + _triple(0), + self.dilation, + self.groups, + ) + + return F.conv3d( + input, weight, bias, self.stride, self.padding, self.dilation, self.groups + ) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) + + +class _ConvTransposeNd(_ConvNd): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + device=None, + dtype=None, + ) -> None: + if padding_mode != "zeros": + raise ValueError( + f'Only "zeros" padding mode is supported for {self.__class__.__name__}' + ) + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + # dilation being an optional parameter is for backwards + # compatibility + def _output_padding( + self, + input: Tensor, + output_size: list[int] | None, + stride: list[int], + padding: list[int], + kernel_size: list[int], + num_spatial_dims: int, + dilation: list[int] | None = None, + ) -> list[int]: + if output_size is None: + ret = _single(self.output_padding) # converting to list if was not already + else: + has_batch_dim = input.dim() == num_spatial_dims + 2 + num_non_spatial_dims = 2 if has_batch_dim else 1 + if len(output_size) == num_non_spatial_dims + num_spatial_dims: + output_size = output_size[num_non_spatial_dims:] + if len(output_size) != num_spatial_dims: + raise ValueError( + f"ConvTranspose{num_spatial_dims}D: for {input.dim()}D input, output_size must have {num_spatial_dims} " + f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})" + ) + + min_sizes = torch.jit.annotate(list[int], []) + max_sizes = torch.jit.annotate(list[int], []) + for d in range(num_spatial_dims): + dim_size = ( + (input.size(d + num_non_spatial_dims) - 1) * stride[d] + - 2 * padding[d] + + (dilation[d] if dilation is not None else 1) + * (kernel_size[d] - 1) + + 1 + ) + min_sizes.append(dim_size) + max_sizes.append(min_sizes[d] + stride[d] - 1) + + for i in range(len(output_size)): + size = output_size[i] + min_size = min_sizes[i] + max_size = max_sizes[i] + if size < min_size or size > max_size: + raise ValueError( + f"requested an output size of {output_size}, but valid sizes range " + f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})" + ) + + res = torch.jit.annotate(list[int], []) + for d in range(num_spatial_dims): + res.append(output_size[d] - min_sizes[d]) + + ret = res + return ret + + +class ConvTranspose1d(_ConvTransposeNd): + __doc__ = ( + r"""Applies a 1D transposed convolution operator over an input image + composed of several input planes. + + This module can be seen as the gradient of Conv1d with respect to its input. + It is also known as a fractionally-strided convolution or + a deconvolution (although it is not an actual deconvolution operation as it does + not compute a true inverse of convolution). For more information, see the visualizations + `here`_ and the `Deconvolutional Networks`_ paper. + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation. + + * :attr:`padding` controls the amount of implicit zero padding on both + sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note + below for details. + + * :attr:`output_padding` controls the additional size added to one side + of the output shape. See note below for details. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" + {groups_note} + + Note: + The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` + amount of zero padding to both sizes of the input. This is set so that + when a :class:`~torch.nn.Conv1d` and a :class:`~torch.nn.ConvTranspose1d` + are initialized with same parameters, they are inverses of each other in + regard to the input and output shapes. However, when ``stride > 1``, + :class:`~torch.nn.Conv1d` maps multiple input shapes to the same output + shape. :attr:`output_padding` is provided to resolve this ambiguity by + effectively increasing the calculated output shape on one side. Note + that :attr:`output_padding` is only used to find output shape, but does + not actually add zero-padding to output. + + Note: + In some circumstances when using the CUDA backend with CuDNN, this operator + may select a nondeterministic algorithm to increase performance. If this is + undesirable, you can try to make the operation deterministic (potentially at + a performance cost) by setting ``torch.backends.cudnn.deterministic = + True``. + Please see the notes on :doc:`/notes/randomness` for background. + + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + + r""" + + Shape: + - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})` + - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where + + .. math:: + L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation} + \times (\text{kernel\_size} - 1) + \text{output\_padding} + 1 + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},` + :math:`\text{kernel\_size})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}` + bias (Tensor): the learnable bias of the module of shape (out_channels). + If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}` + + Examples:: + + >>> # With square kernels and equal stride + >>> m = nn.ConvTranspose1d(16, 33, 3, stride=2) + >>> input = torch.randn(20, 16, 50) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12) + >>> downsample = nn.Conv1d(16, 16, 3, stride=2, padding=1) + >>> upsample = nn.ConvTranspose1d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12]) + + .. _`here`: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + .. _`Deconvolutional Networks`: + https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf + """ + ) + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + output_padding = _single(output_padding) + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose1d" + ) + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 1 + output_padding = self._output_padding( + input, + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, # type: ignore[arg-type] + ) + return F.conv_transpose1d( + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + + +class ConvTranspose2d(_ConvTransposeNd): + __doc__ = ( + r"""Applies a 2D transposed convolution operator over an input image + composed of several input planes. + + This module can be seen as the gradient of Conv2d with respect to its input. + It is also known as a fractionally-strided convolution or + a deconvolution (although it is not an actual deconvolution operation as it does + not compute a true inverse of convolution). For more information, see the visualizations + `here`_ and the `Deconvolutional Networks`_ paper. + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation. When stride > 1, ConvTranspose2d inserts zeros between input + elements along the spatial dimensions before applying the convolution kernel. This zero-insertion operation is the standard + behavior of transposed convolutions, which can increase the spatial resolution and is equivalent to a learnable + upsampling operation. + + * :attr:`padding` controls the amount of implicit zero padding on both + sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note + below for details. + + * :attr:`output_padding` controls the additional size added to one side + of the output shape. See note below for details. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" + {groups_note} + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` + can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimensions + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + Note: + The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` + amount of zero padding to both sizes of the input. This is set so that + when a :class:`~torch.nn.Conv2d` and a :class:`~torch.nn.ConvTranspose2d` + are initialized with same parameters, they are inverses of each other in + regard to the input and output shapes. However, when ``stride > 1``, + :class:`~torch.nn.Conv2d` maps multiple input shapes to the same output + shape. :attr:`output_padding` is provided to resolve this ambiguity by + effectively increasing the calculated output shape on one side. Note + that :attr:`output_padding` is only used to find output shape, but does + not actually add zero-padding to output. + + Note: + {cudnn_reproducibility_note} + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + + r""" + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] + \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1 + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] + \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1 + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},` + :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` + bias (Tensor): the learnable bias of the module of shape (out_channels) + If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` + + Examples:: + + >>> # With square kernels and equal stride + >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> input = torch.randn(20, 16, 50, 100) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12) + >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1) + >>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12]) + + .. _`here`: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + .. _`Deconvolutional Networks`: + https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf + """ + ) + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + output_padding: _size_2_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_2_t = 1, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: + """ + Performs the forward pass. + + Attributes: + input (Tensor): The input tensor. + output_size (list[int], optional): A list of integers representing + the size of the output tensor. Default is None. + """ + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose2d" + ) + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 2 + output_padding = self._output_padding( + input, + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, # type: ignore[arg-type] + ) + + return F.conv_transpose2d( + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + + +class ConvTranspose3d(_ConvTransposeNd): + __doc__ = ( + r"""Applies a 3D transposed convolution operator over an input image composed of several input + planes. + The transposed convolution operator multiplies each input value element-wise by a learnable kernel, + and sums over the outputs from all input feature planes. + + This module can be seen as the gradient of Conv3d with respect to its input. + It is also known as a fractionally-strided convolution or + a deconvolution (although it is not an actual deconvolution operation as it does + not compute a true inverse of convolution). For more information, see the visualizations + `here`_ and the `Deconvolutional Networks`_ paper. + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation. + + * :attr:`padding` controls the amount of implicit zero padding on both + sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note + below for details. + + * :attr:`output_padding` controls the additional size added to one side + of the output shape. See note below for details. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" + {groups_note} + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` + can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + Note: + The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` + amount of zero padding to both sizes of the input. This is set so that + when a :class:`~torch.nn.Conv3d` and a :class:`~torch.nn.ConvTranspose3d` + are initialized with same parameters, they are inverses of each other in + regard to the input and output shapes. However, when ``stride > 1``, + :class:`~torch.nn.Conv3d` maps multiple input shapes to the same output + shape. :attr:`output_padding` is provided to resolve this ambiguity by + effectively increasing the calculated output shape on one side. Note + that :attr:`output_padding` is only used to find output shape, but does + not actually add zero-padding to output. + + Note: + {cudnn_reproducibility_note} + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + + r""" + + Shape: + - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or + :math:`(C_{out}, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] + \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1 + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] + \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1 + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{dilation}[2] + \times (\text{kernel\_size}[2] - 1) + \text{output\_padding}[2] + 1 + + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},` + :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + bias (Tensor): the learnable bias of the module of shape (out_channels) + If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + + Examples:: + + >>> # With square kernels and equal stride + >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2)) + >>> input = torch.randn(20, 16, 10, 50, 100) + >>> output = m(input) + + .. _`here`: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + .. _`Deconvolutional Networks`: + https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf + """ + ) + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: _size_3_t = 0, + output_padding: _size_3_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_3_t = 1, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + output_padding = _triple(output_padding) + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose3d" + ) + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 3 + output_padding = self._output_padding( + input, + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, # type: ignore[arg-type] + ) + + return F.conv_transpose3d( + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + + +# TODO: Deprecate and remove the following alias `_ConvTransposeMixin`. +# +# `_ConvTransposeMixin` was a mixin that was removed. It is meant to be used +# with `_ConvNd` to construct actual module classes that implements conv +# transpose ops: +# +# class MyConvTranspose(_ConvNd, _ConvTransposeMixin): +# ... +# +# In PyTorch, it has been replaced by `_ConvTransposeNd`, which is a proper +# subclass of `_ConvNd`. However, some user code in the wild still (incorrectly) +# use the internal class `_ConvTransposeMixin`. Hence, we provide this alias +# for BC, because it is cheap and easy for us to do so, even though that +# `_ConvTransposeNd` is really not a mixin anymore (but multiple inheritance as +# above would still work). +class _ConvTransposeMixin(_ConvTransposeNd): + @deprecated( + "`_ConvTransposeMixin` is a deprecated internal class. " + "Please consider using public APIs.", + category=FutureWarning, + ) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + +# TODO: Conv2dLocal +# TODO: Conv2dMap +# TODO: ConvTranspose2dMap + + +class _LazyConvXdMixin(LazyModuleMixin): + groups: int + transposed: bool + in_channels: int + out_channels: int + kernel_size: tuple[int, ...] + weight: UninitializedParameter + bias: UninitializedParameter + + def reset_parameters(self) -> None: + # has_uninitialized_params is defined in parent class and it is using a protocol on self + if not self.has_uninitialized_params() and self.in_channels != 0: # type: ignore[misc] + # "type:ignore[..]" is required because mypy thinks that "reset_parameters" is undefined + # in super class. Turns out that it is defined in _ConvND which is inherited by any class + # that also inherits _LazyConvXdMixin + super().reset_parameters() # type: ignore[misc] + + # Signature of "initialize_parameters" is incompatible with the definition in supertype LazyModuleMixin + def initialize_parameters(self, input: Tensor, *args, **kwargs) -> None: # type: ignore[override] + # defined by parent class but using a protocol + if self.has_uninitialized_params(): # type: ignore[misc] + self.in_channels = self._get_in_channels(input) + if self.in_channels % self.groups != 0: + raise ValueError("in_channels must be divisible by groups") + assert isinstance(self.weight, UninitializedParameter) + if self.transposed: + self.weight.materialize( + ( + self.in_channels, + self.out_channels // self.groups, + *self.kernel_size, + ) + ) + else: + self.weight.materialize( + ( + self.out_channels, + self.in_channels // self.groups, + *self.kernel_size, + ) + ) + if self.bias is not None: + assert isinstance(self.bias, UninitializedParameter) + self.bias.materialize((self.out_channels,)) + self.reset_parameters() + + # Function to extract in_channels from first input. + def _get_in_channels(self, input: Tensor) -> int: + num_spatial_dims = self._get_num_spatial_dims() + num_dims_no_batch = num_spatial_dims + 1 # +1 for channels dim + num_dims_batch = num_dims_no_batch + 1 + if input.dim() not in (num_dims_no_batch, num_dims_batch): + raise RuntimeError( + f"Expected {num_dims_no_batch}D (unbatched) or {num_dims_batch}D (batched) input " + f"to {self.__class__.__name__}, but " + f"got input of size: {input.shape}" + ) + return input.shape[1] if input.dim() == num_dims_batch else input.shape[0] + + # Function to return the number of spatial dims expected for inputs to the module. + # This is expected to be implemented by subclasses. + def _get_num_spatial_dims(self) -> int: + raise NotImplementedError + + +# LazyConv1d defines weight as a Tensor but derived class defines it as UninitializeParameter +class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc] + r"""A :class:`torch.nn.Conv1d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`Conv1d` is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + + .. seealso:: :class:`torch.nn.Conv1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = Conv1d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore [bad-argument-type] + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + dilation, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + padding_mode, + **factory_kwargs, + ) + # pyrefly: ignore [bad-override, bad-argument-type] + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + # pyrefly: ignore [bad-override, bad-argument-type] + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 1 + + +# LazyConv2d defines weight as a Tensor but derived class defines it as UninitializeParameter +class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc] + r"""A :class:`torch.nn.Conv2d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`Conv2d` that is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + + .. seealso:: :class:`torch.nn.Conv2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = Conv2d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore [bad-argument-type] + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + dilation, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + padding_mode, + **factory_kwargs, + ) + # pyrefly: ignore [bad-override, bad-argument-type] + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + # pyrefly: ignore [bad-override, bad-argument-type] + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 2 + + +# LazyConv3d defines weight as a Tensor but derived class defines it as UninitializeParameter +class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc] + r"""A :class:`torch.nn.Conv3d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`Conv3d` that is inferred from + the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + + .. seealso:: :class:`torch.nn.Conv3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = Conv3d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: _size_3_t = 0, + dilation: _size_3_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore [bad-argument-type] + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + dilation, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + padding_mode, + **factory_kwargs, + ) + # pyrefly: ignore [bad-override, bad-argument-type] + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + # pyrefly: ignore [bad-override, bad-argument-type] + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 3 + + +# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UninitializeParameter +class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[misc] + r"""A :class:`torch.nn.ConvTranspose1d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`ConvTranspose1d` that is inferred from + the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + + .. seealso:: :class:`torch.nn.ConvTranspose1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = ConvTranspose1d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore [bad-argument-type] + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + output_padding, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + dilation, + padding_mode, + **factory_kwargs, + ) + # pyrefly: ignore [bad-override, bad-argument-type] + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + # pyrefly: ignore [bad-override, bad-argument-type] + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 1 + + +# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UninitializeParameter +class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[misc] + r"""A :class:`torch.nn.ConvTranspose2d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`ConvTranspose2d` is inferred from + the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + + .. seealso:: :class:`torch.nn.ConvTranspose2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = ConvTranspose2d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + output_padding: _size_2_t = 0, + groups: int = 1, + bias: bool = True, + dilation: int = 1, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore [bad-argument-type] + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + output_padding, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + dilation, + padding_mode, + **factory_kwargs, + ) + # pyrefly: ignore [bad-override, bad-argument-type] + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + # pyrefly: ignore [bad-override, bad-argument-type] + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 2 + + +# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UninitializeParameter +class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[misc] + r"""A :class:`torch.nn.ConvTranspose3d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`ConvTranspose3d` is inferred from + the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + + .. seealso:: :class:`torch.nn.ConvTranspose3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = ConvTranspose3d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: _size_3_t = 0, + output_padding: _size_3_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_3_t = 1, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore [bad-argument-type] + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + output_padding, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + dilation, + padding_mode, + **factory_kwargs, + ) + # pyrefly: ignore [bad-override, bad-argument-type] + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + # pyrefly: ignore [bad-override, bad-argument-type] + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 3 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/distance.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/distance.py new file mode 100644 index 0000000000000000000000000000000000000000..27ab92fef5eb4a6da80d97d8559a204a6956ac4d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/distance.py @@ -0,0 +1,100 @@ +import torch.nn.functional as F +from torch import Tensor + +from .module import Module + + +__all__ = ["PairwiseDistance", "CosineSimilarity"] + + +class PairwiseDistance(Module): + r""" + Computes the pairwise distance between input vectors, or between columns of input matrices. + + Distances are computed using ``p``-norm, with constant ``eps`` added to avoid division by zero + if ``p`` is negative, i.e.: + + .. math :: + \mathrm{dist}\left(x, y\right) = \left\Vert x-y + \epsilon e \right\Vert_p, + + where :math:`e` is the vector of ones and the ``p``-norm is given by. + + .. math :: + \Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}. + + Args: + p (real, optional): the norm degree. Can be negative. Default: 2 + eps (float, optional): Small value to avoid division by zero. + Default: 1e-6 + keepdim (bool, optional): Determines whether or not to keep the vector dimension. + Default: False + Shape: + - Input1: :math:`(N, D)` or :math:`(D)` where `N = batch dimension` and `D = vector dimension` + - Input2: :math:`(N, D)` or :math:`(D)`, same shape as the Input1 + - Output: :math:`(N)` or :math:`()` based on input dimension. + If :attr:`keepdim` is ``True``, then :math:`(N, 1)` or :math:`(1)` based on input dimension. + + Examples: + >>> pdist = nn.PairwiseDistance(p=2) + >>> input1 = torch.randn(100, 128) + >>> input2 = torch.randn(100, 128) + >>> output = pdist(input1, input2) + """ + + __constants__ = ["norm", "eps", "keepdim"] + norm: float + eps: float + keepdim: bool + + def __init__( + self, p: float = 2.0, eps: float = 1e-6, keepdim: bool = False + ) -> None: + super().__init__() + self.norm = p + self.eps = eps + self.keepdim = keepdim + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim) + + +class CosineSimilarity(Module): + r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along `dim`. + + .. math :: + \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}. + + Args: + dim (int, optional): Dimension where cosine similarity is computed. Default: 1 + eps (float, optional): Small value to avoid division by zero. + Default: 1e-8 + Shape: + - Input1: :math:`(\ast_1, D, \ast_2)` where D is at position `dim` + - Input2: :math:`(\ast_1, D, \ast_2)`, same number of dimensions as x1, matching x1 size at dimension `dim`, + and broadcastable with x1 at other dimensions. + - Output: :math:`(\ast_1, \ast_2)` + + Examples: + >>> input1 = torch.randn(100, 128) + >>> input2 = torch.randn(100, 128) + >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6) + >>> output = cos(input1, input2) + """ + + __constants__ = ["dim", "eps"] + dim: int + eps: float + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.cosine_similarity(x1, x2, self.dim, self.eps) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/dropout.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3de5d61dc0b56d6f708a242611bfc5b2850288 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/dropout.py @@ -0,0 +1,323 @@ +import torch.nn.functional as F +from torch import Tensor + +from .module import Module + + +__all__ = [ + "Dropout", + "Dropout1d", + "Dropout2d", + "Dropout3d", + "AlphaDropout", + "FeatureAlphaDropout", +] + + +class _DropoutNd(Module): + __constants__ = ["p", "inplace"] + p: float + inplace: bool + + def __init__(self, p: float = 0.5, inplace: bool = False) -> None: + super().__init__() + if p < 0 or p > 1: + raise ValueError( + f"dropout probability has to be between 0 and 1, but got {p}" + ) + self.p = p + self.inplace = inplace + + def extra_repr(self) -> str: + return f"p={self.p}, inplace={self.inplace}" + + +class Dropout(_DropoutNd): + r"""During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p`. + + The zeroed elements are chosen independently for each forward call and are sampled from a Bernoulli distribution. + + Each channel will be zeroed out independently on every forward call. + + This has proven to be an effective technique for regularization and + preventing the co-adaptation of neurons as described in the paper + `Improving neural networks by preventing co-adaptation of feature + detectors`_ . + + Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during + training. This means that during evaluation the module simply computes an + identity function. + + Args: + p: probability of an element to be zeroed. Default: 0.5 + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(*)`. Input can be of any shape + - Output: :math:`(*)`. Output is of the same shape as input + + Examples:: + + >>> m = nn.Dropout(p=0.2) + >>> input = torch.randn(20, 16) + >>> output = m(input) + + .. _Improving neural networks by preventing co-adaptation of feature + detectors: https://arxiv.org/abs/1207.0580 + """ + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.dropout(input, self.p, self.training, self.inplace) + + +class Dropout1d(_DropoutNd): + r"""Randomly zero out entire channels. + + A channel is a 1D feature map, + e.g., the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 1D tensor :math:`\text{input}[i, j]`. + + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + Usually the input comes from :class:`nn.Conv1d` modules. + + As described in the paper + `Efficient Object Localization Using Convolutional Networks`_ , + if adjacent pixels within feature maps are strongly correlated + (as is normally the case in early convolution layers) then i.i.d. dropout + will not regularize the activations and will otherwise just result + in an effective learning rate decrease. + + In this case, :func:`nn.Dropout1d` will help promote independence between + feature maps and should be used instead. + + Args: + p (float, optional): probability of an element to be zero-ed. + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + Shape: + - Input: :math:`(N, C, L)` or :math:`(C, L)`. + - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input). + + Examples:: + + >>> m = nn.Dropout1d(p=0.2) + >>> input = torch.randn(20, 16, 32) + >>> output = m(input) + + .. _Efficient Object Localization Using Convolutional Networks: + https://arxiv.org/abs/1411.4280 + """ + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.dropout1d(input, self.p, self.training, self.inplace) + + +class Dropout2d(_DropoutNd): + r"""Randomly zero out entire channels. + + A channel is a 2D feature map, + e.g., the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 2D tensor :math:`\text{input}[i, j]`. + + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + Usually the input comes from :class:`nn.Conv2d` modules. + + As described in the paper + `Efficient Object Localization Using Convolutional Networks`_ , + if adjacent pixels within feature maps are strongly correlated + (as is normally the case in early convolution layers) then i.i.d. dropout + will not regularize the activations and will otherwise just result + in an effective learning rate decrease. + + In this case, :func:`nn.Dropout2d` will help promote independence between + feature maps and should be used instead. + + Args: + p (float, optional): probability of an element to be zero-ed. + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + .. warning :: + Due to historical reasons, this class will perform 1D channel-wise dropout + for 3D inputs (as done by :class:`nn.Dropout1d`). Thus, it currently does NOT + support inputs without a batch dimension of shape :math:`(C, H, W)`. This + behavior will change in a future release to interpret 3D inputs as no-batch-dim + inputs. To maintain the old behavior, switch to :class:`nn.Dropout1d`. + + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`. + - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input). + + Examples:: + + >>> m = nn.Dropout2d(p=0.2) + >>> input = torch.randn(20, 16, 32, 32) + >>> output = m(input) + + .. _Efficient Object Localization Using Convolutional Networks: + https://arxiv.org/abs/1411.4280 + """ + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.dropout2d(input, self.p, self.training, self.inplace) + + +class Dropout3d(_DropoutNd): + r"""Randomly zero out entire channels. + + A channel is a 3D feature map, + e.g., the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 3D tensor :math:`\text{input}[i, j]`. + + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + Usually the input comes from :class:`nn.Conv3d` modules. + + As described in the paper + `Efficient Object Localization Using Convolutional Networks`_ , + if adjacent pixels within feature maps are strongly correlated + (as is normally the case in early convolution layers) then i.i.d. dropout + will not regularize the activations and will otherwise just result + in an effective learning rate decrease. + + In this case, :func:`nn.Dropout3d` will help promote independence between + feature maps and should be used instead. + + Args: + p (float, optional): probability of an element to be zeroed. + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + Shape: + - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`. + - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input). + + Examples:: + + >>> m = nn.Dropout3d(p=0.2) + >>> input = torch.randn(20, 16, 4, 32, 32) + >>> output = m(input) + + .. _Efficient Object Localization Using Convolutional Networks: + https://arxiv.org/abs/1411.4280 + """ + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.dropout3d(input, self.p, self.training, self.inplace) + + +class AlphaDropout(_DropoutNd): + r"""Applies Alpha Dropout over the input. + + Alpha Dropout is a type of Dropout that maintains the self-normalizing + property. + For an input with zero mean and unit standard deviation, the output of + Alpha Dropout maintains the original mean and standard deviation of the + input. + Alpha Dropout goes hand-in-hand with SELU activation function, which ensures + that the outputs have zero mean and unit standard deviation. + + During training, it randomly masks some of the elements of the input + tensor with probability *p* using samples from a bernoulli distribution. + The elements to masked are randomized on every forward call, and scaled + and shifted to maintain zero mean and unit standard deviation. + + During evaluation the module simply computes an identity function. + + More details can be found in the paper `Self-Normalizing Neural Networks`_ . + + Args: + p (float): probability of an element to be dropped. Default: 0.5 + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + Shape: + - Input: :math:`(*)`. Input can be of any shape + - Output: :math:`(*)`. Output is of the same shape as input + + Examples:: + + >>> m = nn.AlphaDropout(p=0.2) + >>> input = torch.randn(20, 16) + >>> output = m(input) + + .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 + """ + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.alpha_dropout(input, self.p, self.training) + + +class FeatureAlphaDropout(_DropoutNd): + r"""Randomly masks out entire channels. + + A channel is a feature map, + e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input + is a tensor :math:`\text{input}[i, j]` of the input tensor). Instead of + setting activations to zero, as in regular Dropout, the activations are set + to the negative saturation value of the SELU activation function. More details + can be found in the paper `Self-Normalizing Neural Networks`_ . + + Each element will be masked independently for each sample on every forward + call with probability :attr:`p` using samples from a Bernoulli distribution. + The elements to be masked are randomized on every forward call, and scaled + and shifted to maintain zero mean and unit variance. + + Usually the input comes from :class:`nn.AlphaDropout` modules. + + As described in the paper + `Efficient Object Localization Using Convolutional Networks`_ , + if adjacent pixels within feature maps are strongly correlated + (as is normally the case in early convolution layers) then i.i.d. dropout + will not regularize the activations and will otherwise just result + in an effective learning rate decrease. + + In this case, :func:`nn.AlphaDropout` will help promote independence between + feature maps and should be used instead. + + Args: + p (float, optional): probability of an element to be zeroed. Default: 0.5 + inplace (bool, optional): If set to ``True``, will do this operation + in-place + + Shape: + - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`. + - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input). + + Examples:: + + >>> m = nn.FeatureAlphaDropout(p=0.2) + >>> input = torch.randn(20, 16, 4, 32, 32) + >>> output = m(input) + + .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 + .. _Efficient Object Localization Using Convolutional Networks: + https://arxiv.org/abs/1411.4280 + """ + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.feature_alpha_dropout(input, self.p, self.training) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/flatten.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..146a1890d422475712c9d62d0ff841530282d30e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/flatten.py @@ -0,0 +1,167 @@ +# mypy: allow-untyped-defs + +from torch import Tensor +from torch.types import _size + +from .module import Module + + +__all__ = ["Flatten", "Unflatten"] + + +class Flatten(Module): + r""" + Flattens a contiguous range of dims into a tensor. + + For use with :class:`~nn.Sequential`, see :meth:`torch.flatten` for details. + + Shape: + - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' + where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any + number of dimensions including none. + - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. + + Args: + start_dim: first dim to flatten (default = 1). + end_dim: last dim to flatten (default = -1). + + Examples:: + >>> input = torch.randn(32, 1, 5, 5) + >>> # With default parameters + >>> m = nn.Flatten() + >>> output = m(input) + >>> output.size() + torch.Size([32, 25]) + >>> # With non-default parameters + >>> m = nn.Flatten(0, 2) + >>> output = m(input) + >>> output.size() + torch.Size([160, 5]) + """ + + __constants__ = ["start_dim", "end_dim"] + start_dim: int + end_dim: int + + def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: + super().__init__() + self.start_dim = start_dim + self.end_dim = end_dim + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return input.flatten(self.start_dim, self.end_dim) + + def extra_repr(self) -> str: + """ + Returns the extra representation of the module. + """ + return f"start_dim={self.start_dim}, end_dim={self.end_dim}" + + +class Unflatten(Module): + r""" + Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`. + + * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can + be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. + + * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be + a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape` + (tuple of `(name, size)` tuples) for `NamedTensor` input. + + Shape: + - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at + dimension :attr:`dim` and :math:`*` means any number of dimensions including none. + - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and + :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`. + + Args: + dim (Union[int, str]): Dimension to be unflattened + unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension + + Examples: + >>> input = torch.randn(2, 50) + >>> # With tuple of ints + >>> m = nn.Sequential( + >>> nn.Linear(50, 50), + >>> nn.Unflatten(1, (2, 5, 5)) + >>> ) + >>> output = m(input) + >>> output.size() + torch.Size([2, 2, 5, 5]) + >>> # With torch.Size + >>> m = nn.Sequential( + >>> nn.Linear(50, 50), + >>> nn.Unflatten(1, torch.Size([2, 5, 5])) + >>> ) + >>> output = m(input) + >>> output.size() + torch.Size([2, 2, 5, 5]) + >>> # With namedshape (tuple of tuples) + >>> input = torch.randn(2, 50, names=("N", "features")) + >>> unflatten = nn.Unflatten("features", (("C", 2), ("H", 5), ("W", 5))) + >>> output = unflatten(input) + >>> output.size() + torch.Size([2, 2, 5, 5]) + """ + + NamedShape = tuple[tuple[str, int]] + + __constants__ = ["dim", "unflattened_size"] + dim: int | str + unflattened_size: _size | NamedShape + + def __init__(self, dim: int | str, unflattened_size: _size | NamedShape) -> None: + super().__init__() + + if isinstance(dim, int): + self._require_tuple_int(unflattened_size) + elif isinstance(dim, str): + self._require_tuple_tuple(unflattened_size) + else: + raise TypeError("invalid argument type for dim parameter") + + self.dim = dim + self.unflattened_size = unflattened_size + + def _require_tuple_tuple(self, input) -> None: + if isinstance(input, tuple): + for idx, elem in enumerate(input): + if not isinstance(elem, tuple): + raise TypeError( + "unflattened_size must be tuple of tuples, " + + f"but found element of type {type(elem).__name__} at pos {idx}" + ) + return + raise TypeError( + "unflattened_size must be a tuple of tuples, " + + f"but found type {type(input).__name__}" + ) + + def _require_tuple_int(self, input) -> None: + if isinstance(input, (tuple, list)): + for idx, elem in enumerate(input): + if not isinstance(elem, int): + raise TypeError( + "unflattened_size must be tuple of ints, " + + f"but found element of type {type(elem).__name__} at pos {idx}" + ) + return + raise TypeError( + f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}" + ) + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return input.unflatten(self.dim, self.unflattened_size) + + def extra_repr(self) -> str: + """ + Returns the extra representation of the module. + """ + return f"dim={self.dim}, unflattened_size={self.unflattened_size}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/fold.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/fold.py new file mode 100644 index 0000000000000000000000000000000000000000..ab1a58882c852370141e1e1dd911278334b425d8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/fold.py @@ -0,0 +1,335 @@ +import torch.nn.functional as F +from torch import Tensor +from torch.nn.common_types import _size_any_t + +from .module import Module + + +__all__ = ["Fold", "Unfold"] + + +class Fold(Module): + ( + r"""Combines an array of sliding local blocks into a large containing tensor. + + Consider a batched :attr:`input` tensor containing sliding local blocks, + e.g., patches of images, of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, + where :math:`N` is batch dimension, :math:`C \times \prod(\text{kernel\_size})` + is the number of values within a block (a block has :math:`\prod(\text{kernel\_size})` + spatial locations each containing a :math:`C`-channeled vector), and + :math:`L` is the total number of blocks. (This is exactly the + same specification as the output shape of :class:`~torch.nn.Unfold`.) This + operation combines these local blocks into the large :attr:`output` tensor + of shape :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)` + by summing the overlapping values. Similar to :class:`~torch.nn.Unfold`, the + arguments must satisfy + + .. math:: + L = \prod_d \left\lfloor\frac{\text{output\_size}[d] + 2 \times \text{padding}[d] % + - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor, + + where :math:`d` is over all spatial dimensions. + + * :attr:`output_size` describes the spatial shape of the large containing + tensor of the sliding local blocks. It is useful to resolve the ambiguity + when multiple input shapes map to same number of sliding blocks, e.g., + with ``stride > 0``. + + The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify + how the sliding blocks are retrieved. + + * :attr:`stride` controls the stride for the sliding blocks. + + * :attr:`padding` controls the amount of implicit zero-paddings on both + sides for :attr:`padding` number of points for each dimension before + reshaping. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" + Args: + output_size (int or tuple): the shape of the spatial dimensions of the + output (i.e., ``output.sizes()[2:]``) + kernel_size (int or tuple): the size of the sliding blocks + dilation (int or tuple, optional): a parameter that controls the + stride of elements within the + neighborhood. Default: 1 + padding (int or tuple, optional): implicit zero padding to be added on + both sides of input. Default: 0 + stride (int or tuple): the stride of the sliding blocks in the input + spatial dimensions. Default: 1 + + * If :attr:`output_size`, :attr:`kernel_size`, :attr:`dilation`, + :attr:`padding` or :attr:`stride` is an int or a tuple of length 1 then + their values will be replicated across all spatial dimensions. + + * For the case of two output spatial dimensions this operation is sometimes + called ``col2im``. + + .. note:: + :class:`~torch.nn.Fold` calculates each combined value in the resulting + large tensor by summing all values from all containing blocks. + :class:`~torch.nn.Unfold` extracts the values in the local blocks by + copying from the large tensor. So, if the blocks overlap, they are not + inverses of each other. + + In general, folding and unfolding operations are related as + follows. Consider :class:`~torch.nn.Fold` and + :class:`~torch.nn.Unfold` instances created with the same + parameters: + + >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) + >>> fold = nn.Fold(output_size=..., **fold_params) + >>> unfold = nn.Unfold(**fold_params) + + Then for any (supported) ``input`` tensor the following + equality holds: + + :: + + fold(unfold(input)) == divisor * input + + where ``divisor`` is a tensor that depends only on the shape + and dtype of the ``input``: + + >>> # xdoctest: +SKIP + >>> input_ones = torch.ones(input.shape, dtype=input.dtype) + >>> divisor = fold(unfold(input_ones)) + + When the ``divisor`` tensor contains no zero elements, then + ``fold`` and ``unfold`` operations are inverses of each + other (up to constant divisor). + + .. warning:: + Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported. + + Shape: + - Input: :math:`(N, C \times \prod(\text{kernel\_size}), L)` or :math:`(C \times \prod(\text{kernel\_size}), L)` + - Output: :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)` + or :math:`(C, \text{output\_size}[0], \text{output\_size}[1], \dots)` as described above + + Examples:: + + >>> fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2)) + >>> input = torch.randn(1, 3 * 2 * 2, 12) + >>> output = fold(input) + >>> output.size() + torch.Size([1, 3, 4, 5]) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + """ + ) + + __constants__ = ["output_size", "kernel_size", "dilation", "padding", "stride"] + output_size: _size_any_t + kernel_size: _size_any_t + dilation: _size_any_t + padding: _size_any_t + stride: _size_any_t + + def __init__( + self, + output_size: _size_any_t, + kernel_size: _size_any_t, + dilation: _size_any_t = 1, + padding: _size_any_t = 0, + stride: _size_any_t = 1, + ) -> None: + super().__init__() + self.output_size = output_size + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.stride = stride + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.fold( + input, + self.output_size, + self.kernel_size, + self.dilation, + self.padding, + self.stride, + ) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return ( + "output_size={output_size}, kernel_size={kernel_size}, " + "dilation={dilation}, padding={padding}, stride={stride}".format( + **self.__dict__ + ) + ) + + +class Unfold(Module): + ( + r"""Extracts sliding local blocks from a batched input tensor. + + Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`, + where :math:`N` is the batch dimension, :math:`C` is the channel dimension, + and :math:`*` represent arbitrary spatial dimensions. This operation flattens + each sliding :attr:`kernel_size`-sized block within the spatial dimensions + of :attr:`input` into a column (i.e., last dimension) of a 3-D :attr:`output` + tensor of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, where + :math:`C \times \prod(\text{kernel\_size})` is the total number of values + within each block (a block has :math:`\prod(\text{kernel\_size})` spatial + locations each containing a :math:`C`-channeled vector), and :math:`L` is + the total number of such blocks: + + .. math:: + L = \prod_d \left\lfloor\frac{\text{spatial\_size}[d] + 2 \times \text{padding}[d] % + - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor, + + where :math:`\text{spatial\_size}` is formed by the spatial dimensions + of :attr:`input` (:math:`*` above), and :math:`d` is over all spatial + dimensions. + + Therefore, indexing :attr:`output` at the last dimension (column dimension) + gives all values within a certain block. + + The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify + how the sliding blocks are retrieved. + + * :attr:`stride` controls the stride for the sliding blocks. + + * :attr:`padding` controls the amount of implicit zero-paddings on both + sides for :attr:`padding` number of points for each dimension before + reshaping. +""" + """ + * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" + Args: + kernel_size (int or tuple): the size of the sliding blocks + dilation (int or tuple, optional): a parameter that controls the + stride of elements within the + neighborhood. Default: 1 + padding (int or tuple, optional): implicit zero padding to be added on + both sides of input. Default: 0 + stride (int or tuple, optional): the stride of the sliding blocks in the input + spatial dimensions. Default: 1 + + * If :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or + :attr:`stride` is an int or a tuple of length 1, their values will be + replicated across all spatial dimensions. + + * For the case of two input spatial dimensions this operation is sometimes + called ``im2col``. + + .. note:: + :class:`~torch.nn.Fold` calculates each combined value in the resulting + large tensor by summing all values from all containing blocks. + :class:`~torch.nn.Unfold` extracts the values in the local blocks by + copying from the large tensor. So, if the blocks overlap, they are not + inverses of each other. + + In general, folding and unfolding operations are related as + follows. Consider :class:`~torch.nn.Fold` and + :class:`~torch.nn.Unfold` instances created with the same + parameters: + + >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) + >>> fold = nn.Fold(output_size=..., **fold_params) + >>> unfold = nn.Unfold(**fold_params) + + Then for any (supported) ``input`` tensor the following + equality holds: + + :: + + fold(unfold(input)) == divisor * input + + where ``divisor`` is a tensor that depends only on the shape + and dtype of the ``input``: + + >>> # xdoctest: +SKIP + >>> input_ones = torch.ones(input.shape, dtype=input.dtype) + >>> divisor = fold(unfold(input_ones)) + + When the ``divisor`` tensor contains no zero elements, then + ``fold`` and ``unfold`` operations are inverses of each + other (up to constant divisor). + + .. warning:: + Currently, only 4-D input tensors (batched image-like tensors) are + supported. + + Shape: + - Input: :math:`(N, C, *)` + - Output: :math:`(N, C \times \prod(\text{kernel\_size}), L)` as described above + + Examples:: + + >>> unfold = nn.Unfold(kernel_size=(2, 3)) + >>> input = torch.randn(2, 5, 3, 4) + >>> output = unfold(input) + >>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels) + >>> # 4 blocks (2x3 kernels) in total in the 3x4 input + >>> output.size() + torch.Size([2, 30, 4]) + + >>> # xdoctest: +IGNORE_WANT + >>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape) + >>> inp = torch.randn(1, 3, 10, 12) + >>> w = torch.randn(2, 3, 4, 5) + >>> inp_unf = torch.nn.functional.unfold(inp, (4, 5)) + >>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2) + >>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1)) + >>> # or equivalently (and avoiding a copy), + >>> # out = out_unf.view(1, 2, 7, 8) + >>> (torch.nn.functional.conv2d(inp, w) - out).abs().max() + tensor(1.9073e-06) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + """ + ) + + __constants__ = ["kernel_size", "dilation", "padding", "stride"] + kernel_size: _size_any_t + dilation: _size_any_t + padding: _size_any_t + stride: _size_any_t + + def __init__( + self, + kernel_size: _size_any_t, + dilation: _size_any_t = 1, + padding: _size_any_t = 0, + stride: _size_any_t = 1, + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.stride = stride + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.unfold( + input, self.kernel_size, self.dilation, self.padding, self.stride + ) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return ( + "kernel_size={kernel_size}, dilation={dilation}, padding={padding}," + " stride={stride}".format(**self.__dict__) + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/instancenorm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/instancenorm.py new file mode 100644 index 0000000000000000000000000000000000000000..058ffb3ed9aa9fa9bf496c709b4f3e6c48e72178 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/instancenorm.py @@ -0,0 +1,472 @@ +# mypy: allow-untyped-defs + +import warnings + +import torch.nn.functional as F +from torch import Tensor + +from .batchnorm import _LazyNormBase, _NormBase + + +__all__ = [ + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LazyInstanceNorm1d", + "LazyInstanceNorm2d", + "LazyInstanceNorm3d", +] + + +class _InstanceNorm(_NormBase): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = False, + track_running_stats: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + + def _check_input_dim(self, input): + raise NotImplementedError + + def _get_no_batch_dim(self): + raise NotImplementedError + + def _handle_no_batch_input(self, input): + return self._apply_instance_norm(input.unsqueeze(0)).squeeze(0) + + def _apply_instance_norm(self, input): + return F.instance_norm( + input, + self.running_mean, + self.running_var, + self.weight, + self.bias, + self.training or not self.track_running_stats, + self.momentum if self.momentum is not None else 0.0, + self.eps, + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) -> None: + version = local_metadata.get("version", None) + # at version 1: removed running_mean and running_var when + # track_running_stats=False (default) + if version is None and not self.track_running_stats: + running_stats_keys = [] + for name in ("running_mean", "running_var"): + key = prefix + name + if key in state_dict: + running_stats_keys.append(key) + if len(running_stats_keys) > 0: + error_msgs.append( + "Unexpected running stats buffer(s) {names} for {klass} " + "with track_running_stats=False. If state_dict is a " + "checkpoint saved before 0.4.0, this may be expected " + "because {klass} does not track running stats by default " + "since 0.4.0. Please remove these keys from state_dict. If " + "the running stats are actually needed, instead set " + "track_running_stats=True in {klass} to enable them. See " + "the documentation of {klass} for details.".format( + names=" and ".join(f'"{k}"' for k in running_stats_keys), + klass=self.__class__.__name__, + ) + ) + for key in running_stats_keys: + state_dict.pop(key) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + + feature_dim = input.dim() - self._get_no_batch_dim() + if input.size(feature_dim) != self.num_features: + if self.affine: + raise ValueError( + f"expected input's size at dim={feature_dim} to match num_features" + f" ({self.num_features}), but got: {input.size(feature_dim)}." + ) + else: + warnings.warn( + f"input's size at dim={feature_dim} does not match num_features. " + "You can silence this warning by not passing in num_features, " + "which is not used because affine=False", + stacklevel=2, + ) + + if input.dim() == self._get_no_batch_dim(): + return self._handle_no_batch_input(input) + + return self._apply_instance_norm(input) + + +class InstanceNorm1d(_InstanceNorm): + r"""Applies Instance Normalization. + + This operation applies Instance Normalization + over a 2D (unbatched) or 3D (batched) input as described in the paper + `Instance Normalization: The Missing Ingredient for Fast Stylization + `__. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension separately + for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the number of features or channels of the input) if :attr:`affine` is ``True``. + The variance is calculated via the biased estimator, equivalent to + `torch.var(input, correction=0)`. + + By default, this layer uses instance statistics computed from input data in + both training and evaluation modes. + + If :attr:`track_running_stats` is set to ``True``, during training this + layer keeps running estimates of its computed mean and variance, which are + then used for normalization during evaluation. The running estimates are + kept with a default :attr:`momentum` of 0.1. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + .. note:: + :class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but + have some subtle differences. :class:`InstanceNorm1d` is applied + on each channel of channeled data like multidimensional time series, but + :class:`LayerNorm` is usually applied on entire sample and often in NLP + tasks. Additionally, :class:`LayerNorm` applies elementwise affine + transform, while :class:`InstanceNorm1d` usually don't apply affine + transform. + + Args: + num_features: number of features or channels :math:`C` of the input + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, L)` or :math:`(C, L)` + - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input) + + Examples:: + + >>> # Without Learnable Parameters + >>> m = nn.InstanceNorm1d(100) + >>> # With Learnable Parameters + >>> m = nn.InstanceNorm1d(100, affine=True) + >>> input = torch.randn(20, 100, 40) + >>> output = m(input) + """ + + def _get_no_batch_dim(self) -> int: + return 2 + + def _check_input_dim(self, input) -> None: + if input.dim() not in (2, 3): + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") + + +class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm): + r"""A :class:`torch.nn.InstanceNorm1d` module with lazy initialization of the ``num_features`` argument. + + The ``num_features`` argument of the :class:`InstanceNorm1d` is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, L)` or :math:`(C, L)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, L)` or :math:`(C, L)` + - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input) + """ + + cls_to_become = InstanceNorm1d # type: ignore[assignment] + + def _get_no_batch_dim(self) -> int: + return 2 + + def _check_input_dim(self, input) -> None: + if input.dim() not in (2, 3): + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") + + +class InstanceNorm2d(_InstanceNorm): + r"""Applies Instance Normalization. + + This operation applies Instance Normalization + over a 4D input (a mini-batch of 2D inputs + with additional channel dimension) as described in the paper + `Instance Normalization: The Missing Ingredient for Fast Stylization + `__. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension separately + for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the input size) if :attr:`affine` is ``True``. + The standard-deviation is calculated via the biased estimator, equivalent to + `torch.var(input, correction=0)`. + + By default, this layer uses instance statistics computed from input data in + both training and evaluation modes. + + If :attr:`track_running_stats` is set to ``True``, during training this + layer keeps running estimates of its computed mean and variance, which are + then used for normalization during evaluation. The running estimates are + kept with a default :attr:`momentum` of 0.1. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + .. note:: + :class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but + have some subtle differences. :class:`InstanceNorm2d` is applied + on each channel of channeled data like RGB images, but + :class:`LayerNorm` is usually applied on entire sample and often in NLP + tasks. Additionally, :class:`LayerNorm` applies elementwise affine + transform, while :class:`InstanceNorm2d` usually don't apply affine + transform. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, H, W)` or :math:`(C, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` + - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) + + Examples:: + + >>> # Without Learnable Parameters + >>> m = nn.InstanceNorm2d(100) + >>> # With Learnable Parameters + >>> m = nn.InstanceNorm2d(100, affine=True) + >>> input = torch.randn(20, 100, 35, 45) + >>> output = m(input) + """ + + def _get_no_batch_dim(self) -> int: + return 3 + + def _check_input_dim(self, input) -> None: + if input.dim() not in (3, 4): + raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") + + +class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm): + r"""A :class:`torch.nn.InstanceNorm2d` module with lazy initialization of the ``num_features`` argument. + + The ``num_features`` argument of the :class:`InstanceNorm2d` is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, + `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, H, W)` or :math:`(C, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` + - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) + """ + + cls_to_become = InstanceNorm2d # type: ignore[assignment] + + def _get_no_batch_dim(self) -> int: + return 3 + + def _check_input_dim(self, input) -> None: + if input.dim() not in (3, 4): + raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") + + +class InstanceNorm3d(_InstanceNorm): + r"""Applies Instance Normalization. + + This operation applies Instance Normalization + over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper + `Instance Normalization: The Missing Ingredient for Fast Stylization + `__. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension separately + for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size C (where C is the input size) if :attr:`affine` is ``True``. + The standard-deviation is calculated via the biased estimator, equivalent to + `torch.var(input, correction=0)`. + + By default, this layer uses instance statistics computed from input data in + both training and evaluation modes. + + If :attr:`track_running_stats` is set to ``True``, during training this + layer keeps running estimates of its computed mean and variance, which are + then used for normalization during evaluation. The running estimates are + kept with a default :attr:`momentum` of 0.1. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + .. note:: + :class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but + have some subtle differences. :class:`InstanceNorm3d` is applied + on each channel of channeled data like 3D models with RGB color, but + :class:`LayerNorm` is usually applied on entire sample and often in NLP + tasks. Additionally, :class:`LayerNorm` applies elementwise affine + transform, while :class:`InstanceNorm3d` usually don't apply affine + transform. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input) + + Examples:: + + >>> # Without Learnable Parameters + >>> m = nn.InstanceNorm3d(100) + >>> # With Learnable Parameters + >>> m = nn.InstanceNorm3d(100, affine=True) + >>> input = torch.randn(20, 100, 35, 45, 10) + >>> output = m(input) + """ + + def _get_no_batch_dim(self) -> int: + return 4 + + def _check_input_dim(self, input) -> None: + if input.dim() not in (4, 5): + raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") + + +class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm): + r"""A :class:`torch.nn.InstanceNorm3d` module with lazy initialization of the ``num_features`` argument. + + The ``num_features`` argument of the :class:`InstanceNorm3d` is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, + `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input) + """ + + cls_to_become = InstanceNorm3d # type: ignore[assignment] + + def _get_no_batch_dim(self) -> int: + return 4 + + def _check_input_dim(self, input) -> None: + if input.dim() not in (4, 5): + raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/lazy.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..72d90d1c10364ea380b1f27069dd69dda6ec80cc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/lazy.py @@ -0,0 +1,278 @@ +# mypy: allow-untyped-defs +import itertools +from typing import Any, Protocol + +import torch +from torch.nn.parameter import is_lazy + + +__all__ = ["LazyModuleMixin"] + + +class _LazyProtocol(Protocol): + """This class is used to avoid errors with mypy checks for the attributes in a mixin. + + https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes + """ + + def _register_load_state_dict_pre_hook(self, hook): ... + + def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False): ... + + def _lazy_load_hook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): ... + + def _get_name(self): ... + + def _infer_parameters(self, module, input): ... + + @property + def _parameters(self): ... + + @property + def _buffers(self): ... + + @property + def _non_persistent_buffers_set(self): ... + + @property + def _load_hook(self): ... + + @property + def _initialize_hook(self): ... + + +class LazyModuleMixin: + r"""A mixin for modules that lazily initialize parameters, also known as "lazy modules". + + .. warning: + Lazy modules are an experimental new feature under active development, + and their API is likely to change. + + Modules that lazily initialize parameters, or "lazy modules", + derive the shapes of their parameters from the first input(s) + to their forward method. Until that first forward they contain + :class:`torch.nn.UninitializedParameter` s that should not be accessed + or used, and afterward they contain regular :class:`torch.nn.Parameter` s. + Lazy modules are convenient since they don't require computing some + module arguments, like the :attr:`in_features` argument of a + typical :class:`torch.nn.Linear`. + + After construction, networks with lazy modules should first + be converted to the desired dtype and placed on the expected device. + This is because lazy modules only perform shape inference so the usual dtype + and device placement behavior applies. + The lazy modules should then perform "dry runs" to initialize all the components in the module. + These "dry runs" send inputs of the correct size, dtype, and device through + the network and to each one of its lazy modules. After this the network can be used as usual. + + >>> # xdoctest: +SKIP + >>> class LazyMLP(torch.nn.Module): + ... def __init__(self) -> None: + ... super().__init__() + ... self.fc1 = torch.nn.LazyLinear(10) + ... self.relu1 = torch.nn.ReLU() + ... self.fc2 = torch.nn.LazyLinear(1) + ... self.relu2 = torch.nn.ReLU() + ... + ... def forward(self, input): + ... x = self.relu1(self.fc1(input)) + ... y = self.relu2(self.fc2(x)) + ... return y + >>> # constructs a network with lazy modules + >>> lazy_mlp = LazyMLP() + >>> # transforms the network's device and dtype + >>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs' + >>> lazy_mlp = lazy_mlp.cuda() + >>> lazy_mlp + LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True) + (relu1): ReLU() + (fc2): LazyLinear(in_features=0, out_features=1, bias=True) + (relu2): ReLU() + ) + >>> # performs a dry run to initialize the network's lazy modules + >>> lazy_mlp(torch.ones(10, 10).cuda()) + >>> # after initialization, LazyLinear modules become regular Linear modules + >>> lazy_mlp + LazyMLP( + (fc1): Linear(in_features=10, out_features=10, bias=True) + (relu1): ReLU() + (fc2): Linear(in_features=10, out_features=1, bias=True) + (relu2): ReLU() + ) + >>> # attaches an optimizer, since parameters can now be used as usual + >>> optim = torch.optim.SGD(lazy_mlp.parameters(), lr=0.01) + + A final caveat when using lazy modules is that the order of initialization of a network's + parameters may change, since the lazy modules are always initialized after other modules. + For example, if the LazyMLP class defined above had a :class:`torch.nn.LazyLinear` module + first and then a regular :class:`torch.nn.Linear` second, the second module would be + initialized on construction and the first module would be initialized during the first dry run. + This can cause the parameters of a network using lazy modules to be initialized differently + than the parameters of a network without lazy modules as the order of parameter initializations, + which often depends on a stateful random number generator, is different. + Check :doc:`/notes/randomness` for more details. + + Lazy modules can be serialized with a state dict like other modules. For example: + + >>> lazy_mlp = LazyMLP() + >>> # The state dict shows the uninitialized parameters + >>> lazy_mlp.state_dict() + OrderedDict({'fc1.weight': , + 'fc1.bias': , + 'fc2.weight': , + 'fc2.bias': }) + + Lazy modules can load regular :class:`torch.nn.Parameter` s (i.e. you can serialize/deserialize + initialized LazyModules and they will remain initialized) + + + >>> full_mlp = LazyMLP() + >>> # Dry run to initialize another module + >>> full_mlp.forward(torch.ones(10, 1)) + >>> # Load an initialized state into a lazy module + >>> lazy_mlp.load_state_dict(full_mlp.state_dict()) + >>> # The state dict now holds valid values + >>> lazy_mlp.state_dict() + OrderedDict([('fc1.weight', + tensor([[-0.3837], + [ 0.0907], + [ 0.6708], + [-0.5223], + [-0.9028], + [ 0.2851], + [-0.4537], + [ 0.6813], + [ 0.5766], + [-0.8678]])), + ('fc1.bias', + tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, + 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), + ('fc2.weight', + tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807, + 0.2479, 0.1091]])), + ('fc2.bias', tensor([0.0019]))]) + + Note, however, that the loaded parameters will not be replaced when doing a "dry run" if they are initialized + when the state is loaded. This prevents using initialized modules in different contexts. + """ + + # modules inheriting from this will change their __class__ to the specified + # one after they are fully initialized + cls_to_become: type[Any] | None = None + + def __init__(self: _LazyProtocol, *args, **kwargs): + # Mypy doesn't like this super call in a mixin + super().__init__(*args, **kwargs) # type: ignore[misc] + # pyrefly: ignore [read-only] + self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook) + # pyrefly: ignore [read-only] + self._initialize_hook = self.register_forward_pre_hook( + self._infer_parameters, with_kwargs=True + ) + + def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars): + # This should be ideally implemented as a hook, + # but we should override `detach` in the UninitializedParameter to return itself + # which is not clean + for name, param in self._parameters.items(): + if param is not None: + if not (is_lazy(param) or keep_vars): + param = param.detach() + destination[prefix + name] = param + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + if not (is_lazy(buf) or keep_vars): + buf = buf.detach() + destination[prefix + name] = buf + + def _lazy_load_hook( + self: _LazyProtocol, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """load_state_dict pre-hook function for lazy buffers and parameters. + + The purpose of this hook is to adjust the current state and/or + ``state_dict`` being loaded so that a module instance serialized in + both un/initialized state can be deserialized onto both un/initialized + module instance. + See comment in ``torch.nn.Module._register_load_state_dict_pre_hook`` + for the details of the hook specification. + """ + for name, param in itertools.chain( + self._parameters.items(), self._buffers.items() + ): + key = prefix + name + if key in state_dict and param is not None: + input_param = state_dict[key] + if is_lazy(param): + # The current parameter is not initialized but the one being loaded one is + # create a new parameter based on the uninitialized one + if not is_lazy(input_param): + with torch.no_grad(): + param.materialize(input_param.shape) + + def initialize_parameters(self: _LazyProtocol, *args, **kwargs): + r"""Initialize parameters according to the input batch properties. + + This adds an interface to isolate parameter initialization from the + forward pass when doing parameter shape inference. + """ + raise NotImplementedError( + f"initialize_parameters is not implemented for {self.__class__.__name__}" + ) + + def has_uninitialized_params(self: _LazyProtocol): + r"""Check if a module has parameters that are not initialized.""" + # This is to avoid the JIT to track this parameter and force + # custom modules __setstate__ to add it + params = self._parameters.values() + buffers = self._buffers.values() + for param in itertools.chain(params, buffers): + if is_lazy(param): + return True + return False + + # torchrec tests the code consistency with the following code + # fmt: off + def _infer_parameters(self: _LazyProtocol, module, args, kwargs=None): + r"""Infers the size and initializes the parameters according to the provided input batch. + + Given a module that contains parameters that were declared inferable + using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass + in the complete module using the provided input to initialize all the parameters + as needed. + The module is set into evaluation mode before running the forward pass in order + to avoid saving statistics or calculating gradients + """ + kwargs = kwargs if kwargs else {} + module.initialize_parameters(*args, **kwargs) + if module.has_uninitialized_params(): + raise RuntimeError(f'module {self._get_name()} has not been fully initialized') + module._initialize_hook.remove() + module._load_hook.remove() + delattr(module, '_initialize_hook') + delattr(module, '_load_hook') + if module.cls_to_become is not None: + module.__class__ = module.cls_to_become + # fmt: on + + def _replicate_for_data_parallel(self: _LazyProtocol): + raise RuntimeError( + "Modules with uninitialized parameters can't be used with `DataParallel`. " + "Run a dummy forward pass to correctly initialize the modules" + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/linear.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..c58bdcefd0e0a9212d44891d6ade694e55c5f529 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/linear.py @@ -0,0 +1,337 @@ +# mypy: allow-untyped-defs +import math +from typing import Any + +import torch +from torch import Tensor +from torch.nn import functional as F, init +from torch.nn.parameter import Parameter, UninitializedParameter + +from .lazy import LazyModuleMixin +from .module import Module + + +__all__ = [ + "Bilinear", + "Identity", + "LazyLinear", + "Linear", +] + + +class Identity(Module): + r"""A placeholder identity operator that is argument-insensitive. + + Args: + args: any argument (unused) + kwargs: any keyword argument (unused) + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + Examples:: + + >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 20]) + + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return input + + +class Linear(Module): + r"""Applies an affine linear transformation to the incoming data: :math:`y = xA^T + b`. + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + in_features: size of each input sample + out_features: size of each output sample + bias: If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + + Shape: + - Input: :math:`(*, H_\text{in})` where :math:`*` means any number of + dimensions including none and :math:`H_\text{in} = \text{in\_features}`. + - Output: :math:`(*, H_\text{out})` where all but the last dimension + are the same shape as the input and :math:`H_\text{out} = \text{out\_features}`. + + Attributes: + weight: the learnable weights of the module of shape + :math:`(\text{out\_features}, \text{in\_features})`. The values are + initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in\_features}}` + bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{1}{\text{in\_features}}` + + Examples:: + + >>> m = nn.Linear(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter( + torch.empty((out_features, in_features), **factory_kwargs) + ) + if bias: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in ``__init__``. + """ + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see + # https://github.com/pytorch/pytorch/issues/57109 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.linear(input, self.weight, self.bias) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" + + +# This class exists solely to avoid triggering an obscure error when scripting +# an improperly quantized attention layer. See this issue for details: +# https://github.com/pytorch/pytorch/issues/58969 +# TODO: fail fast on quantization API usage error, then remove this class +# and replace uses of it with plain Linear +class NonDynamicallyQuantizableLinear(Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__( + in_features, out_features, bias=bias, device=device, dtype=dtype + ) + + +class Bilinear(Module): + r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b`. + + Args: + in1_features: size of each first input sample, must be > 0 + in2_features: size of each second input sample, must be > 0 + out_features: size of each output sample, must be > 0 + bias: If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + + Shape: + - Input1: :math:`(*, H_\text{in1})` where :math:`H_\text{in1}=\text{in1\_features}` and + :math:`*` means any number of additional dimensions including none. All but the last dimension + of the inputs should be the same. + - Input2: :math:`(*, H_\text{in2})` where :math:`H_\text{in2}=\text{in2\_features}`. + - Output: :math:`(*, H_\text{out})` where :math:`H_\text{out}=\text{out\_features}` + and all but the last dimension are the same shape as the input. + + Attributes: + weight: the learnable weights of the module of shape + :math:`(\text{out\_features}, \text{in1\_features}, \text{in2\_features})`. + The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in1\_features}}` + bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in1\_features}}` + + Examples:: + + >>> m = nn.Bilinear(20, 30, 40) + >>> input1 = torch.randn(128, 20) + >>> input2 = torch.randn(128, 30) + >>> output = m(input1, input2) + >>> print(output.size()) + torch.Size([128, 40]) + """ + + __constants__ = ["in1_features", "in2_features", "out_features"] + in1_features: int + in2_features: int + out_features: int + weight: Tensor + + def __init__( + self, + in1_features: int, + in2_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in1_features = in1_features + self.in2_features = in2_features + self.out_features = out_features + self.weight = Parameter( + torch.empty((out_features, in1_features, in2_features), **factory_kwargs) + ) + + if bias: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in ``__init__``. + """ + if self.in1_features <= 0: + raise ValueError( + f"in1_features must be > 0, but got (in1_features={self.in1_features})" + ) + bound = 1 / math.sqrt(self.weight.size(1)) + init.uniform_(self.weight, -bound, bound) + if self.bias is not None: + init.uniform_(self.bias, -bound, bound) + + def forward(self, input1: Tensor, input2: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.bilinear(input1, input2, self.weight, self.bias) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return ( + f"in1_features={self.in1_features}, in2_features={self.in2_features}, " + f"out_features={self.out_features}, bias={self.bias is not None}" + ) + + +class LazyLinear(LazyModuleMixin, Linear): + r"""A :class:`torch.nn.Linear` module where `in_features` is inferred. + + In this module, the `weight` and `bias` are of :class:`torch.nn.UninitializedParameter` + class. They will be initialized after the first call to ``forward`` is done and the + module will become a regular :class:`torch.nn.Linear` module. The ``in_features`` argument + of the :class:`Linear` is inferred from the ``input.shape[-1]``. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_features: size of each output sample + bias: If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + + Attributes: + weight: the learnable weights of the module of shape + :math:`(\text{out\_features}, \text{in\_features})`. The values are + initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in\_features}}` + bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{1}{\text{in\_features}}` + + + """ + + cls_to_become = Linear # type: ignore[assignment] + # pyrefly: ignore [bad-override] + weight: UninitializedParameter + bias: UninitializedParameter # type: ignore[assignment] + + def __init__( + self, out_features: int, bias: bool = True, device=None, dtype=None + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + # pyrefly: ignore [bad-argument-type] + super().__init__(0, 0, False) + # pyrefly: ignore [bad-argument-type] + self.weight = UninitializedParameter(**factory_kwargs) + self.out_features = out_features + if bias: + # pyrefly: ignore [bad-argument-type] + self.bias = UninitializedParameter(**factory_kwargs) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in ``__init__``. + """ + # pyrefly: ignore [bad-argument-type] + if not self.has_uninitialized_params() and self.in_features != 0: + super().reset_parameters() + + def initialize_parameters(self, input) -> None: # type: ignore[override] + """ + Infers ``in_features`` based on ``input`` and initializes parameters. + """ + # pyrefly: ignore [bad-argument-type] + if self.has_uninitialized_params(): + with torch.no_grad(): + self.in_features = input.shape[-1] + self.weight.materialize((self.out_features, self.in_features)) + if self.bias is not None: + self.bias.materialize((self.out_features,)) + self.reset_parameters() + if self.in_features == 0: + assert input.shape[-1] == self.weight.shape[-1], ( + f"The in_features inferred from input: {input.shape[-1]} " + f"is not equal to in_features from self.weight: " + f"{self.weight.shape[-1]}" + ) + self.in_features = input.shape[-1] + + +# TODO: PartialLinear - maybe in sparse? diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/loss.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..00ada62febded14af25c6a32ec8c1e5998349d74 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/loss.py @@ -0,0 +1,2083 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable +from typing_extensions import deprecated + +from torch import Tensor +from torch.nn import _reduction as _Reduction, functional as F + +from .distance import PairwiseDistance +from .module import Module + + +__all__ = [ + "L1Loss", + "NLLLoss", + "NLLLoss2d", + "PoissonNLLLoss", + "GaussianNLLLoss", + "KLDivLoss", + "MSELoss", + "BCELoss", + "BCEWithLogitsLoss", + "HingeEmbeddingLoss", + "MultiLabelMarginLoss", + "SmoothL1Loss", + "HuberLoss", + "SoftMarginLoss", + "CrossEntropyLoss", + "MultiLabelSoftMarginLoss", + "CosineEmbeddingLoss", + "MarginRankingLoss", + "MultiMarginLoss", + "TripletMarginLoss", + "TripletMarginWithDistanceLoss", + "CTCLoss", +] + + +class _Loss(Module): + reduction: str + + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: + super().__init__() + if size_average is not None or reduce is not None: + self.reduction: str = _Reduction.legacy_get_string(size_average, reduce) + else: + self.reduction = reduction + + +class _WeightedLoss(_Loss): + def __init__( + self, + weight: Tensor | None = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(size_average, reduce, reduction) + self.register_buffer("weight", weight) + self.weight: Tensor | None + + +class L1Loss(_Loss): + r"""Creates a criterion that measures the mean absolute error (MAE) between each element in + the input :math:`x` and target :math:`y`. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left| x_n - y_n \right|, + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`N` elements each. + + The sum operation still operates over all the elements, and divides by :math:`N`. + + The division by :math:`N` can be avoided if one sets ``reduction = 'sum'``. + + Supports real-valued and complex-valued inputs. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then + :math:`(*)`, same shape as the input. + + Examples: + + >>> loss = nn.L1Loss() + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5) + >>> output = loss(input, target) + >>> output.backward() + """ + + __constants__ = ["reduction"] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.l1_loss(input, target, reduction=self.reduction) + + +class NLLLoss(_WeightedLoss): + r"""The negative log likelihood loss. It is useful to train a classification + problem with `C` classes. + + If provided, the optional argument :attr:`weight` should be a 1D Tensor assigning + weight to each of the classes. This is particularly useful when you have an + unbalanced training set. + + The `input` given through a forward call is expected to contain + log-probabilities of each class. `input` has to be a Tensor of size either + :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` + with :math:`K \geq 1` for the `K`-dimensional case. The latter is useful for + higher dimension inputs, such as computing NLL loss per-pixel for 2D images. + + Obtaining log-probabilities in a neural network is easily achieved by + adding a `LogSoftmax` layer in the last layer of your network. + You may use `CrossEntropyLoss` instead, if you prefer not to add an extra + layer. + + The `target` that this loss expects should be a class index in the range :math:`[0, C-1]` + where `C = number of classes`; if `ignore_index` is specified, this loss also accepts + this class index (this index may not necessarily be in the class range). + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \\ + l_n = - w_{y_n} x_{n,y_n}, \\ + w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\}, + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and + :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + Args: + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, it has to be a Tensor of size `C`. Otherwise, it is + treated as if having all ones. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``None`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When + :attr:`size_average` is ``True``, the loss is averaged over + non-ignored targets. + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``None`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, `N = batch size`, or + :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: :math:`(N)` or :math:`()`, where each value is + :math:`0 \leq \text{targets}[i] \leq C-1`, or + :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of + K-dimensional loss. + - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or + :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. + Otherwise, scalar. + + Examples: + + >>> log_softmax = nn.LogSoftmax(dim=1) + >>> loss_fn = nn.NLLLoss() + >>> # input to NLLLoss is of size N x C = 3 x 5 + >>> input = torch.randn(3, 5, requires_grad=True) + >>> # each element in target must have 0 <= value < C + >>> target = torch.tensor([1, 0, 4]) + >>> loss = loss_fn(log_softmax(input), target) + >>> loss.backward() + >>> + >>> + >>> # 2D loss example (used, for example, with image inputs) + >>> N, C = 5, 4 + >>> loss_fn = nn.NLLLoss() + >>> data = torch.randn(N, 16, 10, 10) + >>> conv = nn.Conv2d(16, C, (3, 3)) + >>> log_softmax = nn.LogSoftmax(dim=1) + >>> # output of conv forward is of shape [N, C, 8, 8] + >>> output = log_softmax(conv(data)) + >>> # each element in target must have 0 <= value < C + >>> target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) + >>> # input to NLLLoss is of size N x C x height (8) x width (8) + >>> loss = loss_fn(output, target) + >>> loss.backward() + """ + + __constants__ = ["ignore_index", "reduction"] + ignore_index: int + + def __init__( + self, + weight: Tensor | None = None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.nll_loss( + input, + target, + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + ) + + +@deprecated( + "`NLLLoss2d` has been deprecated. " + "Please use `NLLLoss` instead as a drop-in replacement and see " + "https://pytorch.org/docs/main/nn.html#torch.nn.NLLLoss for more details.", + category=FutureWarning, +) +class NLLLoss2d(NLLLoss): + def __init__( + self, + weight: Tensor | None = None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(weight, size_average, ignore_index, reduce, reduction) + + +class PoissonNLLLoss(_Loss): + r"""Negative log likelihood loss with Poisson distribution of target. + + The loss can be described as: + + .. math:: + \text{target} \sim \mathrm{Poisson}(\text{input}) + + \text{loss}(\text{input}, \text{target}) = \text{input} - \text{target} * \log(\text{input}) + + \log(\text{target!}) + + The last term can be omitted or approximated with Stirling formula. The + approximation is used for target values more than 1. For targets less or + equal to 1 zeros are added to the loss. + + Args: + log_input (bool, optional): if ``True`` the loss is computed as + :math:`\exp(\text{input}) - \text{target}*\text{input}`, if ``False`` the loss is + :math:`\text{input} - \text{target}*\log(\text{input}+\text{eps})`. + full (bool, optional): whether to compute full loss, i. e. to add the + Stirling approximation term + + .. math:: + \text{target}*\log(\text{target}) - \text{target} + 0.5 * \log(2\pi\text{target}). + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when + :attr:`log_input = False`. Default: 1e-8 + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Examples: + + >>> loss = nn.PoissonNLLLoss() + >>> log_input = torch.randn(5, 2, requires_grad=True) + >>> target = torch.randn(5, 2) + >>> output = loss(log_input, target) + >>> output.backward() + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(*)`, + the same shape as the input. + """ + + __constants__ = ["log_input", "full", "eps", "reduction"] + log_input: bool + full: bool + eps: float + + def __init__( + self, + log_input: bool = True, + full: bool = False, + size_average=None, + eps: float = 1e-8, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(size_average, reduce, reduction) + self.log_input = log_input + self.full = full + self.eps = eps + + def forward(self, log_input: Tensor, target: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.poisson_nll_loss( + log_input, + target, + log_input=self.log_input, + full=self.full, + eps=self.eps, + reduction=self.reduction, + ) + + +class GaussianNLLLoss(_Loss): + r"""Gaussian negative log likelihood loss. + + The targets are treated as samples from Gaussian distributions with + expectations and variances predicted by the neural network. For a + ``target`` tensor modelled as having Gaussian distribution with a tensor + of expectations ``input`` and a tensor of positive variances ``var`` the loss is: + + .. math:: + \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, + \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} + {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} + + where :attr:`eps` is used for stability. By default, the constant term of + the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same + size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension + of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting. + + Args: + full (bool, optional): include the constant term in the loss + calculation. Default: ``False``. + eps (float, optional): value used to clamp ``var`` (see note below), for + stability. Default: 1e-6. + reduction (str, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction + will be applied, ``'mean'``: the output is the average of all batch + member losses, ``'sum'``: the output is the sum of all batch member + losses. Default: ``'mean'``. + + Shape: + - Input: :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional + dimensions + - Target: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input + but with one dimension equal to 1 (to allow for broadcasting) + - Var: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but + with one dimension equal to 1, or same shape as the input but with one fewer + dimension (to allow for broadcasting), or a scalar value + - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or + ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same + shape as the input + + Examples: + >>> loss = nn.GaussianNLLLoss() + >>> input = torch.randn(5, 2, requires_grad=True) + >>> target = torch.randn(5, 2) + >>> var = torch.ones(5, 2, requires_grad=True) # heteroscedastic + >>> output = loss(input, target, var) + >>> output.backward() + + >>> loss = nn.GaussianNLLLoss() + >>> input = torch.randn(5, 2, requires_grad=True) + >>> target = torch.randn(5, 2) + >>> var = torch.ones(5, 1, requires_grad=True) # homoscedastic + >>> output = loss(input, target, var) + >>> output.backward() + + Note: + The clamping of ``var`` is ignored with respect to autograd, and so the + gradients are unaffected by it. + + Reference: + Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the + target probability distribution", Proceedings of 1994 IEEE International + Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60 + vol.1, doi: 10.1109/ICNN.1994.374138. + """ + + __constants__ = ["full", "eps", "reduction"] + full: bool + eps: float + + def __init__( + self, *, full: bool = False, eps: float = 1e-6, reduction: str = "mean" + ) -> None: + super().__init__(None, None, reduction) + self.full = full + self.eps = eps + + def forward(self, input: Tensor, target: Tensor, var: Tensor | float) -> Tensor: + """ + Runs the forward pass. + """ + return F.gaussian_nll_loss( + input, target, var, full=self.full, eps=self.eps, reduction=self.reduction + ) + + +class KLDivLoss(_Loss): + r"""The Kullback-Leibler divergence loss. + + For tensors of the same shape :math:`y_{\text{pred}},\ y_{\text{true}}`, + where :math:`y_{\text{pred}}` is the :attr:`input` and :math:`y_{\text{true}}` is the + :attr:`target`, we define the **pointwise KL-divergence** as + + .. math:: + + L(y_{\text{pred}},\ y_{\text{true}}) + = y_{\text{true}} \cdot \log \frac{y_{\text{true}}}{y_{\text{pred}}} + = y_{\text{true}} \cdot (\log y_{\text{true}} - \log y_{\text{pred}}) + + To avoid underflow issues when computing this quantity, this loss expects the argument + :attr:`input` in the log-space. The argument :attr:`target` may also be provided in the + log-space if :attr:`log_target`\ `= True`. + + To summarise, this function is roughly equivalent to computing + + .. code-block:: python + + if not log_target: # default + loss_pointwise = target * (target.log() - input) + else: + loss_pointwise = target.exp() * (target - input) + + and then reducing this result depending on the argument :attr:`reduction` as + + .. code-block:: python + + if reduction == "mean": # default + loss = loss_pointwise.mean() + elif reduction == "batchmean": # mathematically correct + loss = loss_pointwise.sum() / input.size(0) + elif reduction == "sum": + loss = loss_pointwise.sum() + else: # reduction == "none" + loss = loss_pointwise + + .. note:: + As all the other losses in PyTorch, this function expects the first argument, + :attr:`input`, to be the output of the model (e.g. the neural network) + and the second, :attr:`target`, to be the observations in the dataset. + This differs from the standard mathematical notation :math:`KL(P\ ||\ Q)` where + :math:`P` denotes the distribution of the observations and :math:`Q` denotes the model. + + .. warning:: + :attr:`reduction`\ `= "mean"` doesn't return the true KL divergence value, please use + :attr:`reduction`\ `= "batchmean"` which aligns with the mathematical definition. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to `False`, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is `False`. Default: `True` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is `False`, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: `True` + reduction (str, optional): Specifies the reduction to apply to the output. Default: `"mean"` + log_target (bool, optional): Specifies whether `target` is the log space. Default: `False` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar by default. If :attr:`reduction` is `'none'`, then :math:`(*)`, + same shape as the input. + + Examples: + >>> kl_loss = nn.KLDivLoss(reduction="batchmean") + >>> # input should be a distribution in the log space + >>> input = F.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1) + >>> # Sample a batch of distributions. Usually this would come from the dataset + >>> target = F.softmax(torch.rand(3, 5), dim=1) + >>> output = kl_loss(input, target) + >>> + >>> kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True) + >>> log_target = F.log_softmax(torch.rand(3, 5), dim=1) + >>> output = kl_loss(input, log_target) + """ + + __constants__ = ["reduction"] + + def __init__( + self, + size_average=None, + reduce=None, + reduction: str = "mean", + log_target: bool = False, + ) -> None: + super().__init__(size_average, reduce, reduction) + self.log_target = log_target + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.kl_div( + input, target, reduction=self.reduction, log_target=self.log_target + ) + + +class MSELoss(_Loss): + r"""Creates a criterion that measures the mean squared error (squared L2 norm) between + each element in the input :math:`x` and target :math:`y`. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left( x_n - y_n \right)^2, + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`N` elements each. + + The mean operation still operates over all the elements, and divides by :math:`N`. + + The division by :math:`N` can be avoided if one sets ``reduction = 'sum'``. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + + Examples: + + >>> loss = nn.MSELoss() + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5) + >>> output = loss(input, target) + >>> output.backward() + """ + + __constants__ = ["reduction"] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.mse_loss(input, target, reduction=self.reduction) + + +class BCELoss(_WeightedLoss): + r"""Creates a criterion that measures the Binary Cross Entropy between the target and + the input probabilities: + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right], + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + This is used for measuring the error of a reconstruction in for example + an auto-encoder. Note that the targets :math:`y` should be numbers + between 0 and 1. + + Notice that if :math:`x_n` is either 0 or 1, one of the log terms would be + mathematically undefined in the above loss equation. PyTorch chooses to set + :math:`\log (0) = -\infty`, since :math:`\lim_{x\to 0} \log (x) = -\infty`. + However, an infinite term in the loss equation is not desirable for several reasons. + + For one, if either :math:`y_n = 0` or :math:`(1 - y_n) = 0`, then we would be + multiplying 0 with infinity. Secondly, if we have an infinite loss value, then + we would also have an infinite term in our gradient, since + :math:`\lim_{x\to 0} \frac{d}{dx} \log (x) = \infty`. + This would make BCELoss's backward method nonlinear with respect to :math:`x_n`, + and using it for things like linear regression would not be straight-forward. + + Our solution is that BCELoss clamps its log function outputs to be greater than + or equal to -100. This way, we can always have a finite loss value and a linear + backward method. + + + Args: + weight (Tensor, optional): a manual rescaling weight given to the loss + of each batch element. If given, has to be a Tensor of size `nbatch`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same + shape as input. + + Examples: + + >>> m = nn.Sigmoid() + >>> loss = nn.BCELoss() + >>> input = torch.randn(3, 2, requires_grad=True) + >>> target = torch.rand(3, 2, requires_grad=False) + >>> output = loss(m(input), target) + >>> output.backward() + """ + + __constants__ = ["reduction"] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.binary_cross_entropy( + input, target, weight=self.weight, reduction=self.reduction + ) + + +class BCEWithLogitsLoss(_Loss): + r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single + class. This version is more numerically stable than using a plain `Sigmoid` + followed by a `BCELoss` as, by combining the operations into one layer, + we take advantage of the log-sum-exp trick for numerical stability. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_n \left[ y_n \cdot \log \sigma(x_n) + + (1 - y_n) \cdot \log (1 - \sigma(x_n)) \right], + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + This is used for measuring the error of a reconstruction in for example + an auto-encoder. Note that the targets `t[i]` should be numbers + between 0 and 1. + + It's possible to trade off recall and precision by adding weights to positive examples. + In the case of multi-label classification the loss can be described as: + + .. math:: + \ell_c(x, y) = L_c = \{l_{1,c},\dots,l_{N,c}\}^\top, \quad + l_{n,c} = - w_{n,c} \left[ p_c y_{n,c} \cdot \log \sigma(x_{n,c}) + + (1 - y_{n,c}) \cdot \log (1 - \sigma(x_{n,c})) \right], + + where :math:`c` is the class number (:math:`c > 1` for multi-label binary classification, + :math:`c = 1` for single-label binary classification), + :math:`n` is the number of the sample in the batch and + :math:`p_c` is the weight of the positive answer for the class :math:`c`. + + :math:`p_c > 1` increases the recall, :math:`p_c < 1` increases the precision. + + For example, if a dataset contains 100 positive and 300 negative examples of a single class, + then ``pos_weight`` for the class should be equal to :math:`\frac{300}{100}=3`. + The loss would act as if the dataset contains :math:`3\times 100=300` positive examples. + + Examples: + + >>> target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10 + >>> output = torch.full([10, 64], 1.5) # A prediction (logit) + >>> pos_weight = torch.ones([64]) # All weights are equal to 1 + >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) + >>> criterion(output, target) # -log(sigmoid(1.5)) + tensor(0.20...) + + In the above example, the ``pos_weight`` tensor's elements correspond to the 64 distinct classes + in a multi-label binary classification scenario. Each element in ``pos_weight`` is designed to adjust the + loss function based on the imbalance between negative and positive samples for the respective class. + This approach is useful in datasets with varying levels of class imbalance, ensuring that the loss + calculation accurately accounts for the distribution in each class. + + Args: + weight (Tensor, optional): a manual rescaling weight given to the loss + of each batch element. If given, has to be a Tensor of size `nbatch`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target. + Must be a tensor with equal size along the class dimension to the number of classes. + Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired + operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of + size [B, C, H, W] will apply different pos_weights to each element of the batch or + [C, H, W] the same pos_weights across the batch. To apply the same positive weight + along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. + Default: ``None`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same + shape as input. + + Examples: + + >>> loss = nn.BCEWithLogitsLoss() + >>> input = torch.randn(3, requires_grad=True) + >>> target = torch.empty(3).random_(2) + >>> output = loss(input, target) + >>> output.backward() + """ + + def __init__( + self, + weight: Tensor | None = None, + size_average=None, + reduce=None, + reduction: str = "mean", + pos_weight: Tensor | None = None, + ) -> None: + super().__init__(size_average, reduce, reduction) + self.register_buffer("weight", weight) + self.register_buffer("pos_weight", pos_weight) + self.weight: Tensor | None + self.pos_weight: Tensor | None + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.binary_cross_entropy_with_logits( + input, + target, + self.weight, + pos_weight=self.pos_weight, + reduction=self.reduction, + ) + + +class HingeEmbeddingLoss(_Loss): + r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y` + (containing 1 or -1). + This is usually used for measuring whether two inputs are similar or + dissimilar, e.g. using the L1 pairwise distance as :math:`x`, and is typically + used for learning nonlinear embeddings or semi-supervised learning. + + The loss function for :math:`n`-th sample in the mini-batch is + + .. math:: + l_n = \begin{cases} + x_n, & \text{if}\; y_n = 1,\\ + \max \{0, margin - x_n\}, & \text{if}\; y_n = -1, + \end{cases} + + and the total loss functions is + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + where :math:`L = \{l_1,\dots,l_N\}^\top`. + + Args: + margin (float, optional): Has a default value of `1`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)` where :math:`*` means, any number of dimensions. The sum operation + operates over all the elements. + - Target: :math:`(*)`, same shape as the input + - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input + """ + + __constants__ = ["margin", "reduction"] + margin: float + + def __init__( + self, + margin: float = 1.0, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.hinge_embedding_loss( + input, target, margin=self.margin, reduction=self.reduction + ) + + +class MultiLabelMarginLoss(_Loss): + r"""Creates a criterion that optimizes a multi-class multi-classification + hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) + and output :math:`y` (which is a 2D `Tensor` of target class indices). + For each sample in the mini-batch: + + .. math:: + \text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)} + + where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \ + :math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \ + :math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \ + and :math:`i \neq y[j]` for all :math:`i` and :math:`j`. + + :math:`y` and :math:`x` must have the same size. + + The criterion only considers a contiguous block of non-negative targets that + starts at the front. + + This allows for different samples to have variable amounts of target classes. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(C)` or :math:`(N, C)` where `N` is the batch size and `C` + is the number of classes. + - Target: :math:`(C)` or :math:`(N, C)`, label targets padded by -1 ensuring same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`. + + Examples: + + >>> loss = nn.MultiLabelMarginLoss() + >>> x = torch.FloatTensor([[0.1, 0.2, 0.4, 0.8]]) + >>> # for target y, only consider labels 3 and 0, not after label -1 + >>> y = torch.LongTensor([[3, 0, -1, 1]]) + >>> # 0.25 * ((1-(0.1-0.2)) + (1-(0.1-0.4)) + (1-(0.8-0.2)) + (1-(0.8-0.4))) + >>> loss(x, y) + tensor(0.85...) + + """ + + __constants__ = ["reduction"] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.multilabel_margin_loss(input, target, reduction=self.reduction) + + +class SmoothL1Loss(_Loss): + r"""Creates a criterion that uses a squared term if the absolute + element-wise error falls below beta and an L1 term otherwise. + It is less sensitive to outliers than :class:`torch.nn.MSELoss` and in some cases + prevents exploding gradients (e.g. see the paper `Fast R-CNN`_ by Ross Girshick). + + For a batch of size :math:`N`, the unreduced loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1, ..., l_N\}^T + + with + + .. math:: + l_n = \begin{cases} + 0.5 (x_n - y_n)^2 / beta, & \text{if } |x_n - y_n| < beta \\ + |x_n - y_n| - 0.5 * beta, & \text{otherwise } + \end{cases} + + If `reduction` is not `none`, then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + Smooth L1 loss can be seen as exactly :class:`L1Loss`, but with the :math:`|x - y| < beta` + portion replaced with a quadratic function such that its slope is 1 at :math:`|x - y| = beta`. + The quadratic segment smooths the L1 loss near :math:`|x - y| = 0`. + + .. note:: + Smooth L1 loss is closely related to :class:`HuberLoss`, being + equivalent to :math:`huber(x, y) / beta` (note that Smooth L1's beta hyper-parameter is + also known as delta for Huber). This leads to the following differences: + + * As beta -> 0, Smooth L1 loss converges to :class:`L1Loss`, while :class:`HuberLoss` + converges to a constant 0 loss. When beta is 0, Smooth L1 loss is equivalent to L1 loss. + * As beta -> :math:`+\infty`, Smooth L1 loss converges to a constant 0 loss, while + :class:`HuberLoss` converges to :class:`MSELoss`. + * For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant slope of 1. + For :class:`HuberLoss`, the slope of the L1 segment is beta. + + .. _`Fast R-CNN`: https://arxiv.org/abs/1504.08083 + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + beta (float, optional): Specifies the threshold at which to change between L1 and L2 loss. + The value must be non-negative. Default: 1.0 + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. + """ + + __constants__ = ["reduction"] + + def __init__( + self, size_average=None, reduce=None, reduction: str = "mean", beta: float = 1.0 + ) -> None: + super().__init__(size_average, reduce, reduction) + self.beta = beta + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta) + + +class HuberLoss(_Loss): + r"""Creates a criterion that uses a squared term if the absolute + element-wise error falls below delta and a delta-scaled L1 term otherwise. + This loss combines advantages of both :class:`L1Loss` and :class:`MSELoss`; the + delta-scaled L1 region makes the loss less sensitive to outliers than :class:`MSELoss`, + while the L2 region provides smoothness over :class:`L1Loss` near 0. See + `Huber loss `_ for more information. + + For a batch of size :math:`N`, the unreduced loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1, ..., l_N\}^T + + with + + .. math:: + l_n = \begin{cases} + 0.5 (x_n - y_n)^2, & \text{if } |x_n - y_n| < delta \\ + delta * (|x_n - y_n| - 0.5 * delta), & \text{otherwise } + \end{cases} + + If `reduction` is not `none`, then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + When delta is set to 1, this loss is equivalent to :class:`SmoothL1Loss`. + In general, this loss differs from :class:`SmoothL1Loss` by a factor of delta (AKA beta + in Smooth L1). + See :class:`SmoothL1Loss` for additional discussion on the differences in behavior + between the two losses. + + Args: + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` + delta (float, optional): Specifies the threshold at which to change between delta-scaled L1 and L2 loss. + The value must be positive. Default: 1.0 + + Shape: + - Input: :math:`(*)` where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. + """ + + __constants__ = ["reduction", "delta"] + + def __init__(self, reduction: str = "mean", delta: float = 1.0) -> None: + super().__init__(reduction=reduction) + self.delta = delta + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.huber_loss(input, target, reduction=self.reduction, delta=self.delta) + + +class SoftMarginLoss(_Loss): + r"""Creates a criterion that optimizes a two-class classification + logistic loss between input tensor :math:`x` and target tensor :math:`y` + (containing 1 or -1). + + .. math:: + \text{loss}(x, y) = \sum_i \frac{\log(1 + \exp(-y[i]*x[i]))}{\text{x.nelement}()} + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same + shape as input. + + """ + + __constants__ = ["reduction"] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.soft_margin_loss(input, target, reduction=self.reduction) + + +class CrossEntropyLoss(_WeightedLoss): + r"""This criterion computes the cross entropy loss between input logits + and target. + + It is useful when training a classification problem with `C` classes. + If provided, the optional argument :attr:`weight` should be a 1D `Tensor` + assigning weight to each of the classes. + This is particularly useful when you have an unbalanced training set. + + The `input` is expected to contain the unnormalized logits for each class (which do `not` need + to be positive or sum to 1, in general). + `input` has to be a Tensor of size :math:`(C)` for unbatched input, + :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` for the + `K`-dimensional case. The last being useful for higher dimension inputs, such + as computing cross entropy loss per-pixel for 2D images. + + The `target` that this criterion expects should contain either: + + - Class indices in the range :math:`[0, C)` where :math:`C` is the number of classes; if + `ignore_index` is specified, this loss also accepts this class index (this index + may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction` + set to ``'none'``) loss for this case can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})} + \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\} + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, + :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as + :math:`d_1, ..., d_k` for the `K`-dimensional case. If + :attr:`reduction` is not ``'none'`` (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + Note that this case is equivalent to applying :class:`~torch.nn.LogSoftmax` + on an input, followed by :class:`~torch.nn.NLLLoss`. + + - Probabilities for each class; useful when labels beyond a single class per minibatch item + are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with + :attr:`reduction` set to ``'none'``) loss for this case can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c} + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, + :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as + :math:`d_1, ..., d_k` for the `K`-dimensional case. If + :attr:`reduction` is not ``'none'`` (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \frac{\sum_{n=1}^N l_n}{N}, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + The performance of this criterion is generally better when `target` contains class + indices, as this allows for optimized computation. Consider providing `target` as + class probabilities only when a single class label per minibatch item is too restrictive. + + Args: + weight (Tensor, optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of size `C`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When :attr:`size_average` is + ``True``, the loss is averaged over non-ignored targets. Note that + :attr:`ignore_index` is only applicable when the target contains class indices. + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. + + Shape: + - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with + :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. The + target data type is required to be long when using class indices. If containing class probabilities, the + target must be the same shape input, and each value should be between :math:`[0, 1]`. This means the target + data type is required to be float when using class probabilities. Note that PyTorch does not strictly enforce + probability constraints on the class probabilities and that it is the user's responsibility to ensure + ``target`` contains valid probability distributions (see below examples section for more details). + - Output: If reduction is 'none', shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar. + + + where: + + .. math:: + \begin{aligned} + C ={} & \text{number of classes} \\ + N ={} & \text{batch size} \\ + \end{aligned} + + Examples: + + >>> # Example of target with class indices + >>> loss = nn.CrossEntropyLoss() + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.empty(3, dtype=torch.long).random_(5) + >>> output = loss(input, target) + >>> output.backward() + >>> + >>> # Example of target with class probabilities + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5).softmax(dim=1) + >>> output = loss(input, target) + >>> output.backward() + + .. note:: + When ``target`` contains class probabilities, it should consist of soft labels—that is, + each ``target`` entry should represent a probability distribution over the possible classes for a given data sample, + with individual probabilities between ``[0,1]`` and the total distribution summing to 1. + This is why the :func:`softmax()` function is applied to the ``target`` in the class probabilities example above. + + PyTorch does not validate whether the values provided in ``target`` lie in the range ``[0,1]`` + or whether the distribution of each data sample sums to ``1``. + No warning will be raised and it is the user's responsibility + to ensure that ``target`` contains valid probability distributions. + Providing arbitrary values may yield misleading loss values and unstable gradients during training. + + Examples: + >>> # xdoctest: +SKIP + >>> # Example of target with incorrectly specified class probabilities + >>> loss = nn.CrossEntropyLoss() + >>> torch.manual_seed(283) + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5) + >>> # Provided target class probabilities are not in range [0,1] + >>> target + tensor([[ 0.7105, 0.4446, 2.0297, 0.2671, -0.6075], + [-1.0496, -0.2753, -0.3586, 0.9270, 1.0027], + [ 0.7551, 0.1003, 1.3468, -0.3581, -0.9569]]) + >>> # Provided target class probabilities do not sum to 1 + >>> target.sum(axis=1) + tensor([2.8444, 0.2462, 0.8873]) + >>> # No error message and possible misleading loss value + >>> loss(input, target).item() + 4.6379876136779785 + >>> + >>> # Example of target with correctly specified class probabilities + >>> # Use .softmax() to ensure true probability distribution + >>> target_new = target.softmax(dim=1) + >>> # New target class probabilities all in range [0,1] + >>> target_new + tensor([[0.1559, 0.1195, 0.5830, 0.1000, 0.0417], + [0.0496, 0.1075, 0.0990, 0.3579, 0.3860], + [0.2607, 0.1355, 0.4711, 0.0856, 0.0471]]) + >>> # New target class probabilities sum to 1 + >>> target_new.sum(axis=1) + tensor([1.0000, 1.0000, 1.0000]) + >>> loss(input, target_new).item() + 2.55349063873291 + """ + + __constants__ = ["ignore_index", "reduction", "label_smoothing"] + ignore_index: int + label_smoothing: float + + def __init__( + self, + weight: Tensor | None = None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + label_smoothing: float = 0.0, + ) -> None: + super().__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.cross_entropy( + input, + target, + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing, + ) + + +class MultiLabelSoftMarginLoss(_WeightedLoss): + r"""Creates a criterion that optimizes a multi-label one-versus-all + loss based on max-entropy, between input :math:`x` and target :math:`y` of size + :math:`(N, C)`. + For each sample in the minibatch: + + .. math:: + loss(x, y) = - \frac{1}{C} * \sum_i y[i] * \log((1 + \exp(-x[i]))^{-1}) + + (1-y[i]) * \log\left(\frac{\exp(-x[i])}{(1 + \exp(-x[i]))}\right) + + where :math:`i \in \left\{0, \; \cdots , \; \text{x.nElement}() - 1\right\}`, + :math:`y[i] \in \left\{0, \; 1\right\}`. + + Args: + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, it has to be a Tensor of size `C`. Otherwise, it is + treated as if having all ones. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, C)` where `N` is the batch size and `C` is the number of classes. + - Target: :math:`(N, C)`, label targets must have the same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`. + """ + + __constants__ = ["reduction"] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.multilabel_soft_margin_loss( + input, target, weight=self.weight, reduction=self.reduction + ) + + +class CosineEmbeddingLoss(_Loss): + r"""Creates a criterion that measures the loss given input tensors + :math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1. + Use (:math:`y=1`) to maximize the cosine similarity of two inputs, and (:math:`y=-1`) otherwise. + This is typically used for learning nonlinear + embeddings or semi-supervised learning. + + The loss function for each sample is: + + .. math:: + \text{loss}(x, y) = + \begin{cases} + 1 - \cos(x_1, x_2), & \text{if } y = 1 \\ + \max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1 + \end{cases} + + Args: + margin (float, optional): Should be a number from :math:`-1` to :math:`1`, + :math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the + default value is :math:`0`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input1: :math:`(N, D)` or :math:`(D)`, where `N` is the batch size and `D` is the embedding dimension. + - Input2: :math:`(N, D)` or :math:`(D)`, same shape as Input1. + - Target: :math:`(N)` or :math:`()`. + - Output: If :attr:`reduction` is ``'none'``, then :math:`(N)`, otherwise scalar. + + Examples: + + >>> loss = nn.CosineEmbeddingLoss() + >>> input1 = torch.randn(3, 5, requires_grad=True) + >>> input2 = torch.randn(3, 5, requires_grad=True) + >>> target = torch.ones(3) + >>> output = loss(input1, input2, target) + >>> output.backward() + """ + + __constants__ = ["margin", "reduction"] + margin: float + + def __init__( + self, + margin: float = 0.0, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.cosine_embedding_loss( + input1, input2, target, margin=self.margin, reduction=self.reduction + ) + + +class MarginRankingLoss(_Loss): + r"""Creates a criterion that measures the loss given + inputs :math:`x1`, :math:`x2`, two 1D mini-batch or 0D `Tensors`, + and a label 1D mini-batch or 0D `Tensor` :math:`y` (containing 1 or -1). + + If :math:`y = 1` then it assumed the first input should be ranked higher + (have a larger value) than the second input, and vice-versa for :math:`y = -1`. + + The loss function for each pair of samples in the mini-batch is: + + .. math:: + \text{loss}(x1, x2, y) = \max(0, -y * (x1 - x2) + \text{margin}) + + Args: + margin (float, optional): Has a default value of :math:`0`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input1: :math:`(N)` or :math:`()` where `N` is the batch size. + - Input2: :math:`(N)` or :math:`()`, same shape as the Input1. + - Target: :math:`(N)` or :math:`()`, same shape as the inputs. + - Output: scalar. If :attr:`reduction` is ``'none'`` and Input size is not :math:`()`, then :math:`(N)`. + + Examples: + + >>> loss = nn.MarginRankingLoss() + >>> input1 = torch.randn(3, requires_grad=True) + >>> input2 = torch.randn(3, requires_grad=True) + >>> target = torch.randn(3).sign() + >>> output = loss(input1, input2, target) + >>> output.backward() + """ + + __constants__ = ["margin", "reduction"] + margin: float + + def __init__( + self, + margin: float = 0.0, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.margin_ranking_loss( + input1, input2, target, margin=self.margin, reduction=self.reduction + ) + + +class MultiMarginLoss(_WeightedLoss): + r"""Creates a criterion that optimizes a multi-class classification hinge + loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and + output :math:`y` (which is a 1D tensor of target class indices, + :math:`0 \leq y \leq \text{x.size}(1)-1`): + + For each mini-batch sample, the loss in terms of the 1D input :math:`x` and scalar + output :math:`y` is: + + .. math:: + \text{loss}(x, y) = \frac{\sum_i \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)} + + where :math:`i \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}` + and :math:`i \neq y`. + + Optionally, you can give non-equal weighting on the classes by passing + a 1D :attr:`weight` tensor into the constructor. + + The loss function then becomes: + + .. math:: + \text{loss}(x, y) = \frac{\sum_i w[y] * \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)} + + Args: + p (int, optional): Has a default value of :math:`1`. :math:`1` and :math:`2` + are the only supported values. + margin (float, optional): Has a default value of :math:`1`. + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, it has to be a Tensor of size `C`. Otherwise, it is + treated as if having all ones. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, C)` or :math:`(C)`, where :math:`N` is the batch size and :math:`C` is the number of classes. + - Target: :math:`(N)` or :math:`()`, where each value is :math:`0 \leq \text{targets}[i] \leq C-1`. + - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the target. + + Examples: + + >>> loss = nn.MultiMarginLoss() + >>> x = torch.tensor([[0.1, 0.2, 0.4, 0.8]]) + >>> y = torch.tensor([3]) + >>> # 0.25 * ((1-(0.8-0.1)) + (1-(0.8-0.2)) + (1-(0.8-0.4))) + >>> loss(x, y) + tensor(0.32...) + """ + + __constants__ = ["p", "margin", "reduction"] + margin: float + p: int + + def __init__( + self, + p: int = 1, + margin: float = 1.0, + weight: Tensor | None = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(weight, size_average, reduce, reduction) + if p != 1 and p != 2: + raise ValueError("only p == 1 and p == 2 supported") + if weight is not None and weight.dim() != 1: + raise ValueError( + f"MultiMarginLoss: expected weight to be None or 1D tensor, got {weight.dim()}D instead" + ) + self.p = p + self.margin = margin + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.multi_margin_loss( + input, + target, + p=self.p, + margin=self.margin, + weight=self.weight, + reduction=self.reduction, + ) + + +class TripletMarginLoss(_Loss): + r"""Creates a criterion that measures the triplet loss given an input + tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`. + This is used for measuring a relative similarity between samples. A triplet + is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative + examples` respectively). The shapes of all input tensors should be + :math:`(N, D)`. + + The distance swap is described in detail in the paper `Learning shallow + convolutional feature descriptors with triplet losses`_ by + V. Balntas, E. Riba et al. + + The loss function for each sample in the mini-batch is: + + .. math:: + L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} + + + where + + .. math:: + d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p + + The norm is calculated using the specified p value and a small constant :math:`\varepsilon` is + added for numerical stability. + + See also :class:`~torch.nn.TripletMarginWithDistanceLoss`, which computes the + triplet margin loss for input tensors using a custom distance function. + + Args: + margin (float, optional): Default: :math:`1`. + p (int, optional): The norm degree for pairwise distance. Default: :math:`2`. + eps (float, optional): Small constant for numerical stability. Default: :math:`1e-6`. + swap (bool, optional): The distance swap is described in detail in the paper + `Learning shallow convolutional feature descriptors with triplet losses` by + V. Balntas, E. Riba et al. Default: ``False``. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, D)` or :math:`(D)` where :math:`D` is the vector dimension. + - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'`` and + input shape is :math:`(N, D)`; a scalar otherwise. + + Examples: + + >>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) + >>> anchor = torch.randn(100, 128, requires_grad=True) + >>> positive = torch.randn(100, 128, requires_grad=True) + >>> negative = torch.randn(100, 128, requires_grad=True) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + + .. _Learning shallow convolutional feature descriptors with triplet losses: + https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html + """ + + __constants__ = ["margin", "p", "eps", "swap", "reduction"] + margin: float + p: float + eps: float + swap: bool + + def __init__( + self, + margin: float = 1.0, + p: float = 2.0, + eps: float = 1e-6, + swap: bool = False, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super().__init__(size_average, reduce, reduction) + if margin <= 0: + raise ValueError( + f"TripletMarginLoss: expected margin to be greater than 0, got {margin} instead" + ) + self.margin = margin + self.p = p + self.eps = eps + self.swap = swap + + def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.triplet_margin_loss( + anchor, + positive, + negative, + margin=self.margin, + p=self.p, + eps=self.eps, + swap=self.swap, + reduction=self.reduction, + ) + + +class TripletMarginWithDistanceLoss(_Loss): + r"""Creates a criterion that measures the triplet loss given input + tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor, + positive, and negative examples, respectively), and a nonnegative, + real-valued function ("distance function") used to compute the relationship + between the anchor and positive example ("positive distance") and the + anchor and negative example ("negative distance"). + + The unreduced loss (i.e., with :attr:`reduction` set to ``'none'``) + can be described as: + + .. math:: + \ell(a, p, n) = L = \{l_1,\dots,l_N\}^\top, \quad + l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} + + where :math:`N` is the batch size; :math:`d` is a nonnegative, real-valued function + quantifying the closeness of two tensors, referred to as the :attr:`distance_function`; + and :math:`margin` is a nonnegative margin representing the minimum difference + between the positive and negative distances that is required for the loss to + be 0. The input tensors have :math:`N` elements each and can be of any shape + that the distance function can handle. + + If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + See also :class:`~torch.nn.TripletMarginLoss`, which computes the triplet + loss for input tensors using the :math:`l_p` distance as the distance function. + + Args: + distance_function (Callable, optional): A nonnegative, real-valued function that + quantifies the closeness of two tensors. If not specified, + `nn.PairwiseDistance` will be used. Default: ``None`` + margin (float, optional): A nonnegative margin representing the minimum difference + between the positive and negative distances required for the loss to be 0. Larger + margins penalize cases where the negative examples are not distant enough from the + anchors, relative to the positives. Default: :math:`1`. + swap (bool, optional): Whether to use the distance swap described in the paper + `Learning shallow convolutional feature descriptors with triplet losses` by + V. Balntas, E. Riba et al. If True, and if the positive example is closer to the + negative example than the anchor is, swaps the positive example and the anchor in + the loss computation. Default: ``False``. + reduction (str, optional): Specifies the (optional) reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` + + + Shape: + - Input: :math:`(N, *)` where :math:`*` represents any number of additional dimensions + as supported by the distance function. + - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar + otherwise. + + Examples: + + >>> # Initialize embeddings + >>> embedding = nn.Embedding(1000, 128) + >>> anchor_ids = torch.randint(0, 1000, (1,)) + >>> positive_ids = torch.randint(0, 1000, (1,)) + >>> negative_ids = torch.randint(0, 1000, (1,)) + >>> anchor = embedding(anchor_ids) + >>> positive = embedding(positive_ids) + >>> negative = embedding(negative_ids) + >>> + >>> # Built-in Distance Function + >>> triplet_loss = \ + >>> nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance()) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + >>> + >>> # Custom Distance Function + >>> def l_infinity(x1, x2): + >>> return torch.max(torch.abs(x1 - x2), dim=1).values + >>> + >>> # xdoctest: +SKIP("FIXME: Would call backwards a second time") + >>> triplet_loss = ( + >>> nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5)) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + >>> + >>> # Custom Distance Function (Lambda) + >>> triplet_loss = ( + >>> nn.TripletMarginWithDistanceLoss( + >>> distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + + Reference: + V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses: + https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html + """ + + __constants__ = ["margin", "swap", "reduction"] + margin: float + swap: bool + + def __init__( + self, + *, + distance_function: Callable[[Tensor, Tensor], Tensor] | None = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", + ) -> None: + super().__init__(size_average=None, reduce=None, reduction=reduction) + if margin <= 0: + raise ValueError( + f"TripletMarginWithDistanceLoss: expected margin to be greater than 0, got {margin} instead" + ) + self.distance_function: Callable[[Tensor, Tensor], Tensor] | None = ( + distance_function if distance_function is not None else PairwiseDistance() + ) + self.margin = margin + self.swap = swap + + def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.triplet_margin_with_distance_loss( + anchor, + positive, + negative, + distance_function=self.distance_function, + margin=self.margin, + swap=self.swap, + reduction=self.reduction, + ) + + +class CTCLoss(_Loss): + r"""The Connectionist Temporal Classification loss. + + Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the + probability of possible alignments of input to target, producing a loss value which is differentiable + with respect to each input node. The alignment of input to target is assumed to be "many-to-one", which + limits the length of the target sequence such that it must be :math:`\leq` the input length. + + Args: + blank (int, optional): blank label. Default :math:`0`. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output losses will be divided by the target lengths and + then the mean over the batch is taken, ``'sum'``: the output losses will be summed. + Default: ``'mean'`` + zero_infinity (bool, optional): + Whether to zero infinite losses and the associated gradients. + Default: ``False`` + Infinite losses mainly occur when the inputs are too short + to be aligned to the targets. + + Shape: + - Log_probs: Tensor of size :math:`(T, N, C)` or :math:`(T, C)`, + where :math:`T = \text{input length}`, + :math:`N = \text{batch size}`, and + :math:`C = \text{number of classes (including blank)}`. + The logarithmized probabilities of the outputs (e.g. obtained with + :func:`torch.nn.functional.log_softmax`). + - Targets: Tensor of size :math:`(N, S)` or + :math:`(\operatorname{sum}(\text{target\_lengths}))`, + where :math:`N = \text{batch size}` and + :math:`S = \text{max target length, if shape is } (N, S)`. + It represents the target sequences. Each element in the target + sequence is a class index. And the target index cannot be blank (default=0). + In the :math:`(N, S)` form, targets are padded to the + length of the longest sequence, and stacked. + In the :math:`(\operatorname{sum}(\text{target\_lengths}))` form, + the targets are assumed to be un-padded and + concatenated within 1 dimension. + - Input_lengths: Tuple or tensor of size :math:`(N)` or :math:`()`, + where :math:`N = \text{batch size}`. It represents the lengths of the + inputs (must each be :math:`\leq T`). And the lengths are specified + for each sequence to achieve masking under the assumption that sequences + are padded to equal lengths. + - Target_lengths: Tuple or tensor of size :math:`(N)` or :math:`()`, + where :math:`N = \text{batch size}`. It represents lengths of the targets. + Lengths are specified for each sequence to achieve masking under the + assumption that sequences are padded to equal lengths. If target shape is + :math:`(N,S)`, target_lengths are effectively the stop index + :math:`s_n` for each target sequence, such that ``target_n = targets[n,0:s_n]`` for + each target in a batch. Lengths must each be :math:`\leq S` + If the targets are given as a 1d tensor that is the concatenation of individual + targets, the target_lengths must add up to the total length of the tensor. + - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or + ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N)` if input is batched or + :math:`()` if input is unbatched, where :math:`N = \text{batch size}`. + + Examples: + + >>> # Target are to be padded + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> N = 16 # Batch size + >>> S = 30 # Target sequence length of longest target in batch (padding length) + >>> S_min = 10 # Minimum target length, for demonstration purposes + >>> + >>> # Initialize random batch of input vectors, for *size = (T,N,C) + >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() + >>> + >>> # Initialize random batch of targets (0 = blank, 1:C = classes) + >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long) + >>> + >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) + >>> target_lengths = torch.randint( + ... low=S_min, + ... high=S, + ... size=(N,), + ... dtype=torch.long, + ... ) + >>> ctc_loss = nn.CTCLoss() + >>> loss = ctc_loss(input, target, input_lengths, target_lengths) + >>> loss.backward() + >>> + >>> + >>> # Target are to be un-padded + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> N = 16 # Batch size + >>> + >>> # Initialize random batch of input vectors, for *size = (T,N,C) + >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() + >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) + >>> + >>> # Initialize random batch of targets (0 = blank, 1:C = classes) + >>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long) + >>> target = torch.randint( + ... low=1, + ... high=C, + ... size=(sum(target_lengths),), + ... dtype=torch.long, + ... ) + >>> ctc_loss = nn.CTCLoss() + >>> loss = ctc_loss(input, target, input_lengths, target_lengths) + >>> loss.backward() + >>> + >>> + >>> # Target are to be un-padded and unbatched (effectively N=1) + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> + >>> # Initialize random batch of input vectors, for *size = (T,C) + >>> # xdoctest: +SKIP("FIXME: error in doctest") + >>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_() + >>> input_lengths = torch.tensor(T, dtype=torch.long) + >>> + >>> # Initialize random batch of targets (0 = blank, 1:C = classes) + >>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long) + >>> target = torch.randint( + ... low=1, + ... high=C, + ... size=(target_lengths,), + ... dtype=torch.long, + ... ) + >>> ctc_loss = nn.CTCLoss() + >>> loss = ctc_loss(input, target, input_lengths, target_lengths) + >>> loss.backward() + + Reference: + A. Graves et al.: Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks: + https://www.cs.toronto.edu/~graves/icml_2006.pdf + + Note: + In order to use CuDNN, the following must be satisfied: the :attr:`targets` must be + in concatenated format, all :attr:`input_lengths` must be `T`. :math:`blank=0`, + :attr:`target_lengths` :math:`\leq 256`, the integer arguments must be of + dtype :attr:`torch.int32`, and the :attr:`log_probs` itself must be of + dtype :attr:`torch.float32`. + + The regular implementation uses the (more common in PyTorch) `torch.long` dtype. + + + Note: + In some circumstances when using the CUDA backend with CuDNN, this operator + may select a nondeterministic algorithm to increase performance. If this is + undesirable, you can try to make the operation deterministic (potentially at + a performance cost) by setting ``torch.backends.cudnn.deterministic = + True``. + Please see the notes on :doc:`/notes/randomness` for background. + """ + + __constants__ = ["blank", "reduction"] + blank: int + zero_infinity: bool + + def __init__( + self, blank: int = 0, reduction: str = "mean", zero_infinity: bool = False + ) -> None: + super().__init__(reduction=reduction) + self.blank = blank + self.zero_infinity = zero_infinity + + def forward( + self, + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + ) -> Tensor: + """Runs the forward pass.""" + return F.ctc_loss( + log_probs, + targets, + input_lengths, + target_lengths, + self.blank, + self.reduction, + self.zero_infinity, + ) + + +# TODO: L1HingeEmbeddingCriterion +# TODO: MSECriterion weight +# TODO: ClassSimplexCriterion diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/module.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/module.py new file mode 100644 index 0000000000000000000000000000000000000000..e9123f76b75c31d71c2e863c2cdb3c87f862291f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/module.py @@ -0,0 +1,3046 @@ +# mypy: allow-untyped-defs + +import functools +import inspect +import itertools +import warnings +import weakref +from collections import namedtuple, OrderedDict +from collections.abc import Callable, Iterator, Mapping +from typing import Any, Optional, overload, TypeVar, Union +from typing_extensions import Self + +import torch +from torch import device, dtype, Tensor +from torch._prims_common import DeviceLikeType +from torch.nn.parameter import Buffer, Parameter +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils.hooks import BackwardHook, RemovableHandle + + +__all__ = [ + "register_module_forward_pre_hook", + "register_module_forward_hook", + "register_module_full_backward_pre_hook", + "register_module_backward_hook", + "register_module_full_backward_hook", + "register_module_buffer_registration_hook", + "register_module_module_registration_hook", + "register_module_parameter_registration_hook", + "Module", +] + +_grad_t = Union[tuple[Tensor, ...], Tensor] +# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use +# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be +# the type of the subclass, not the looser type of `Module`. +T = TypeVar("T", bound="Module") + + +class _IncompatibleKeys( + # pyrefly: ignore [invalid-inheritance] + namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]), +): + __slots__ = () + + def __repr__(self) -> str: + # pyrefly: ignore [missing-attribute] + if not self.missing_keys and not self.unexpected_keys: + return "" + return super().__repr__() + + __str__ = __repr__ + + +def _addindent(s_, numSpaces): + s = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + +r"""This tracks hooks common to all modules that are executed immediately before +.registering the buffer/module/parameter""" +_global_buffer_registration_hooks: dict[int, Callable] = OrderedDict() +_global_module_registration_hooks: dict[int, Callable] = OrderedDict() +_global_parameter_registration_hooks: dict[int, Callable] = OrderedDict() + + +class _WrappedHook: + def __init__(self, hook: Callable, module: Optional["Module"] = None) -> None: + self.hook: Callable = hook + functools.update_wrapper(self, hook) + + self.with_module: bool = False + + if module is not None: + self.module: weakref.ReferenceType[Module] = weakref.ref(module) + self.with_module = True + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if self.with_module: + module = self.module() + if module is None: + raise RuntimeError("You are trying to call the hook of a dead Module!") + return self.hook(module, *args, **kwargs) + return self.hook(*args, **kwargs) + + def __getstate__(self) -> dict: + result = {"hook": self.hook, "with_module": self.with_module} + if self.with_module: + # pyrefly: ignore [unsupported-operation] + result["module"] = self.module() + + return result + + def __setstate__(self, state: dict): + self.hook = state["hook"] + self.with_module = state["with_module"] + + if self.with_module: + if state["module"] is None: + raise RuntimeError( + "You are trying to revive the hook of a dead Module!" + ) + self.module = weakref.ref(state["module"]) + + +r"""This tracks hooks common to all modules that are executed before/after +calling forward and backward. This is global state used for debugging/profiling +purposes""" +_global_backward_pre_hooks: dict[int, Callable] = OrderedDict() +_global_backward_hooks: dict[int, Callable] = OrderedDict() +_global_is_full_backward_hook: bool | None = None +_global_forward_pre_hooks: dict[int, Callable] = OrderedDict() +_global_forward_hooks: dict[int, Callable] = OrderedDict() +_global_forward_hooks_always_called: dict[int, bool] = OrderedDict() +_global_forward_hooks_with_kwargs: dict[int, bool] = OrderedDict() + + +def _has_any_global_hook(): + return ( + _global_backward_pre_hooks + or _global_backward_hooks + or _global_forward_pre_hooks + or _global_forward_hooks + or _global_forward_hooks_always_called + or _global_forward_hooks_with_kwargs + ) + + +_EXTRA_STATE_KEY_SUFFIX = "_extra_state" + + +def register_module_buffer_registration_hook( + hook: Callable[..., None], +) -> RemovableHandle: + r"""Register a buffer registration hook common to all modules. + + .. warning :: + + This adds global state to the `nn.Module` module + + The hook will be called every time :func:`register_buffer` is invoked. + It should have the following signature:: + + hook(module, name, buffer) -> None or new buffer + + The hook can modify the input or return a single modified value in the hook. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(_global_buffer_registration_hooks) + _global_buffer_registration_hooks[handle.id] = hook + return handle + + +def register_module_module_registration_hook( + hook: Callable[..., None], +) -> RemovableHandle: + r"""Register a module registration hook common to all modules. + + .. warning :: + + This adds global state to the `nn.Module` module + + The hook will be called every time :func:`register_module` is invoked. + It should have the following signature:: + + hook(module, name, submodule) -> None or new submodule + + The hook can modify the input or return a single modified value in the hook. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(_global_module_registration_hooks) + _global_module_registration_hooks[handle.id] = hook + return handle + + +def register_module_parameter_registration_hook( + hook: Callable[..., None], +) -> RemovableHandle: + r"""Register a parameter registration hook common to all modules. + + .. warning :: + + This adds global state to the `nn.Module` module + + The hook will be called every time :func:`register_parameter` is invoked. + It should have the following signature:: + + hook(module, name, param) -> None or new parameter + + The hook can modify the input or return a single modified value in the hook. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(_global_parameter_registration_hooks) + _global_parameter_registration_hooks[handle.id] = hook + return handle + + +def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle: + r"""Register a forward pre-hook common to all modules. + + .. warning :: + + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + The hook will be called every time before :func:`forward` is invoked. + It should have the following signature:: + + hook(module, input) -> None or modified input + + The input contains only the positional arguments given to the module. + Keyword arguments won't be passed to the hooks and only to the ``forward``. + The hook can modify the input. User can either return a tuple or a + single modified value in the hook. We will wrap the value into a tuple + if a single value is returned(unless that value is already a tuple). + + This hook has precedence over the specific module hooks registered with + ``register_forward_pre_hook``. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(_global_forward_pre_hooks) + _global_forward_pre_hooks[handle.id] = hook + return handle + + +def register_module_forward_hook( + hook: Callable[..., None], + *, + with_kwargs: bool = False, + always_call: bool = False, +) -> RemovableHandle: + r"""Register a global forward hook for all the modules. + + .. warning :: + + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + The hook will be called every time after :func:`forward` has computed an output. + It should have the following signature:: + + hook(module, input, output) -> None or modified output + + The input contains only the positional arguments given to the module. + Keyword arguments won't be passed to the hooks and only to the ``forward``. + You can optionally modify the output of the module by returning a new value + that will replace the output from the :func:`forward` function. + + Parameters: + hook (Callable): The user defined hook to be registered. + always_call (bool): If ``True`` the ``hook`` will be run regardless of + whether an exception is raised while calling the Module. + Default: ``False`` + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + This hook will be executed before specific module hooks registered with + ``register_forward_hook``. + """ + handle = RemovableHandle( + _global_forward_hooks, extra_dict=_global_forward_hooks_always_called + ) + _global_forward_hooks[handle.id] = hook + if with_kwargs: + _global_forward_hooks_with_kwargs[handle.id] = True + if always_call: + _global_forward_hooks_always_called[handle.id] = True + return handle + + +def register_module_backward_hook( + hook: Callable[["Module", _grad_t, _grad_t], None | _grad_t], +) -> RemovableHandle: + r"""Register a backward hook common to all the modules. + + This function is deprecated in favor of + :func:`torch.nn.modules.module.register_module_full_backward_hook` + and the behavior of this function will change in future versions. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + global _global_is_full_backward_hook + if _global_is_full_backward_hook is True: + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks as a " + "global Module hook. Please use only one of them." + ) + + _global_is_full_backward_hook = False + + handle = RemovableHandle(_global_backward_hooks) + _global_backward_hooks[handle.id] = hook + return handle + + +def register_module_full_backward_pre_hook( + hook: Callable[["Module", _grad_t], None | _grad_t], +) -> RemovableHandle: + r"""Register a backward pre-hook common to all the modules. + + .. warning :: + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + Hooks registered using this function behave in the same way as those + registered by :meth:`torch.nn.Module.register_full_backward_pre_hook`. + Refer to its documentation for more details. + + Hooks registered using this function will be called before hooks registered + using :meth:`torch.nn.Module.register_full_backward_pre_hook`. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + handle = RemovableHandle(_global_backward_pre_hooks) + _global_backward_pre_hooks[handle.id] = hook + return handle + + +def register_module_full_backward_hook( + hook: Callable[["Module", _grad_t, _grad_t], None | _grad_t], +) -> RemovableHandle: + r"""Register a backward hook common to all the modules. + + .. warning :: + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + Hooks registered using this function behave in the same way as those + registered by :meth:`torch.nn.Module.register_full_backward_hook`. + Refer to its documentation for more details. + + Hooks registered using this function will be called before hooks registered + using :meth:`torch.nn.Module.register_full_backward_hook`. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + global _global_is_full_backward_hook + if _global_is_full_backward_hook is False: + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks as a " + "global Module hook. Please use only one of them." + ) + + _global_is_full_backward_hook = True + + handle = RemovableHandle(_global_backward_hooks) + _global_backward_hooks[handle.id] = hook + return handle + + +# Trick mypy into not applying contravariance rules to inputs by defining +# forward as a value, rather than a function. See also +# https://github.com/python/mypy/issues/8795 +def _forward_unimplemented(self, *input: Any) -> None: + r"""Define the computation performed at every call. + + Should be overridden by all subclasses. + + .. note:: + Although the recipe for forward pass needs to be defined within + this function, one should call the :class:`Module` instance afterwards + instead of this since the former takes care of running the + registered hooks while the latter silently ignores them. + """ + raise NotImplementedError( + f'Module [{type(self).__name__}] is missing the required "forward" function' + ) + + +class Module: + r"""Base class for all neural network modules. + + Your models should also subclass this class. + + Modules can also contain other Modules, allowing them to be nested in + a tree structure. You can assign the submodules as regular attributes:: + + import torch.nn as nn + import torch.nn.functional as F + + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + self.conv2 = nn.Conv2d(20, 20, 5) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return F.relu(self.conv2(x)) + + Submodules assigned in this way will be registered, and will also have their + parameters converted when you call :meth:`to`, etc. + + .. note:: + As per the example above, an ``__init__()`` call to the parent class + must be made before assignment on the child. + + :ivar training: Boolean represents whether this module is in training or + evaluation mode. + :vartype training: bool + """ + + dump_patches: bool = False + + _version: int = 1 + r"""This allows better BC support for :meth:`load_state_dict`. In + :meth:`state_dict`, the version number will be saved as in the attribute + `_metadata` of the returned state dict, and thus pickled. `_metadata` is a + dictionary with keys that follow the naming convention of state dict. See + ``_load_from_state_dict`` on how to use this information in loading. + + If new parameters/buffers are added/removed from a module, this number shall + be bumped, and the module's `_load_from_state_dict` method can compare the + version number and do appropriate changes if the state dict is from before + the change.""" + + training: bool + _parameters: dict[str, Parameter | None] + _buffers: dict[str, Tensor | None] + _non_persistent_buffers_set: set[str] + _backward_pre_hooks: dict[int, Callable] + _backward_hooks: dict[int, Callable] + _is_full_backward_hook: bool | None + _forward_hooks: dict[int, Callable] + # Marks whether the corresponding _forward_hooks accept kwargs or not. + # As JIT does not support set[int], this dict is used as a set, where all + # hooks represented in this dict accept kwargs. + _forward_hooks_with_kwargs: dict[int, bool] + # forward hooks that should always be called even if an exception is raised + _forward_hooks_always_called: dict[int, bool] + _forward_pre_hooks: dict[int, Callable] + # Marks whether the corresponding _forward_hooks accept kwargs or not. + # As JIT does not support set[int], this dict is used as a set, where all + # hooks represented in this dict accept kwargs. + _forward_pre_hooks_with_kwargs: dict[int, bool] + _state_dict_hooks: dict[int, Callable] + _load_state_dict_pre_hooks: dict[int, Callable] + _state_dict_pre_hooks: dict[int, Callable] + _load_state_dict_post_hooks: dict[int, Callable] + _modules: dict[str, Optional["Module"]] + call_super_init: bool = False + _compiled_call_impl: Callable | None = None + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize internal Module state, shared by both nn.Module and ScriptModule.""" + torch._C._log_api_usage_once("python.nn_module") + + # Backward compatibility: no args used to be allowed when call_super_init=False + if self.call_super_init is False and bool(kwargs): + raise TypeError( + f"{type(self).__name__}.__init__() got an unexpected keyword argument '{next(iter(kwargs))}'" + "" + ) + + if self.call_super_init is False and bool(args): + raise TypeError( + f"{type(self).__name__}.__init__() takes 1 positional argument but {len(args) + 1} were" + " given" + ) + + """ + Calls super().__setattr__('a', a) instead of the typical self.a = a + to avoid Module.__setattr__ overhead. Module's __setattr__ has special + handling for parameters, submodules, and buffers but simply calls into + super().__setattr__ for all other attributes. + """ + super().__setattr__("training", True) + super().__setattr__("_parameters", {}) + super().__setattr__("_buffers", {}) + super().__setattr__("_non_persistent_buffers_set", set()) + super().__setattr__("_backward_pre_hooks", OrderedDict()) + super().__setattr__("_backward_hooks", OrderedDict()) + super().__setattr__("_is_full_backward_hook", None) + super().__setattr__("_forward_hooks", OrderedDict()) + super().__setattr__("_forward_hooks_with_kwargs", OrderedDict()) + super().__setattr__("_forward_hooks_always_called", OrderedDict()) + super().__setattr__("_forward_pre_hooks", OrderedDict()) + super().__setattr__("_forward_pre_hooks_with_kwargs", OrderedDict()) + super().__setattr__("_state_dict_hooks", OrderedDict()) + super().__setattr__("_state_dict_pre_hooks", OrderedDict()) + super().__setattr__("_load_state_dict_pre_hooks", OrderedDict()) + super().__setattr__("_load_state_dict_post_hooks", OrderedDict()) + super().__setattr__("_modules", {}) + + if self.call_super_init: + super().__init__(*args, **kwargs) + + forward: Callable[..., Any] = _forward_unimplemented + + def register_buffer( + self, name: str, tensor: Tensor | None, persistent: bool = True + ) -> None: + r"""Add a buffer to the module. + + This is typically used to register a buffer that should not be + considered a model parameter. For example, BatchNorm's ``running_mean`` + is not a parameter, but is part of the module's state. Buffers, by + default, are persistent and will be saved alongside parameters. This + behavior can be changed by setting :attr:`persistent` to ``False``. The + only difference between a persistent buffer and a non-persistent buffer + is that the latter will not be a part of this module's + :attr:`state_dict`. + + Buffers can be accessed as attributes using given names. + + Args: + name (str): name of the buffer. The buffer can be accessed + from this module using the given name + tensor (Tensor or None): buffer to be registered. If ``None``, then operations + that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, + the buffer is **not** included in the module's :attr:`state_dict`. + persistent (bool): whether the buffer is part of this module's + :attr:`state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> self.register_buffer('running_mean', torch.zeros(num_features)) + + """ + if persistent is False and isinstance(self, torch.jit.ScriptModule): + raise RuntimeError("ScriptModule does not support non-persistent buffers") + + if "_buffers" not in self.__dict__: + raise AttributeError("cannot assign buffer before Module.__init__() call") + elif not isinstance(name, str): + raise TypeError( + f"buffer name should be a string. Got {torch.typename(name)}" + ) + elif "." in name: + raise KeyError('buffer name can\'t contain "."') + elif name == "": + raise KeyError('buffer name can\'t be empty string ""') + elif hasattr(self, name) and name not in self._buffers: + raise KeyError(f"attribute '{name}' already exists") + elif tensor is not None and not ( + isinstance(tensor, torch.Tensor) or hasattr(tensor, "__torch_function__") + ): + raise TypeError( + f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' " + "(torch Tensor or None required)" + ) + else: + for hook in _global_buffer_registration_hooks.values(): + output = hook(self, name, tensor) + if output is not None: + tensor = output + self._buffers[name] = tensor + if persistent: + self._non_persistent_buffers_set.discard(name) + else: + self._non_persistent_buffers_set.add(name) + + def register_parameter(self, name: str, param: Parameter | None) -> None: + r"""Add a parameter to the module. + + The parameter can be accessed as an attribute using given name. + + Args: + name (str): name of the parameter. The parameter can be accessed + from this module using the given name + param (Parameter or None): parameter to be added to the module. If + ``None``, then operations that run on parameters, such as :attr:`cuda`, + are ignored. If ``None``, the parameter is **not** included in the + module's :attr:`state_dict`. + """ + if "_parameters" not in self.__dict__: + raise AttributeError( + "cannot assign parameter before Module.__init__() call" + ) + + elif not isinstance(name, str): + raise TypeError( + f"parameter name should be a string. Got {torch.typename(name)}" + ) + elif "." in name: + raise KeyError('parameter name can\'t contain "."') + elif name == "": + raise KeyError('parameter name can\'t be empty string ""') + elif hasattr(self, name) and name not in self._parameters: + raise KeyError(f"attribute '{name}' already exists") + + if param is None: + self._parameters[name] = None + elif not isinstance(param, Parameter): + raise TypeError( + f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " + "(torch.nn.Parameter or None required)" + ) + elif param.grad_fn: + raise ValueError( + f"Cannot assign non-leaf Tensor to parameter '{name}'. Model " + f"parameters must be created explicitly. To express '{name}' " + "as a function of another Tensor, compute the value in " + "the forward() method." + ) + else: + for hook in _global_parameter_registration_hooks.values(): + output = hook(self, name, param) + if output is not None: + param = output + self._parameters[name] = param + + def add_module(self, name: str, module: Optional["Module"]) -> None: + r"""Add a child module to the current module. + + The module can be accessed as an attribute using the given name. + + Args: + name (str): name of the child module. The child module can be + accessed from this module using the given name + module (Module): child module to be added to the module. + """ + if not isinstance(module, Module) and module is not None: + raise TypeError(f"{torch.typename(module)} is not a Module subclass") + elif not isinstance(name, str): + raise TypeError( + f"module name should be a string. Got {torch.typename(name)}" + ) + elif hasattr(self, name) and name not in self._modules: + raise KeyError(f"attribute '{name}' already exists") + elif "." in name: + raise KeyError(f'module name can\'t contain ".", got: {name}') + elif name == "": + raise KeyError('module name can\'t be empty string ""') + for hook in _global_module_registration_hooks.values(): + output = hook(self, name, module) + if output is not None: + module = output + self._modules[name] = module + + def register_module(self, name: str, module: Optional["Module"]) -> None: + r"""Alias for :func:`add_module`.""" + self.add_module(name, module) + + def get_submodule(self, target: str) -> "Module": + """Return the submodule given by ``target`` if it exists, otherwise throw an error. + + For example, let's say you have an ``nn.Module`` ``A`` that + looks like this: + + .. code-block:: text + + A( + (net_b): Module( + (net_c): Module( + (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) + ) + (linear): Linear(in_features=100, out_features=200, bias=True) + ) + ) + + (The diagram shows an ``nn.Module`` ``A``. ``A`` which has a nested + submodule ``net_b``, which itself has two submodules ``net_c`` + and ``linear``. ``net_c`` then has a submodule ``conv``.) + + To check whether or not we have the ``linear`` submodule, we + would call ``get_submodule("net_b.linear")``. To check whether + we have the ``conv`` submodule, we would call + ``get_submodule("net_b.net_c.conv")``. + + The runtime of ``get_submodule`` is bounded by the degree + of module nesting in ``target``. A query against + ``named_modules`` achieves the same result, but it is O(N) in + the number of transitive modules. So, for a simple check to see + if some submodule exists, ``get_submodule`` should always be + used. + + Args: + target: The fully-qualified string name of the submodule + to look for. (See above example for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Module: The submodule referenced by ``target`` + + Raises: + AttributeError: If at any point along the path resulting from + the target string the (sub)path resolves to a non-existent + attribute name or an object that is not an instance of ``nn.Module``. + """ + if target == "": + return self + + atoms: list[str] = target.split(".") + mod: torch.nn.Module = self + + for item in atoms: + if not hasattr(mod, item): + raise AttributeError( + mod._get_name() + " has no attribute `" + item + "`" + ) + + mod = getattr(mod, item) + + if not isinstance(mod, torch.nn.Module): + raise AttributeError("`" + item + "` is not an nn.Module") + + return mod + + def set_submodule( + self, target: str, module: "Module", strict: bool = False + ) -> None: + """ + Set the submodule given by ``target`` if it exists, otherwise throw an error. + + .. note:: + If ``strict`` is set to ``False`` (default), the method will replace an existing submodule + or create a new submodule if the parent module exists. If ``strict`` is set to ``True``, + the method will only attempt to replace an existing submodule and throw an error if + the submodule does not exist. + + For example, let's say you have an ``nn.Module`` ``A`` that + looks like this: + + .. code-block:: text + + A( + (net_b): Module( + (net_c): Module( + (conv): Conv2d(3, 3, 3) + ) + (linear): Linear(3, 3) + ) + ) + + (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested + submodule ``net_b``, which itself has two submodules ``net_c`` + and ``linear``. ``net_c`` then has a submodule ``conv``.) + + To override the ``Conv2d`` with a new submodule ``Linear``, you + could call ``set_submodule("net_b.net_c.conv", nn.Linear(1, 1))`` + where ``strict`` could be ``True`` or ``False`` + + To add a new submodule ``Conv2d`` to the existing ``net_b`` module, + you would call ``set_submodule("net_b.conv", nn.Conv2d(1, 1, 1))``. + + In the above if you set ``strict=True`` and call + ``set_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True)``, an AttributeError + will be raised because ``net_b`` does not have a submodule named ``conv``. + + Args: + target: The fully-qualified string name of the submodule + to look for. (See above example for how to specify a + fully-qualified string.) + module: The module to set the submodule to. + strict: If ``False``, the method will replace an existing submodule + or create a new submodule if the parent module exists. If ``True``, + the method will only attempt to replace an existing submodule and throw an error + if the submodule doesn't already exist. + + Raises: + ValueError: If the ``target`` string is empty or if ``module`` is not an instance of ``nn.Module``. + AttributeError: If at any point along the path resulting from + the ``target`` string the (sub)path resolves to a non-existent + attribute name or an object that is not an instance of ``nn.Module``. + """ + if target == "": + raise ValueError("Cannot set the submodule without a target name!") + + atoms: list[str] = target.split(".") + if not isinstance(module, torch.nn.Module): + raise ValueError( + "`" + "module" + f"` is not an nn.Module, found {type(module)}" + ) + if len(atoms) == 1: + parent: torch.nn.Module = self + else: + parent_key = ".".join(atoms[:-1]) + parent = self.get_submodule(parent_key) + + if strict and not hasattr(parent, atoms[-1]): + raise AttributeError( + parent._get_name() + " has no attribute `" + atoms[-1] + "`" + ) + if hasattr(parent, atoms[-1]): + mod = getattr(parent, atoms[-1]) + if not isinstance(mod, torch.nn.Module): + raise AttributeError("`" + atoms[-1] + "` is not an nn.Module") + setattr(parent, atoms[-1], module) + + def get_parameter(self, target: str) -> "Parameter": + """Return the parameter given by ``target`` if it exists, otherwise throw an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the Parameter + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Parameter: The Parameter referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Parameter`` + """ + module_path, _, param_name = target.rpartition(".") + + mod: torch.nn.Module = self.get_submodule(module_path) + + if not hasattr(mod, param_name): + raise AttributeError( + mod._get_name() + " has no attribute `" + param_name + "`" + ) + + param: torch.nn.Parameter = getattr(mod, param_name) + + if not isinstance(param, torch.nn.Parameter): + raise AttributeError("`" + param_name + "` is not an nn.Parameter") + + return param + + def get_buffer(self, target: str) -> "Tensor": + """Return the buffer given by ``target`` if it exists, otherwise throw an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the buffer + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.Tensor: The buffer referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not a + buffer + """ + module_path, _, buffer_name = target.rpartition(".") + + mod: torch.nn.Module = self.get_submodule(module_path) + + if not hasattr(mod, buffer_name): + raise AttributeError( + mod._get_name() + " has no attribute `" + buffer_name + "`" + ) + + buffer: torch.Tensor = getattr(mod, buffer_name) + + if buffer_name not in mod._buffers: + raise AttributeError("`" + buffer_name + "` is not a buffer") + + return buffer + + def get_extra_state(self) -> Any: + """Return any extra state to include in the module's state_dict. + + Implement this and a corresponding :func:`set_extra_state` for your module + if you need to store extra state. This function is called when building the + module's `state_dict()`. + + Note that extra state should be picklable to ensure working serialization + of the state_dict. We only provide backwards compatibility guarantees + for serializing Tensors; other objects may break backwards compatibility if + their serialized pickled form changes. + + Returns: + object: Any extra state to store in the module's state_dict + """ + raise RuntimeError( + "Reached a code path in Module.get_extra_state() that should never be called. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "to report this bug." + ) + + def set_extra_state(self, state: Any) -> None: + """Set extra state contained in the loaded `state_dict`. + + This function is called from :func:`load_state_dict` to handle any extra state + found within the `state_dict`. Implement this function and a corresponding + :func:`get_extra_state` for your module if you need to store extra state within its + `state_dict`. + + Args: + state (dict): Extra state from the `state_dict` + """ + raise RuntimeError( + "Reached a code path in Module.set_extra_state() that should never be called. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "to report this bug." + ) + + def _apply(self, fn, recurse=True): + if recurse: + for module in self.children(): + module._apply(fn) + + from torch._subclasses.fake_tensor import FakeTensor + + def compute_should_use_set_data(tensor, tensor_applied) -> bool: + if torch._has_compatible_shallow_copy_type( + tensor, tensor_applied + ) and not isinstance(tensor_applied, FakeTensor): + # If the new tensor has compatible tensor type as the existing tensor, + # the current behavior is to change the tensor in-place using `.data =`, + # and the future behavior is to overwrite the existing tensor. However, + # changing the current behavior is a BC-breaking change, and we want it + # to happen in future releases. So for now we introduce the + # `torch.__future__.get_overwrite_module_params_on_conversion()` + # global flag to let the user control whether they want the future + # behavior of overwriting the existing tensor or not. + return not torch.__future__.get_overwrite_module_params_on_conversion() + else: + return False + + should_use_swap_tensors = ( + torch.__future__.get_swap_module_params_on_conversion() + ) + + for key, param in self._parameters.items(): + if param is None: + continue + # Tensors stored in modules are graph leaves, and we don't want to + # track autograd history of `param_applied`, so we have to use + # `with torch.no_grad():` + with torch.no_grad(): + param_applied = fn(param) + p_should_use_set_data = compute_should_use_set_data(param, param_applied) + + # subclasses may have multiple child tensors so we need to use swap_tensors + p_should_use_swap_tensors = ( + should_use_swap_tensors + or is_traceable_wrapper_subclass(param_applied) + or isinstance(param, FakeTensor) + ) + + param_grad = param.grad + if p_should_use_swap_tensors: + try: + if param_grad is not None: + # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping. + # Decrement use count of the gradient by setting to None + param.grad = None + param_applied = torch.nn.Parameter( + # pyrefly: ignore [bad-argument-type] + param_applied, + requires_grad=param.requires_grad, + ) + torch.utils.swap_tensors(param, param_applied) + except Exception as e: + if param_grad is not None: + param.grad = param_grad + raise RuntimeError( + f"_apply(): Couldn't swap {self._get_name()}.{key}" + ) from e + out_param = param + elif p_should_use_set_data: + # pyrefly: ignore [bad-assignment] + param.data = param_applied + out_param = param + else: + assert isinstance(param, Parameter) + assert param.is_leaf + # pyrefly: ignore [bad-argument-type] + out_param = Parameter(param_applied, param.requires_grad) + self._parameters[key] = out_param + + if param_grad is not None: + with torch.no_grad(): + grad_applied = fn(param_grad) + g_should_use_set_data = compute_should_use_set_data( + param_grad, grad_applied + ) + if p_should_use_swap_tensors: + grad_applied.requires_grad_(param_grad.requires_grad) + try: + torch.utils.swap_tensors(param_grad, grad_applied) + except Exception as e: + raise RuntimeError( + f"_apply(): Couldn't swap {self._get_name()}.{key}.grad" + ) from e + out_param.grad = param_grad + elif g_should_use_set_data: + assert out_param.grad is not None + out_param.grad.data = grad_applied + else: + assert param_grad.is_leaf + out_param.grad = grad_applied.requires_grad_( + param_grad.requires_grad + ) + + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + + return self + + def apply(self, fn: Callable[["Module"], None]) -> Self: + r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. + + Typical use includes initializing the parameters of a model + (see also :ref:`nn-init-doc`). + + Args: + fn (:class:`Module` -> None): function to be applied to each submodule + + Returns: + Module: self + + Example:: + + >>> @torch.no_grad() + >>> def init_weights(m): + >>> print(m) + >>> if type(m) is nn.Linear: + >>> m.weight.fill_(1.0) + >>> print(m.weight) + >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) + >>> net.apply(init_weights) + Linear(in_features=2, out_features=2, bias=True) + Parameter containing: + tensor([[1., 1.], + [1., 1.]], requires_grad=True) + Linear(in_features=2, out_features=2, bias=True) + Parameter containing: + tensor([[1., 1.], + [1., 1.]], requires_grad=True) + Sequential( + (0): Linear(in_features=2, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=2, bias=True) + ) + + """ + for module in self.children(): + module.apply(fn) + fn(self) + return self + + def cuda(self, device: int | device | None = None) -> Self: + r"""Move all model parameters and buffers to the GPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing the optimizer if the module will + live on GPU while being optimized. + + .. note:: + This method modifies the module in-place. + + Args: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + return self._apply(lambda t: t.cuda(device)) + + def ipu(self, device: int | device | None = None) -> Self: + r"""Move all model parameters and buffers to the IPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing the optimizer if the module will + live on IPU while being optimized. + + .. note:: + This method modifies the module in-place. + + Arguments: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + return self._apply(lambda t: t.ipu(device)) + + def xpu(self, device: int | device | None = None) -> Self: + r"""Move all model parameters and buffers to the XPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on XPU while being optimized. + + .. note:: + This method modifies the module in-place. + + Arguments: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + return self._apply(lambda t: t.xpu(device)) + + def mtia(self, device: int | device | None = None) -> Self: + r"""Move all model parameters and buffers to the MTIA. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing the optimizer if the module will + live on MTIA while being optimized. + + .. note:: + This method modifies the module in-place. + + Arguments: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + return self._apply(lambda t: t.mtia(device)) + + def cpu(self) -> Self: + r"""Move all model parameters and buffers to the CPU. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.cpu()) + + def type(self, dst_type: dtype | str) -> Self: + r"""Casts all parameters and buffers to :attr:`dst_type`. + + .. note:: + This method modifies the module in-place. + + Args: + dst_type (type or string): the desired type + + Returns: + Module: self + """ + return self._apply(lambda t: t.type(dst_type)) + + def float(self) -> Self: + r"""Casts all floating point parameters and buffers to ``float`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.float() if t.is_floating_point() else t) + + def double(self) -> Self: + r"""Casts all floating point parameters and buffers to ``double`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.double() if t.is_floating_point() else t) + + def half(self) -> Self: + r"""Casts all floating point parameters and buffers to ``half`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.half() if t.is_floating_point() else t) + + def bfloat16(self) -> Self: + r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) + + def to_empty(self, *, device: DeviceLikeType | None, recurse: bool = True) -> Self: + r"""Move the parameters and buffers to the specified device without copying storage. + + Args: + device (:class:`torch.device`): The desired device of the parameters + and buffers in this module. + recurse (bool): Whether parameters and buffers of submodules should + be recursively moved to the specified device. + + Returns: + Module: self + """ + return self._apply( + lambda t: torch.empty_like(t, device=device), recurse=recurse + ) + + @overload + def to( + self, + device: DeviceLikeType | None = ..., + dtype: dtype | None = ..., + non_blocking: bool = ..., + ) -> Self: ... + + @overload + def to(self, dtype: dtype, non_blocking: bool = ...) -> Self: ... + + @overload + def to(self, tensor: Tensor, non_blocking: bool = ...) -> Self: ... + + def to(self, *args, **kwargs): + r"""Move and/or cast the parameters and buffers. + + This can be called as + + .. function:: to(device=None, dtype=None, non_blocking=False) + :noindex: + + .. function:: to(dtype, non_blocking=False) + :noindex: + + .. function:: to(tensor, non_blocking=False) + :noindex: + + .. function:: to(memory_format=torch.channels_last) + :noindex: + + Its signature is similar to :meth:`torch.Tensor.to`, but only accepts + floating point or complex :attr:`dtype`\ s. In addition, this method will + only cast the floating point or complex parameters and buffers to :attr:`dtype` + (if given). The integral parameters and buffers will be moved + :attr:`device`, if that is given, but with dtypes unchanged. When + :attr:`non_blocking` is set, it tries to convert/move asynchronously + with respect to the host if possible, e.g., moving CPU Tensors with + pinned memory to CUDA devices. + + See below for examples. + + .. note:: + This method modifies the module in-place. + + Args: + device (:class:`torch.device`): the desired device of the parameters + and buffers in this module + dtype (:class:`torch.dtype`): the desired floating point or complex dtype of + the parameters and buffers in this module + tensor (torch.Tensor): Tensor whose dtype and device are the desired + dtype and device for all parameters and buffers in this module + memory_format (:class:`torch.memory_format`): the desired memory + format for 4D parameters and buffers in this module (keyword + only argument) + + Returns: + Module: self + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> linear = nn.Linear(2, 2) + >>> linear.weight + Parameter containing: + tensor([[ 0.1913, -0.3420], + [-0.5113, -0.2325]]) + >>> linear.to(torch.double) + Linear(in_features=2, out_features=2, bias=True) + >>> linear.weight + Parameter containing: + tensor([[ 0.1913, -0.3420], + [-0.5113, -0.2325]], dtype=torch.float64) + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) + >>> gpu1 = torch.device("cuda:1") + >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) + Linear(in_features=2, out_features=2, bias=True) + >>> linear.weight + Parameter containing: + tensor([[ 0.1914, -0.3420], + [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') + >>> cpu = torch.device("cpu") + >>> linear.to(cpu) + Linear(in_features=2, out_features=2, bias=True) + >>> linear.weight + Parameter containing: + tensor([[ 0.1914, -0.3420], + [-0.5112, -0.2324]], dtype=torch.float16) + + >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) + >>> linear.weight + Parameter containing: + tensor([[ 0.3741+0.j, 0.2382+0.j], + [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) + >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) + tensor([[0.6122+0.j, 0.1150+0.j], + [0.6122+0.j, 0.1150+0.j], + [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) + + """ + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + # pyrefly: ignore [not-iterable] + *args, + **kwargs, + ) + + if dtype is not None: + if not (dtype.is_floating_point or dtype.is_complex): + raise TypeError( + "nn.Module.to only accepts floating point or complex " + f"dtypes, but got desired dtype={dtype}" + ) + if dtype.is_complex: + warnings.warn( + "Complex modules are a new feature under active development whose design may change, " + "and some modules might not work as expected when using complex tensors as parameters or buffers. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "if a complex module does not work as expected.", + stacklevel=2, + ) + + def convert(t): + try: + if convert_to_format is not None and t.dim() in (4, 5): + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, + memory_format=convert_to_format, + ) + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, + ) + except NotImplementedError as e: + if str(e) == "Cannot copy out of meta tensor; no data!": + raise NotImplementedError( + f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() " + f"when moving module from meta to a different device." + ) from None + else: + raise + + return self._apply(convert) + + def register_full_backward_pre_hook( + self, + hook: Callable[["Module", _grad_t], None | _grad_t], + prepend: bool = False, + ) -> RemovableHandle: + r"""Register a backward pre-hook on the module. + + The hook will be called every time the gradients for the module are computed. + The hook should have the following signature:: + + hook(module, grad_output) -> tuple[Tensor, ...], Tensor or None + + The :attr:`grad_output` is a tuple. The hook should + not modify its arguments, but it can optionally return a new gradient with + respect to the output that will be used in place of :attr:`grad_output` in + subsequent computations. Entries in :attr:`grad_output` will be ``None`` for + all non-Tensor arguments. + + For technical reasons, when this hook is applied to a Module, its forward function will + receive a view of each Tensor passed to the Module. Similarly the caller will receive a view + of each Tensor returned by the Module's forward function. + + .. warning :: + Modifying inputs inplace is not allowed when using backward hooks and + will raise an error. + + Args: + hook (Callable): The user-defined hook to be registered. + prepend (bool): If true, the provided ``hook`` will be fired before + all existing ``backward_pre`` hooks on this + :class:`torch.nn.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``backward_pre`` hooks + on this :class:`torch.nn.Module`. Note that global + ``backward_pre`` hooks registered with + :func:`register_module_full_backward_pre_hook` will fire before + all hooks registered by this method. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + handle = RemovableHandle(self._backward_pre_hooks) + self._backward_pre_hooks[handle.id] = hook + if prepend: + self._backward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def register_backward_hook( + self, hook: Callable[["Module", _grad_t, _grad_t], None | _grad_t] + ) -> RemovableHandle: + r"""Register a backward hook on the module. + + This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and + the behavior of this function will change in future versions. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + if self._is_full_backward_hook is True: + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks on a " + "single Module. Please use only one of them." + ) + + self._is_full_backward_hook = False + + handle = RemovableHandle(self._backward_hooks) + self._backward_hooks[handle.id] = hook + return handle + + def register_full_backward_hook( + self, + hook: Callable[["Module", _grad_t, _grad_t], None | _grad_t], + prepend: bool = False, + ) -> RemovableHandle: + r"""Register a backward hook on the module. + + The hook will be called every time the gradients with respect to a module are computed, and its firing rules are as follows: + + 1. Ordinarily, the hook fires when the gradients are computed with respect to the module inputs. + 2. If none of the module inputs require gradients, the hook will fire when the gradients are computed + with respect to module outputs. + 3. If none of the module outputs require gradients, then the hooks will not fire. + + The hook should have the following signature:: + + hook(module, grad_input, grad_output) -> tuple(Tensor) or None + + The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients + with respect to the inputs and outputs respectively. The hook should + not modify its arguments, but it can optionally return a new gradient with + respect to the input that will be used in place of :attr:`grad_input` in + subsequent computations. :attr:`grad_input` will only correspond to the inputs given + as positional arguments and all kwarg arguments are ignored. Entries + in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor + arguments. + + For technical reasons, when this hook is applied to a Module, its forward function will + receive a view of each Tensor passed to the Module. Similarly the caller will receive a view + of each Tensor returned by the Module's forward function. + + .. warning :: + Modifying inputs or outputs inplace is not allowed when using backward hooks and + will raise an error. + + Args: + hook (Callable): The user-defined hook to be registered. + prepend (bool): If true, the provided ``hook`` will be fired before + all existing ``backward`` hooks on this + :class:`torch.nn.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``backward`` hooks on + this :class:`torch.nn.Module`. Note that global + ``backward`` hooks registered with + :func:`register_module_full_backward_hook` will fire before + all hooks registered by this method. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + if self._is_full_backward_hook is False: + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks on a " + "single Module. Please use only one of them." + ) + + self._is_full_backward_hook = True + + handle = RemovableHandle(self._backward_hooks) + self._backward_hooks[handle.id] = hook + if prepend: + self._backward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def _get_backward_hooks(self): + r"""Return the backward hooks for use in the call function. + + It returns two lists, one with the full backward hooks and one with the non-full + backward hooks. + """ + full_backward_hooks: list[Callable] = [] + if _global_is_full_backward_hook is True: + full_backward_hooks += _global_backward_hooks.values() + if self._is_full_backward_hook is True: + full_backward_hooks += self._backward_hooks.values() + + non_full_backward_hooks: list[Callable] = [] + if _global_is_full_backward_hook is False: + non_full_backward_hooks += _global_backward_hooks.values() + if self._is_full_backward_hook is False: + non_full_backward_hooks += self._backward_hooks.values() + + return full_backward_hooks, non_full_backward_hooks + + def _get_backward_pre_hooks(self): + backward_pre_hooks: list[Callable] = [] + backward_pre_hooks += _global_backward_pre_hooks.values() + backward_pre_hooks += self._backward_pre_hooks.values() + + return backward_pre_hooks + + def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn) -> None: + if not isinstance(result, torch.Tensor): + if not ( + isinstance(result, tuple) + and all(isinstance(r, torch.Tensor) for r in result) + ): + warnings.warn( + "Using non-full backward hooks on a Module that does not return a " + "single Tensor or a tuple of Tensors is deprecated and will be removed " + "in future versions. This hook will be missing some of the grad_output. " + "Please use register_full_backward_hook to get the documented behavior.", + FutureWarning, + stacklevel=2, + ) + return + else: + result = (result,) + + if not isinstance(inputs, torch.Tensor): + if not ( + isinstance(inputs, tuple) + and all(isinstance(i, torch.Tensor) for i in inputs) + ): + warnings.warn( + "Using non-full backward hooks on a Module that does not take as input a " + "single Tensor or a tuple of Tensors is deprecated and will be removed " + "in future versions. This hook will be missing some of the grad_input. " + "Please use register_full_backward_hook to get the documented behavior.", + FutureWarning, + stacklevel=2, + ) + return + else: + inputs = (inputs,) + + # At this point we are sure that inputs and result are tuple of Tensors + out_grad_fn = {r.grad_fn for r in result if r.grad_fn is not None} + if len(out_grad_fn) == 0 or ( + len(out_grad_fn) == 1 and grad_fn not in out_grad_fn + ): + warnings.warn( + "Using a non-full backward hook when outputs are nested in python data structure " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_output.", + FutureWarning, + stacklevel=2, + ) + elif len(out_grad_fn) > 1: + warnings.warn( + "Using a non-full backward hook when outputs are generated by different autograd Nodes " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_output. Please use register_full_backward_hook to get the documented behavior.", + FutureWarning, + stacklevel=2, + ) + else: + # At this point the grad_output part of the hook will most likely be correct + inputs_grad_fn = {i.grad_fn for i in inputs if i.grad_fn is not None} + + next_functions = {n[0] for n in grad_fn.next_functions} + + if inputs_grad_fn != next_functions: + warnings.warn( + "Using a non-full backward hook when the forward contains multiple autograd Nodes " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_input. Please use register_full_backward_hook to get the documented " + "behavior.", + FutureWarning, + stacklevel=2, + ) + + def register_forward_pre_hook( + self, + hook: Callable[[T, tuple[Any, ...]], Any | None] + | Callable[ + [T, tuple[Any, ...], dict[str, Any]], tuple[Any, dict[str, Any]] | None + ], + *, + prepend: bool = False, + with_kwargs: bool = False, + ) -> RemovableHandle: + r"""Register a forward pre-hook on the module. + + The hook will be called every time before :func:`forward` is invoked. + + + If ``with_kwargs`` is false or not specified, the input contains only + the positional arguments given to the module. Keyword arguments won't be + passed to the hooks and only to the ``forward``. The hook can modify the + input. User can either return a tuple or a single modified value in the + hook. We will wrap the value into a tuple if a single value is returned + (unless that value is already a tuple). The hook should have the + following signature:: + + hook(module, args) -> None or modified input + + If ``with_kwargs`` is true, the forward pre-hook will be passed the + kwargs given to the forward function. And if the hook modifies the + input, both the args and kwargs should be returned. The hook should have + the following signature:: + + hook(module, args, kwargs) -> None or a tuple of modified input and kwargs + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If true, the provided ``hook`` will be fired before + all existing ``forward_pre`` hooks on this + :class:`torch.nn.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``forward_pre`` hooks + on this :class:`torch.nn.Module`. Note that global + ``forward_pre`` hooks registered with + :func:`register_module_forward_pre_hook` will fire before all + hooks registered by this method. + Default: ``False`` + with_kwargs (bool): If true, the ``hook`` will be passed the kwargs + given to the forward function. + Default: ``False`` + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle( + self._forward_pre_hooks, extra_dict=self._forward_pre_hooks_with_kwargs + ) + self._forward_pre_hooks[handle.id] = hook + if with_kwargs: + self._forward_pre_hooks_with_kwargs[handle.id] = True + + if prepend: + self._forward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def register_forward_hook( + self, + hook: Callable[[T, tuple[Any, ...], Any], Any | None] + | Callable[[T, tuple[Any, ...], dict[str, Any], Any], Any | None], + *, + prepend: bool = False, + with_kwargs: bool = False, + always_call: bool = False, + ) -> RemovableHandle: + r"""Register a forward hook on the module. + + The hook will be called every time after :func:`forward` has computed an output. + + If ``with_kwargs`` is ``False`` or not specified, the input contains only + the positional arguments given to the module. Keyword arguments won't be + passed to the hooks and only to the ``forward``. The hook can modify the + output. It can modify the input inplace but it will not have effect on + forward since this is called after :func:`forward` is called. The hook + should have the following signature:: + + hook(module, args, output) -> None or modified output + + If ``with_kwargs`` is ``True``, the forward hook will be passed the + ``kwargs`` given to the forward function and be expected to return the + output possibly modified. The hook should have the following signature:: + + hook(module, args, kwargs, output) -> None or modified output + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If ``True``, the provided ``hook`` will be fired + before all existing ``forward`` hooks on this + :class:`torch.nn.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``forward`` hooks on + this :class:`torch.nn.Module`. Note that global + ``forward`` hooks registered with + :func:`register_module_forward_hook` will fire before all hooks + registered by this method. + Default: ``False`` + with_kwargs (bool): If ``True``, the ``hook`` will be passed the + kwargs given to the forward function. + Default: ``False`` + always_call (bool): If ``True`` the ``hook`` will be run regardless of + whether an exception is raised while calling the Module. + Default: ``False`` + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle( + self._forward_hooks, + extra_dict=[ + self._forward_hooks_with_kwargs, + self._forward_hooks_always_called, + ], + ) + self._forward_hooks[handle.id] = hook + if with_kwargs: + self._forward_hooks_with_kwargs[handle.id] = True + if always_call: + self._forward_hooks_always_called[handle.id] = True + if prepend: + self._forward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def _slow_forward(self, *input, **kwargs): + tracing_state = torch._C._get_tracing_state() + if not tracing_state or isinstance(self.forward, torch._C.ScriptMethod): + return self.forward(*input, **kwargs) + recording_scopes = torch.jit._trace._trace_module_map is not None + if recording_scopes: + # type ignore was added because at this point one knows that + # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any] + name = torch.jit._trace._trace_module_map.get(self, None) # type: ignore[operator, union-attr] + if name: + tracing_state.push_scope(name) + else: + recording_scopes = False + try: + result = self.forward(*input, **kwargs) + finally: + if recording_scopes: + tracing_state.pop_scope() + return result + + def _wrapped_call_impl(self, *args, **kwargs): + if self._compiled_call_impl is not None: + return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] + else: + return self._call_impl(*args, **kwargs) + + # torchrec tests the code consistency with the following code + # fmt: off + def _call_impl(self, *args, **kwargs): + forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward) + # If we don't have any hooks, we want to skip the rest of the logic in + # this function, and just call forward. + if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks + or _global_backward_pre_hooks or _global_backward_hooks + or _global_forward_hooks or _global_forward_pre_hooks): + return forward_call(*args, **kwargs) + + result = None + called_always_called_hooks = set() + + def inner(): + nonlocal result, args, kwargs + + full_backward_hooks, non_full_backward_hooks = [], [] + backward_pre_hooks = [] + if self._backward_pre_hooks or _global_backward_pre_hooks: + backward_pre_hooks = self._get_backward_pre_hooks() + + if self._backward_hooks or _global_backward_hooks: + full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() + + if _global_forward_pre_hooks or self._forward_pre_hooks: + for hook_id, hook in ( + *_global_forward_pre_hooks.items(), + *self._forward_pre_hooks.items(), + ): + if hook_id in self._forward_pre_hooks_with_kwargs: + args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] + if args_kwargs_result is not None: + if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2: + args, kwargs = args_kwargs_result + else: + raise RuntimeError( + "forward pre-hook must return None or a tuple " + f"of (new_args, new_kwargs), but got {args_kwargs_result}." + ) + else: + args_result = hook(self, args) + if args_result is not None: + if not isinstance(args_result, tuple): + args_result = (args_result,) + args = args_result + + bw_hook = None + if full_backward_hooks or backward_pre_hooks: + bw_hook = BackwardHook(self, full_backward_hooks, backward_pre_hooks) + args = bw_hook.setup_input_hook(args) + + result = forward_call(*args, **kwargs) + if _global_forward_hooks or self._forward_hooks: + for hook_id, hook in ( + *_global_forward_hooks.items(), + *self._forward_hooks.items(), + ): + # mark that always called hook is run + if hook_id in self._forward_hooks_always_called or hook_id in _global_forward_hooks_always_called: + called_always_called_hooks.add(hook_id) + + if hook_id in self._forward_hooks_with_kwargs or hook_id in _global_forward_hooks_with_kwargs: + hook_result = hook(self, args, kwargs, result) + else: + hook_result = hook(self, args, result) + + if hook_result is not None: + result = hook_result + + if bw_hook: + if not isinstance(result, (torch.Tensor, tuple)): + warnings.warn("For backward hooks to be called," + " module output should be a Tensor or a tuple of Tensors" + f" but received {type(result)}", stacklevel=2) + result = bw_hook.setup_output_hook(result) + + # Handle the non-full backward hooks + if non_full_backward_hooks: + var = result + while not isinstance(var, torch.Tensor): + if isinstance(var, dict): + var = next(v for v in var.values() if isinstance(v, torch.Tensor)) + else: + var = var[0] + grad_fn = var.grad_fn + if grad_fn is not None: + for hook in non_full_backward_hooks: + grad_fn.register_hook(_WrappedHook(hook, self)) + self._maybe_warn_non_full_backward_hook(args, result, grad_fn) + + return result + + # This is technically not behavior equivalent when compiling, but it's + # incredibly unlikely we will ever support throwing an exception in NN + # module, and then catching it here, and then reraising it, and then + # catching it again, and expecting the resulting frame to be compiled. + # The reraise here just gunks up our exception handling for no good + # reason. Don't try to run the always called hooks in event of + # exception. + if torch.compiler.is_compiling(): + return inner() + + try: + return inner() + except Exception: + # run always called hooks if they have not already been run + # For now only forward hooks have the always_call option but perhaps + # this functionality should be added to full backward hooks as well. + for hook_id, hook in _global_forward_hooks.items(): + if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] + try: + hook_result = hook(self, args, result) # type: ignore[possibly-undefined] + if hook_result is not None: + result = hook_result + except Exception as e: + warnings.warn("global module forward hook with ``always_call=True`` raised an exception " + f"that was silenced as another error was raised in forward: {str(e)}", stacklevel=2) + continue + + for hook_id, hook in self._forward_hooks.items(): + if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] + try: + if hook_id in self._forward_hooks_with_kwargs: + hook_result = hook(self, args, kwargs, result) # type: ignore[possibly-undefined] + else: + hook_result = hook(self, args, result) # type: ignore[possibly-undefined] + if hook_result is not None: + result = hook_result + except Exception as e: + warnings.warn("module forward hook with ``always_call=True`` raised an exception " + f"that was silenced as another error was raised in forward: {str(e)}", stacklevel=2) + continue + # raise exception raised in try block + raise + # fmt: on + + __call__: Callable[..., Any] = _wrapped_call_impl + + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_compiled_call_impl", None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + # Support loading old checkpoints that don't have the following attrs: + if "_forward_pre_hooks" not in self.__dict__: + self._forward_pre_hooks = OrderedDict() + if "_forward_pre_hooks_with_kwargs" not in self.__dict__: + self._forward_pre_hooks_with_kwargs = OrderedDict() + if "_forward_hooks_with_kwargs" not in self.__dict__: + self._forward_hooks_with_kwargs = OrderedDict() + if "_forward_hooks_always_called" not in self.__dict__: + self._forward_hooks_always_called = OrderedDict() + if "_state_dict_hooks" not in self.__dict__: + self._state_dict_hooks = OrderedDict() + if "_state_dict_pre_hooks" not in self.__dict__: + self._state_dict_pre_hooks = OrderedDict() + if "_load_state_dict_pre_hooks" not in self.__dict__: + self._load_state_dict_pre_hooks = OrderedDict() + if "_load_state_dict_post_hooks" not in self.__dict__: + self._load_state_dict_post_hooks = OrderedDict() + if "_non_persistent_buffers_set" not in self.__dict__: + self._non_persistent_buffers_set = set() + if "_is_full_backward_hook" not in self.__dict__: + self._is_full_backward_hook = None + if "_backward_pre_hooks" not in self.__dict__: + self._backward_pre_hooks = OrderedDict() + + # It is crucial that the return type is not annotated as `Any`, otherwise type checking + # on `torch.nn.Module` and all its subclasses is largely disabled as a result. See: + # https://github.com/pytorch/pytorch/pull/115074 + def __getattr__(self, name: str) -> Union[Tensor, "Module"]: + if "_parameters" in self.__dict__: + _parameters = self.__dict__["_parameters"] + if name in _parameters: + return _parameters[name] + if "_buffers" in self.__dict__: + _buffers = self.__dict__["_buffers"] + if name in _buffers: + return _buffers[name] + if "_modules" in self.__dict__: + modules = self.__dict__["_modules"] + if name in modules: + return modules[name] + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None: + def remove_from(*dicts_or_sets) -> None: + for d in dicts_or_sets: + if name in d: + if isinstance(d, dict): + del d[name] + else: + d.discard(name) + + params = self.__dict__.get("_parameters") + if isinstance(value, Parameter): + if params is None: + raise AttributeError( + "cannot assign parameters before Module.__init__() call" + ) + remove_from( + self.__dict__, + self._buffers, + self._modules, + self._non_persistent_buffers_set, + ) + self.register_parameter(name, value) + elif params is not None and name in params: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as parameter '{name}' " + "(torch.nn.Parameter or None expected)" + ) + self.register_parameter(name, value) + else: + modules = self.__dict__.get("_modules") + if isinstance(value, Module): + if modules is None: + raise AttributeError( + "cannot assign module before Module.__init__() call" + ) + remove_from( + self.__dict__, + self._parameters, + self._buffers, + self._non_persistent_buffers_set, + ) + for hook in _global_module_registration_hooks.values(): + output = hook(self, name, value) + if output is not None: + value = output + modules[name] = value + elif modules is not None and name in modules: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as child module '{name}' " + "(torch.nn.Module or None expected)" + ) + for hook in _global_module_registration_hooks.values(): + output = hook(self, name, value) + if output is not None: + value = output + modules[name] = value + else: + buffers = self.__dict__.get("_buffers") + if isinstance(value, Buffer) or buffers is not None and name in buffers: + if value is not None and not ( + isinstance(value, torch.Tensor) + or hasattr(value, "__torch_function__") + ): + raise TypeError( + f"cannot assign '{torch.typename(value)}' as buffer '{name}' " + "(torch.nn.Buffer, torch.Tensor or None expected)" + ) + if isinstance(value, Buffer): + persistent = value.persistent + else: + persistent = name not in self._non_persistent_buffers_set + # === HACK === + # This whole block below should just be: + # self.register_buffer(name, value, persistent) + + # But to support subclasses of nn.Module that (wrongfully) implement a + # register_buffer() method that doesn't have the "persistent" + # argument. Only pass it in if it is accepted otherwise assume + # it is always true + if ( + getattr(self.register_buffer, "__func__", None) + is torch.nn.Module.register_buffer + ): + self.register_buffer(name, value, persistent) + else: + sign = inspect.signature(self.register_buffer) + if "persistent" in sign.parameters: + self.register_buffer(name, value, persistent) + else: + if not persistent: + raise RuntimeError( + "Registering a non-persistent buffer " + "on a Module subclass that implements " + "register_buffer() without the persistent " + "argument is not allowed." + ) + # Assume that the implementation without the argument has the + # behavior from before the argument was added: persistent=True + self.register_buffer(name, value) + # === HACK END === + else: + super().__setattr__(name, value) + + def __delattr__(self, name) -> None: + if name in self._parameters: + del self._parameters[name] + elif name in self._buffers: + del self._buffers[name] + self._non_persistent_buffers_set.discard(name) + elif name in self._modules: + del self._modules[name] + else: + super().__delattr__(name) + + def _register_state_dict_hook(self, hook): + r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method. + + It should have the following signature:: + hook(module, state_dict, prefix, local_metadata) -> None or state_dict + + The registered hooks can modify the ``state_dict`` inplace or return a new one. + If a new ``state_dict`` is returned, it will only be respected if it is the root + module that :meth:`~nn.Module.state_dict` is called from. + """ + if getattr(hook, "_from_public_api", False): + raise RuntimeError( + "Cannot register the same function as the state dict post hook that was " + "previously registered via register_state_dict_post_hook" + ) + handle = RemovableHandle(self._state_dict_hooks) + self._state_dict_hooks[handle.id] = hook + return handle + + def register_state_dict_post_hook(self, hook): + r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method. + + It should have the following signature:: + hook(module, state_dict, prefix, local_metadata) -> None + + The registered hooks can modify the ``state_dict`` inplace. + """ + # In _register_state_dict_hook there was a bug described in + # https://github.com/pytorch/pytorch/issues/117437 where the return value + # was only respected for the root module but not child submodules. + # We fix this in this public version by only allowing inplace modifications on + # the state_dict by the hook. However, since hooks registered via both these + # APIs will be added to `_state_dict_hooks` and the type of `_state_dict_hooks` + # cannot be changed due to many dependencies on it, we mark a hook + # as being registered via the public API by setting `_from_public_api` on it. + # In the implementation of `state_dict`, if the callable does not have this + # flag, the old behavior of respecting the return value will be preserved + # for the root module, otherwise, we ensure that the hook returns None. + hook._from_public_api = True + handle = RemovableHandle(self._state_dict_hooks) + self._state_dict_hooks[handle.id] = hook + return handle + + def register_state_dict_pre_hook(self, hook): + r"""Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. + + It should have the following signature:: + hook(module, prefix, keep_vars) -> None + + The registered hooks can be used to perform pre-processing before the ``state_dict`` + call is made. + """ + handle = RemovableHandle(self._state_dict_pre_hooks) + self._state_dict_pre_hooks[handle.id] = hook + return handle + + def _save_to_state_dict(self, destination, prefix, keep_vars) -> None: + r"""Save module state to the `destination` dictionary. + + The `destination` dictionary will contain the state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(self.__class__, "get_extra_state", Module.get_extra_state) + is not Module.get_extra_state + ): + destination[extra_state_key] = self.get_extra_state() + + # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns + # back that same object. But if they pass nothing, an `OrderedDict` is created and returned. + T_destination = TypeVar("T_destination", bound=dict[str, Any]) + + @overload + def state_dict( + self, + *, + destination: T_destination, + prefix: str = ..., + keep_vars: bool = ..., + ) -> T_destination: ... + + @overload + def state_dict( + self, + *, + prefix: str = ..., + keep_vars: bool = ..., + ) -> dict[str, Any]: ... + + # TODO: Change `*args` to `*` and remove the corresponding warning in docs when BC allows. + # Also remove the logic for arg parsing together. + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + r"""Return a dictionary containing references to the whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + .. note:: + The returned object is a shallow copy. It contains references + to the module's parameters and buffers. + + .. warning:: + Currently ``state_dict()`` also accepts positional arguments for + ``destination``, ``prefix`` and ``keep_vars`` in order. However, + this is being deprecated and keyword arguments will be enforced in + future releases. + + .. warning:: + Please avoid the use of argument ``destination`` as it is not + designed for end-users. + + Args: + destination (dict, optional): If provided, the state of module will + be updated into the dict and the same object is returned. + Otherwise, an ``OrderedDict`` will be created and returned. + Default: ``None``. + prefix (str, optional): a prefix added to parameter and buffer + names to compose the keys in state_dict. Default: ``''``. + keep_vars (bool, optional): by default the :class:`~torch.Tensor` s + returned in the state dict are detached from autograd. If it's + set to ``True``, detaching will not be performed. + Default: ``False``. + + Returns: + dict: + a dictionary containing a whole state of the module + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> module.state_dict().keys() + ['bias', 'weight'] + + """ + # TODO: Remove `args` and the parsing logic when BC allows. + if len(args) > 0: + # DeprecationWarning is ignored by default + warnings.warn( + "Positional args are being deprecated, use kwargs instead. Refer to " + "https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.state_dict" + " for details.", + FutureWarning, + stacklevel=2, + ) + if destination is None: + destination = args[0] + if len(args) > 1 and prefix == "": + prefix = args[1] + if len(args) > 2 and keep_vars is False: + keep_vars = args[2] + + if destination is None: + destination = OrderedDict() + # pyrefly: ignore [missing-attribute] + destination._metadata = OrderedDict() + + local_metadata = dict(version=self._version) + if hasattr(destination, "_metadata"): + destination._metadata[prefix[:-1]] = local_metadata + + for hook in self._state_dict_pre_hooks.values(): + hook(self, prefix, keep_vars) + self._save_to_state_dict(destination, prefix, keep_vars) + for name, module in self._modules.items(): + if module is not None: + module.state_dict( + destination=destination, + prefix=prefix + name + ".", + keep_vars=keep_vars, + ) + for hook in self._state_dict_hooks.values(): + hook_result = hook(self, destination, prefix, local_metadata) + if not getattr(hook, "_from_public_api", False): + if hook_result is not None: + destination = hook_result + else: + if hook_result is not None: + raise RuntimeError("state_dict post-hook must return None") + return destination + + def _register_load_state_dict_pre_hook(self, hook, with_module=False): + r"""See :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` for details. + + A subtle difference is that if ``with_module`` is set to ``False``, then the + hook will not take the ``module`` as the first argument whereas + :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` always takes the + ``module`` as the first argument. + + Arguments: + hook (Callable): Callable hook that will be invoked before + loading the state dict. + with_module (bool, optional): Whether or not to pass the module + instance to the hook as the first parameter. + """ + handle = RemovableHandle(self._load_state_dict_pre_hooks) + self._load_state_dict_pre_hooks[handle.id] = _WrappedHook( + hook, self if with_module else None + ) + return handle + + def register_load_state_dict_pre_hook(self, hook): + r"""Register a pre-hook to be run before module's :meth:`~nn.Module.load_state_dict` is called. + + It should have the following signature:: + hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950 + + Arguments: + hook (Callable): Callable hook that will be invoked before + loading the state dict. + """ + return self._register_load_state_dict_pre_hook(hook, with_module=True) + + def register_load_state_dict_post_hook(self, hook): + r"""Register a post-hook to be run after module's :meth:`~nn.Module.load_state_dict` is called. + + It should have the following signature:: + hook(module, incompatible_keys) -> None + + The ``module`` argument is the current module that this hook is registered + on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting + of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` + is a ``list`` of ``str`` containing the missing keys and + ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. + + The given incompatible_keys can be modified inplace if needed. + + Note that the checks performed when calling :func:`load_state_dict` with + ``strict=True`` are affected by modifications the hook makes to + ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either + set of keys will result in an error being thrown when ``strict=True``, and + clearing out both missing and unexpected keys will avoid an error. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = RemovableHandle(self._load_state_dict_post_hooks) + self._load_state_dict_post_hooks[handle.id] = hook + return handle + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) -> None: + r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. + + This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + Additionally, :attr:`local_metadata` can also contain the key + `assign_to_params_buffers` that indicates whether keys should be + assigned their corresponding tensor in the state_dict. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + persistent_buffers = { + k: v + for k, v in self._buffers.items() + if k not in self._non_persistent_buffers_set + } + local_name_params = itertools.chain( + self._parameters.items(), + # pyrefly: ignore [bad-argument-type] + persistent_buffers.items(), + ) + local_state = {k: v for k, v in local_name_params if v is not None} + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append( + f'While copying the parameter named "{key}", ' + "expected torch.Tensor or Tensor-like object from checkpoint but " + f"received {type(input_param)}" + ) + continue + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if ( + not is_param_lazy + and len(param.shape) == 0 + and len(input_param.shape) == 1 + and input_param.shape[0] == 1 + ): + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append( + f"size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, " + f"the shape in current model is {param.shape}." + ) + continue + + if ( + param.is_meta + and not input_param.is_meta + and not assign_to_params_buffers + ): + warnings.warn( + f"for {key}: copying from a non-meta parameter in the checkpoint to a meta " + "parameter in the current model, which is a no-op. (Did you mean to " + "pass `assign=True` to assign items in the state dictionary to their " + "corresponding key in the module instead of copying them in place?)", + stacklevel=2, + ) + + try: + with torch.no_grad(): + if use_swap_tensors: + new_input_param = param.module_load( + input_param, assign=assign_to_params_buffers + ) + if id(new_input_param) == id(input_param) or id( + new_input_param + ) == id(param): + raise RuntimeError( + "module_load returned one of self or other, please .detach() " + "the result if returning one of the inputs in module_load" + ) + if isinstance(param, torch.nn.Parameter): + if not isinstance(new_input_param, torch.nn.Parameter): + new_input_param = torch.nn.Parameter( + new_input_param, + requires_grad=param.requires_grad, + ) + else: + new_input_param.requires_grad_(param.requires_grad) + torch.utils.swap_tensors(param, new_input_param) + del new_input_param + elif assign_to_params_buffers: + # Shape checks are already done above + if isinstance(param, torch.nn.Parameter): + if not isinstance(input_param, torch.nn.Parameter): + input_param = torch.nn.Parameter( + input_param, requires_grad=param.requires_grad + ) + else: + input_param.requires_grad_(param.requires_grad) + setattr(self, name, input_param) + else: + param.copy_(input_param) + except Exception as ex: + action = "swapping" if use_swap_tensors else "copying" + error_msgs.append( + f'While {action} the parameter named "{key}", ' + f"whose dimensions in the model are {param.size()} and " + f"whose dimensions in the checkpoint are {input_param.size()}, " + f"an exception occurred : {ex.args}." + ) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(self.__class__, "set_extra_state", Module.set_extra_state) + is not Module.set_extra_state + ): + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict: + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix) :].split(".", 1) + # Must be Module if it have attributes + if len(input_name) > 1: + if input_name[0] not in self._modules: + unexpected_keys.append(key) + elif input_name[0] not in local_state: + unexpected_keys.append(key) + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + ): + r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. + + If :attr:`strict` is ``True``, then + the keys of :attr:`state_dict` must exactly match the keys returned + by this module's :meth:`~torch.nn.Module.state_dict` function. + + .. warning:: + If :attr:`assign` is ``True`` the optimizer must be created after + the call to :attr:`load_state_dict` unless + :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + assign (bool, optional): When set to ``False``, the properties of the tensors + in the current module are preserved whereas setting it to ``True`` preserves + properties of the Tensors in the state dict. The only + exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter` + for which the value from the module is preserved. Default: ``False`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * ``missing_keys`` is a list of str containing any keys that are expected + by this module but missing from the provided ``state_dict``. + * ``unexpected_keys`` is a list of str containing the keys that are not + expected by this module but present in the provided ``state_dict``. + + Note: + If a parameter or buffer is registered as ``None`` and its corresponding key + exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a + ``RuntimeError``. + """ + if not isinstance(state_dict, Mapping): + raise TypeError( + f"Expected state_dict to be dict-like, got {type(state_dict)}." + ) + + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + error_msgs: list[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] + + def load(module, local_state_dict, prefix="") -> None: + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + if assign: + local_metadata["assign_to_params_buffers"] = assign + module._load_from_state_dict( + local_state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + error_msgs, + ) + for name, child in module._modules.items(): + if child is not None: + child_prefix = prefix + name + "." + child_state_dict = { + k: v + for k, v in local_state_dict.items() + if k.startswith(child_prefix) + } + load(child, child_state_dict, child_prefix) # noqa: F821 + + # Note that the hook can modify missing_keys and unexpected_keys. + incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) + for hook in module._load_state_dict_post_hooks.values(): + out = hook(module, incompatible_keys) + assert out is None, ( + "Hooks registered with ``register_load_state_dict_post_hook`` are not" + "expected to return new values, if incompatible_keys need to be modified," + "it should be done inplace." + ) + + load(self, state_dict) + del load + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, + "Unexpected key(s) in state_dict: {}. ".format( + ", ".join(f'"{k}"' for k in unexpected_keys) + ), + ) + if len(missing_keys) > 0: + error_msgs.insert( + 0, + "Missing key(s) in state_dict: {}. ".format( + ", ".join(f'"{k}"' for k in missing_keys) + ), + ) + + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + self.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + def _named_members( + self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True + ): + r"""Help yield various names + members of modules.""" + memo = set() + modules = ( + self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) + if recurse + else [(prefix, self)] + ) + for module_prefix, module in modules: + members = get_members_fn(module) + for k, v in members: + if v is None or v in memo: + continue + if remove_duplicate: + memo.add(v) + name = module_prefix + ("." if module_prefix else "") + k + yield name, v + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + r"""Return an iterator over module parameters. + + This is typically passed to an optimizer. + + Args: + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + Parameter: module parameter + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for param in model.parameters(): + >>> print(type(param), param.size()) + (20L,) + (20L, 1L, 5L, 5L) + + """ + for _name, param in self.named_parameters(recurse=recurse): + yield param + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[tuple[str, Parameter]]: + r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. + + Args: + prefix (str): prefix to prepend to all parameter names. + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + remove_duplicate (bool, optional): whether to remove the duplicated + parameters in the result. Defaults to True. + + Yields: + (str, Parameter): Tuple containing the name and parameter + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, param in self.named_parameters(): + >>> if name in ['bias']: + >>> print(param.size()) + + """ + gen = self._named_members( + lambda module: module._parameters.items(), + prefix=prefix, + recurse=recurse, + remove_duplicate=remove_duplicate, + ) + yield from gen + + def buffers(self, recurse: bool = True) -> Iterator[Tensor]: + r"""Return an iterator over module buffers. + + Args: + recurse (bool): if True, then yields buffers of this module + and all submodules. Otherwise, yields only buffers that + are direct members of this module. + + Yields: + torch.Tensor: module buffer + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for buf in model.buffers(): + >>> print(type(buf), buf.size()) + (20L,) + (20L, 1L, 5L, 5L) + + """ + for _, buf in self.named_buffers(recurse=recurse): + yield buf + + def named_buffers( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[tuple[str, Tensor]]: + r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. + + Args: + prefix (str): prefix to prepend to all buffer names. + recurse (bool, optional): if True, then yields buffers of this module + and all submodules. Otherwise, yields only buffers that + are direct members of this module. Defaults to True. + remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. + + Yields: + (str, torch.Tensor): Tuple containing the name and buffer + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, buf in self.named_buffers(): + >>> if name in ['running_var']: + >>> print(buf.size()) + + """ + gen = self._named_members( + lambda module: module._buffers.items(), + prefix=prefix, + recurse=recurse, + remove_duplicate=remove_duplicate, + ) + yield from gen + + def children(self) -> Iterator["Module"]: + r"""Return an iterator over immediate children modules. + + Yields: + Module: a child module + """ + for _name, module in self.named_children(): + yield module + + def named_children(self) -> Iterator[tuple[str, "Module"]]: + r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. + + Yields: + (str, Module): Tuple containing a name and child module + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, module in model.named_children(): + >>> if name in ['conv4', 'conv5']: + >>> print(module) + + """ + memo = set() + for name, module in self._modules.items(): + if module is not None and module not in memo: + memo.add(module) + yield name, module + + def modules(self) -> Iterator["Module"]: + r"""Return an iterator over all modules in the network. + + Yields: + Module: a module in the network + + Note: + Duplicate modules are returned only once. In the following + example, ``l`` will be returned only once. + + Example:: + + >>> l = nn.Linear(2, 2) + >>> net = nn.Sequential(l, l) + >>> for idx, m in enumerate(net.modules()): + ... print(idx, '->', m) + + 0 -> Sequential( + (0): Linear(in_features=2, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=2, bias=True) + ) + 1 -> Linear(in_features=2, out_features=2, bias=True) + + """ + for _, module in self.named_modules(): + yield module + + def named_modules( + self, + memo: set["Module"] | None = None, + prefix: str = "", + remove_duplicate: bool = True, + ): + r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. + + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + + Yields: + (str, Module): Tuple of name and module + + Note: + Duplicate modules are returned only once. In the following + example, ``l`` will be returned only once. + + Example:: + + >>> l = nn.Linear(2, 2) + >>> net = nn.Sequential(l, l) + >>> for idx, m in enumerate(net.named_modules()): + ... print(idx, '->', m) + + 0 -> ('', Sequential( + (0): Linear(in_features=2, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=2, bias=True) + )) + 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) + + """ + if memo is None: + memo = set() + if self not in memo: + if remove_duplicate: + memo.add(self) + yield prefix, self + for name, module in self._modules.items(): + if module is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + yield from module.named_modules( + memo, submodule_prefix, remove_duplicate + ) + + def train(self, mode: bool = True) -> Self: + r"""Set the module in training mode. + + This has an effect only on certain modules. See the documentation of + particular modules for details of their behaviors in training/evaluation + mode, i.e., whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, + etc. + + Args: + mode (bool): whether to set training mode (``True``) or evaluation + mode (``False``). Default: ``True``. + + Returns: + Module: self + """ + if not isinstance(mode, bool): + raise ValueError("training mode is expected to be boolean") + self.training = mode + for module in self.children(): + module.train(mode) + return self + + def eval(self) -> Self: + r"""Set the module in evaluation mode. + + This has an effect only on certain modules. See the documentation of + particular modules for details of their behaviors in training/evaluation + mode, i.e. whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, + etc. + + This is equivalent with :meth:`self.train(False) `. + + See :ref:`locally-disable-grad-doc` for a comparison between + `.eval()` and several similar mechanisms that may be confused with it. + + Returns: + Module: self + """ + return self.train(False) + + def requires_grad_(self, requires_grad: bool = True) -> Self: + r"""Change if autograd should record operations on parameters in this module. + + This method sets the parameters' :attr:`requires_grad` attributes + in-place. + + This method is helpful for freezing part of the module for finetuning + or training parts of a model individually (e.g., GAN training). + + See :ref:`locally-disable-grad-doc` for a comparison between + `.requires_grad_()` and several similar mechanisms that may be confused with it. + + Args: + requires_grad (bool): whether autograd should record operations on + parameters in this module. Default: ``True``. + + Returns: + Module: self + """ + for p in self.parameters(): + p.requires_grad_(requires_grad) + return self + + def zero_grad(self, set_to_none: bool = True) -> None: + r"""Reset gradients of all model parameters. + + See similar function under :class:`torch.optim.Optimizer` for more context. + + Args: + set_to_none (bool): instead of setting to zero, set the grads to None. + See :meth:`torch.optim.Optimizer.zero_grad` for details. + """ + if getattr(self, "_is_replica", False): + warnings.warn( + "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. " + "The parameters are copied (in a differentiable manner) from the original module. " + "This means they are not leaf nodes in autograd and so don't accumulate gradients. " + "If you need gradients in your forward method, consider using autograd.grad instead.", + stacklevel=2, + ) + + for p in self.parameters(): + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + p.grad.zero_() + + def share_memory(self) -> Self: + r"""See :meth:`torch.Tensor.share_memory_`.""" + return self._apply(lambda t: t.share_memory_()) + + def _get_name(self): + return self.__class__.__name__ + + def extra_repr(self) -> str: + r"""Return the extra representation of the module. + + To print customized extra information, you should re-implement + this method in your own modules. Both single-line and multi-line + strings are acceptable. + """ + return "" + + def __repr__(self) -> str: + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + for key, module in self._modules.items(): + mod_str = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append("(" + key + "): " + mod_str) + lines = extra_lines + child_lines + + main_str = self._get_name() + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + return main_str + + def __dir__(self): + module_attrs = dir(self.__class__) + attrs = list(self.__dict__.keys()) + parameters = list(self._parameters.keys()) + modules = list(self._modules.keys()) + buffers = list(self._buffers.keys()) + keys = module_attrs + attrs + parameters + modules + buffers + + # Eliminate attrs that are not legal Python variable names + keys = [key for key in keys if not key[0].isdigit()] + + return sorted(keys) + + def _replicate_for_data_parallel(self): + replica = self.__new__(type(self)) + replica.__dict__ = self.__dict__.copy() + + # replicas do not have parameters themselves, the replicas reference the original + # module. + replica._parameters = {} + replica._buffers = replica._buffers.copy() + replica._modules = replica._modules.copy() + replica._is_replica = True # type: ignore[assignment] + + return replica + + def compile(self, *args, **kwargs) -> None: + """ + Compile this Module's forward using :func:`torch.compile`. + + This Module's `__call__` method is compiled and all arguments are passed as-is + to :func:`torch.compile`. + + See :func:`torch.compile` for details on the arguments for this function. + """ + self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/normalization.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..d492cdb3cf5a03c647760401fcc6f8709d87f1bd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/normalization.py @@ -0,0 +1,430 @@ +# mypy: allow-untyped-defs +import numbers +from typing import Union + +import torch +from torch import Size, Tensor +from torch.nn import functional as F, init +from torch.nn.parameter import Parameter + +from ._functions import CrossMapLRN2d as _cross_map_lrn2d +from .module import Module + + +__all__ = ["LocalResponseNorm", "CrossMapLRN2d", "LayerNorm", "GroupNorm", "RMSNorm"] + + +class LocalResponseNorm(Module): + r"""Applies local response normalization over an input signal. + + The input signal is composed of several input planes, where channels occupy the second dimension. + Applies normalization across channels. + + .. math:: + b_{c} = a_{c}\left(k + \frac{\alpha}{n} + \sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta} + + Args: + size: amount of neighbouring channels used for normalization + alpha: multiplicative factor. Default: 0.0001 + beta: exponent. Default: 0.75 + k: additive factor. Default: 1 + + Shape: + - Input: :math:`(N, C, *)` + - Output: :math:`(N, C, *)` (same shape as input) + + Examples:: + + >>> lrn = nn.LocalResponseNorm(2) + >>> signal_2d = torch.randn(32, 5, 24, 24) + >>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7) + >>> output_2d = lrn(signal_2d) + >>> output_4d = lrn(signal_4d) + + """ + + __constants__ = ["size", "alpha", "beta", "k"] + size: int + alpha: float + beta: float + k: float + + def __init__( + self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0 + ) -> None: + super().__init__() + self.size = size + self.alpha = alpha + self.beta = beta + self.k = k + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.local_response_norm(input, self.size, self.alpha, self.beta, self.k) + + def extra_repr(self): + """ + Return the extra representation of the module. + """ + return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__) + + +class CrossMapLRN2d(Module): + size: int + alpha: float + beta: float + k: float + + def __init__( + self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1 + ) -> None: + super().__init__() + self.size = size + self.alpha = alpha + self.beta = beta + self.k = k + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta, self.k) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__) + + +_shape_t = Union[int, list[int], Size] + + +class LayerNorm(Module): + r"""Applies Layer Normalization over a mini-batch of inputs. + + This layer implements the operation as described in + the paper `Layer Normalization `__ + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated over the last `D` dimensions, where `D` + is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape` + is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over + the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``). + :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of + :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. + The variance is calculated via the biased estimator, equivalent to + `torch.var(input, correction=0)`. + + .. note:: + Unlike Batch Normalization and Instance Normalization, which applies + scalar scale and bias for each entire channel/plane with the + :attr:`affine` option, Layer Normalization applies per-element scale and + bias with :attr:`elementwise_affine`. + + This layer uses statistics computed from input data in both training and + evaluation modes. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] + \times \ldots \times \text{normalized\_shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine: a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`elementwise_affine` is ``True``). Default: ``True``. + + Attributes: + weight: the learnable weights of the module of shape + :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``. + The values are initialized to 1. + bias: the learnable bias of the module of shape + :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``. + The values are initialized to 0. + + Shape: + - Input: :math:`(N, *)` + - Output: :math:`(N, *)` (same shape as input) + + Examples:: + + >>> # NLP Example + >>> batch, sentence_length, embedding_dim = 20, 5, 10 + >>> embedding = torch.randn(batch, sentence_length, embedding_dim) + >>> layer_norm = nn.LayerNorm(embedding_dim) + >>> # Activate module + >>> layer_norm(embedding) + >>> + >>> # Image Example + >>> N, C, H, W = 20, 5, 10, 10 + >>> input = torch.randn(N, C, H, W) + >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions) + >>> # as shown in the image below + >>> layer_norm = nn.LayerNorm([C, H, W]) + >>> output = layer_norm(input) + + .. image:: ../_static/img/nn/layer_norm.jpg + :scale: 50 % + + """ + + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + if bias: + self.bias = Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("bias", None) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + init.ones_(self.weight) + if self.bias is not None: + init.zeros_(self.bias) + + def forward(self, input: Tensor) -> Tensor: + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) + + def extra_repr(self) -> str: + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +class GroupNorm(Module): + r"""Applies Group Normalization over a mini-batch of inputs. + + This layer implements the operation as described in + the paper `Group Normalization `__ + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The input channels are separated into :attr:`num_groups` groups, each containing + ``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by + :attr:`num_groups`. The mean and standard-deviation are calculated + separately over each group. :math:`\gamma` and :math:`\beta` are learnable + per-channel affine transform parameter vectors of size :attr:`num_channels` if + :attr:`affine` is ``True``. + The variance is calculated via the biased estimator, equivalent to + `torch.var(input, correction=0)`. + + This layer uses statistics computed from input data in both training and + evaluation modes. + + Args: + num_groups (int): number of groups to separate the channels into + num_channels (int): number of channels expected in input + eps: a value added to the denominator for numerical stability. Default: 1e-5 + affine: a boolean value that when set to ``True``, this module + has learnable per-channel affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + + Shape: + - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}` + - Output: :math:`(N, C, *)` (same shape as input) + + Examples:: + + >>> input = torch.randn(20, 6, 10, 10) + >>> # Separate 6 channels into 3 groups + >>> m = nn.GroupNorm(3, 6) + >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) + >>> m = nn.GroupNorm(6, 6) + >>> # Put all 6 channels into a single group (equivalent with LayerNorm) + >>> m = nn.GroupNorm(1, 6) + >>> # Activating the module + >>> output = m(input) + """ + + __constants__ = ["num_groups", "num_channels", "eps", "affine"] + num_groups: int + num_channels: int + eps: float + affine: bool + + def __init__( + self, + num_groups: int, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if num_channels % num_groups != 0: + raise ValueError( + f"num_channels ({num_channels}) must be divisible by num_groups ({num_groups})" + ) + + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + self.affine = affine + if self.affine: + self.weight = Parameter(torch.empty(num_channels, **factory_kwargs)) + self.bias = Parameter(torch.empty(num_channels, **factory_kwargs)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input: Tensor) -> Tensor: + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + + def extra_repr(self) -> str: + return "{num_groups}, {num_channels}, eps={eps}, affine={affine}".format( + **self.__dict__ + ) + + +class RMSNorm(Module): + r"""Applies Root Mean Square Layer Normalization over a mini-batch of inputs. + + This layer implements the operation as described in + the paper `Root Mean Square Layer Normalization `__ + + .. math:: + y_i = \frac{x_i}{\mathrm{RMS}(x)} * \gamma_i, \quad + \text{where} \quad \text{RMS}(x) = \sqrt{\epsilon + \frac{1}{n} \sum_{i=1}^{n} x_i^2} + + The RMS is taken over the last ``D`` dimensions, where ``D`` + is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape` + is ``(3, 5)`` (a 2-dimensional shape), the RMS is computed over + the last 2 dimensions of the input. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] + \times \ldots \times \text{normalized\_shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability. Default: ``torch.finfo(x.dtype).eps`` + elementwise_affine: a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights). Default: ``True``. + + Shape: + - Input: :math:`(N, *)` + - Output: :math:`(N, *)` (same shape as input) + + Examples:: + + >>> rms_norm = nn.RMSNorm([2, 3]) + >>> input = torch.randn(2, 2, 3) + >>> rms_norm(input) + + """ + + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: tuple[int, ...] + eps: float | None + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float | None = None, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + if self.elementwise_affine: + init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Runs the forward pass. + """ + return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +# TODO: ContrastiveNorm2d +# TODO: DivisiveNorm2d +# TODO: SubtractiveNorm2d diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/padding.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/padding.py new file mode 100644 index 0000000000000000000000000000000000000000..d5aa1e0d425548857d20b093041b190bc7f2f645 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/padding.py @@ -0,0 +1,842 @@ +# mypy: allow-untyped-defs +from collections.abc import Sequence + +import torch.nn.functional as F +from torch import Tensor +from torch.nn.common_types import _size_2_t, _size_4_t, _size_6_t + +from .module import Module +from .utils import _ntuple, _pair, _quadruple + + +# TODO: grad_output size asserts in THNN + +__all__ = [ + "CircularPad1d", + "CircularPad2d", + "CircularPad3d", + "ConstantPad1d", + "ConstantPad2d", + "ConstantPad3d", + "ReflectionPad1d", + "ReflectionPad2d", + "ReflectionPad3d", + "ReplicationPad1d", + "ReplicationPad2d", + "ReplicationPad3d", + "ZeroPad1d", + "ZeroPad2d", + "ZeroPad3d", +] + + +class _CircularPadNd(Module): + __constants__ = ["padding"] + padding: Sequence[int] + + def _check_input_dim(self, input): + raise NotImplementedError + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + return F.pad(input, self.padding, "circular") + + def extra_repr(self) -> str: + return f"{self.padding}" + + +class CircularPad1d(_CircularPadNd): + r"""Pads the input tensor using circular padding of the input boundary. + + Tensor values at the beginning of the dimension are used to pad the end, + and values at the end are used to pad the beginning. If negative padding is + applied then the ends of the tensor get removed. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 2-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + Note that padding size should be less than or equal to the corresponding input dimension. + + Shape: + - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. + - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this") + >>> m = nn.CircularPad1d(2) + >>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4) + >>> input + tensor([[[0., 1., 2., 3.], + [4., 5., 6., 7.]]]) + >>> m(input) + tensor([[[2., 3., 0., 1., 2., 3., 0., 1.], + [6., 7., 4., 5., 6., 7., 4., 5.]]]) + >>> # using different paddings for different sides + >>> m = nn.CircularPad1d((3, 1)) + >>> m(input) + tensor([[[1., 2., 3., 0., 1., 2., 3., 0.], + [5., 6., 7., 4., 5., 6., 7., 4.]]]) + """ + + # pyrefly: ignore [bad-override] + padding: tuple[int, int] + + def __init__(self, padding: _size_2_t) -> None: + super().__init__() + self.padding = _pair(padding) + + def _check_input_dim(self, input) -> None: + if input.dim() != 2 and input.dim() != 3: + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") + + +class CircularPad2d(_CircularPadNd): + r"""Pads the input tensor using circular padding of the input boundary. + + Tensor values at the beginning of the dimension are used to pad the end, + and values at the end are used to pad the beginning. If negative padding is + applied then the ends of the tensor get removed. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, + :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + Note that padding size should be less than or equal to the corresponding input dimension. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> m = nn.CircularPad2d(2) + >>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3) + >>> input + tensor([[[[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]]]) + >>> m(input) + tensor([[[[4., 5., 3., 4., 5., 3., 4.], + [7., 8., 6., 7., 8., 6., 7.], + [1., 2., 0., 1., 2., 0., 1.], + [4., 5., 3., 4., 5., 3., 4.], + [7., 8., 6., 7., 8., 6., 7.], + [1., 2., 0., 1., 2., 0., 1.], + [4., 5., 3., 4., 5., 3., 4.]]]]) + >>> # using different paddings for different sides + >>> m = nn.CircularPad2d((1, 1, 2, 0)) + >>> m(input) + tensor([[[[5., 3., 4., 5., 3.], + [8., 6., 7., 8., 6.], + [2., 0., 1., 2., 0.], + [5., 3., 4., 5., 3.], + [8., 6., 7., 8., 6.]]]]) + """ + + # pyrefly: ignore [bad-override] + padding: tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t) -> None: + super().__init__() + self.padding = _quadruple(padding) + + def _check_input_dim(self, input) -> None: + if input.dim() != 3 and input.dim() != 4: + raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") + + +class CircularPad3d(_CircularPadNd): + r"""Pads the input tensor using circular padding of the input boundary. + + Tensor values at the beginning of the dimension are used to pad the end, + and values at the end are used to pad the beginning. If negative padding is + applied then the ends of the tensor get removed. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 6-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, + :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, + :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + Note that padding size should be less than or equal to the corresponding input dimension. + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, + where + + :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}` + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.CircularPad3d(3) + >>> input = torch.randn(16, 3, 8, 320, 480) + >>> output = m(input) + >>> # using different paddings for different sides + >>> m = nn.CircularPad3d((3, 3, 6, 6, 1, 1)) + >>> output = m(input) + """ + + # pyrefly: ignore [bad-override] + padding: tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t) -> None: + super().__init__() + self.padding = _ntuple(6)(padding) + + def _check_input_dim(self, input) -> None: + if input.dim() != 4 and input.dim() != 5: + raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") + + +class _ConstantPadNd(Module): + __constants__ = ["padding", "value"] + value: float + padding: Sequence[int] + + def __init__(self, value: float) -> None: + super().__init__() + self.value = value + + def forward(self, input: Tensor) -> Tensor: + return F.pad(input, self.padding, "constant", self.value) + + def extra_repr(self) -> str: + return f"padding={self.padding}, value={self.value}" + + +class ConstantPad1d(_ConstantPadNd): + r"""Pads the input tensor boundaries with a constant value. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in both boundaries. If a 2-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + + Shape: + - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. + - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.ConstantPad1d(2, 3.5) + >>> input = torch.randn(1, 2, 4) + >>> input + tensor([[[-1.0491, -0.7152, -0.0749, 0.8530], + [-1.3287, 1.8966, 0.1466, -0.2771]]]) + >>> m(input) + tensor([[[ 3.5000, 3.5000, -1.0491, -0.7152, -0.0749, 0.8530, 3.5000, + 3.5000], + [ 3.5000, 3.5000, -1.3287, 1.8966, 0.1466, -0.2771, 3.5000, + 3.5000]]]) + >>> m = nn.ConstantPad1d(2, 3.5) + >>> input = torch.randn(1, 2, 3) + >>> input + tensor([[[ 1.6616, 1.4523, -1.1255], + [-3.6372, 0.1182, -1.8652]]]) + >>> m(input) + tensor([[[ 3.5000, 3.5000, 1.6616, 1.4523, -1.1255, 3.5000, 3.5000], + [ 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000, 3.5000]]]) + >>> # using different paddings for different sides + >>> m = nn.ConstantPad1d((3, 1), 3.5) + >>> m(input) + tensor([[[ 3.5000, 3.5000, 3.5000, 1.6616, 1.4523, -1.1255, 3.5000], + [ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]]) + """ + + # pyrefly: ignore [bad-override] + padding: tuple[int, int] + + def __init__(self, padding: _size_2_t, value: float) -> None: + super().__init__(value) + self.padding = _pair(padding) + + +class ConstantPad2d(_ConstantPadNd): + r"""Pads the input tensor boundaries with a constant value. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, + :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.ConstantPad2d(2, 3.5) + >>> input = torch.randn(1, 2, 2) + >>> input + tensor([[[ 1.6585, 0.4320], + [-0.8701, -0.4649]]]) + >>> m(input) + tensor([[[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000], + [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000], + [ 3.5000, 3.5000, 1.6585, 0.4320, 3.5000, 3.5000], + [ 3.5000, 3.5000, -0.8701, -0.4649, 3.5000, 3.5000], + [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000], + [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]]) + >>> # using different paddings for different sides + >>> m = nn.ConstantPad2d((3, 0, 2, 1), 3.5) + >>> m(input) + tensor([[[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000], + [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000], + [ 3.5000, 3.5000, 3.5000, 1.6585, 0.4320], + [ 3.5000, 3.5000, 3.5000, -0.8701, -0.4649], + [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]]) + """ + + __constants__ = ["padding", "value"] + # pyrefly: ignore [bad-override] + padding: tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t, value: float) -> None: + super().__init__(value) + self.padding = _quadruple(padding) + + +class ConstantPad3d(_ConstantPadNd): + r"""Pads the input tensor boundaries with a constant value. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 6-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, + :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, + :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or + :math:`(C, D_{out}, H_{out}, W_{out})`, where + + :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}` + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> m = nn.ConstantPad3d(3, 3.5) + >>> input = torch.randn(16, 3, 10, 20, 30) + >>> output = m(input) + >>> # using different paddings for different sides + >>> m = nn.ConstantPad3d((3, 3, 6, 6, 0, 1), 3.5) + >>> output = m(input) + """ + + # pyrefly: ignore [bad-override] + padding: tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t, value: float) -> None: + super().__init__(value) + self.padding = _ntuple(6)(padding) + + +class _ReflectionPadNd(Module): + __constants__ = ["padding"] + padding: Sequence[int] + + def forward(self, input: Tensor) -> Tensor: + return F.pad(input, self.padding, "reflect") + + def extra_repr(self) -> str: + return f"{self.padding}" + + +class ReflectionPad1d(_ReflectionPadNd): + r"""Pads the input tensor using the reflection of the input boundary. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 2-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + Note that padding size should be less than the corresponding input dimension. + + Shape: + - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. + - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> m = nn.ReflectionPad1d(2) + >>> # xdoctest: +IGNORE_WANT("other tests seem to modify printing styles") + >>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4) + >>> input + tensor([[[0., 1., 2., 3.], + [4., 5., 6., 7.]]]) + >>> m(input) + tensor([[[2., 1., 0., 1., 2., 3., 2., 1.], + [6., 5., 4., 5., 6., 7., 6., 5.]]]) + >>> # using different paddings for different sides + >>> m = nn.ReflectionPad1d((3, 1)) + >>> m(input) + tensor([[[3., 2., 1., 0., 1., 2., 3., 2.], + [7., 6., 5., 4., 5., 6., 7., 6.]]]) + """ + + # pyrefly: ignore [bad-override] + padding: tuple[int, int] + + def __init__(self, padding: _size_2_t) -> None: + super().__init__() + self.padding = _pair(padding) + + +class ReflectionPad2d(_ReflectionPadNd): + r"""Pads the input tensor using the reflection of the input boundary. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, + :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + Note that padding size should be less than the corresponding input dimension. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})` where + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this") + >>> m = nn.ReflectionPad2d(2) + >>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3) + >>> input + tensor([[[[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]]]) + >>> m(input) + tensor([[[[8., 7., 6., 7., 8., 7., 6.], + [5., 4., 3., 4., 5., 4., 3.], + [2., 1., 0., 1., 2., 1., 0.], + [5., 4., 3., 4., 5., 4., 3.], + [8., 7., 6., 7., 8., 7., 6.], + [5., 4., 3., 4., 5., 4., 3.], + [2., 1., 0., 1., 2., 1., 0.]]]]) + >>> # using different paddings for different sides + >>> m = nn.ReflectionPad2d((1, 1, 2, 0)) + >>> m(input) + tensor([[[[7., 6., 7., 8., 7.], + [4., 3., 4., 5., 4.], + [1., 0., 1., 2., 1.], + [4., 3., 4., 5., 4.], + [7., 6., 7., 8., 7.]]]]) + """ + + # pyrefly: ignore [bad-override] + padding: tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t) -> None: + super().__init__() + self.padding = _quadruple(padding) + + +class ReflectionPad3d(_ReflectionPadNd): + r"""Pads the input tensor using the reflection of the input boundary. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 6-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, + :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, + :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + Note that padding size should be less than the corresponding input dimension. + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, + where + + :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}` + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this") + >>> m = nn.ReflectionPad3d(1) + >>> input = torch.arange(8, dtype=torch.float).reshape(1, 1, 2, 2, 2) + >>> m(input) + tensor([[[[[7., 6., 7., 6.], + [5., 4., 5., 4.], + [7., 6., 7., 6.], + [5., 4., 5., 4.]], + [[3., 2., 3., 2.], + [1., 0., 1., 0.], + [3., 2., 3., 2.], + [1., 0., 1., 0.]], + [[7., 6., 7., 6.], + [5., 4., 5., 4.], + [7., 6., 7., 6.], + [5., 4., 5., 4.]], + [[3., 2., 3., 2.], + [1., 0., 1., 0.], + [3., 2., 3., 2.], + [1., 0., 1., 0.]]]]]) + """ + + # pyrefly: ignore [bad-override] + padding: tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t) -> None: + super().__init__() + self.padding = _ntuple(6)(padding) + + +class _ReplicationPadNd(Module): + __constants__ = ["padding"] + padding: Sequence[int] + + def forward(self, input: Tensor) -> Tensor: + return F.pad(input, self.padding, "replicate") + + def extra_repr(self) -> str: + return f"{self.padding}" + + +class ReplicationPad1d(_ReplicationPadNd): + r"""Pads the input tensor using replication of the input boundary. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 2-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + Note that the output dimensions must remain positive. + + Shape: + - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. + - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this") + >>> m = nn.ReplicationPad1d(2) + >>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4) + >>> input + tensor([[[0., 1., 2., 3.], + [4., 5., 6., 7.]]]) + >>> m(input) + tensor([[[0., 0., 0., 1., 2., 3., 3., 3.], + [4., 4., 4., 5., 6., 7., 7., 7.]]]) + >>> # using different paddings for different sides + >>> m = nn.ReplicationPad1d((3, 1)) + >>> m(input) + tensor([[[0., 0., 0., 0., 1., 2., 3., 3.], + [4., 4., 4., 4., 5., 6., 7., 7.]]]) + """ + + # pyrefly: ignore [bad-override] + padding: tuple[int, int] + + def __init__(self, padding: _size_2_t) -> None: + super().__init__() + self.padding = _pair(padding) + + +class ReplicationPad2d(_ReplicationPadNd): + r"""Pads the input tensor using replication of the input boundary. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, + :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + Note that the output dimensions must remain positive. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> m = nn.ReplicationPad2d(2) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3) + >>> input + tensor([[[[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]]]) + >>> m(input) + tensor([[[[0., 0., 0., 1., 2., 2., 2.], + [0., 0., 0., 1., 2., 2., 2.], + [0., 0., 0., 1., 2., 2., 2.], + [3., 3., 3., 4., 5., 5., 5.], + [6., 6., 6., 7., 8., 8., 8.], + [6., 6., 6., 7., 8., 8., 8.], + [6., 6., 6., 7., 8., 8., 8.]]]]) + >>> # using different paddings for different sides + >>> m = nn.ReplicationPad2d((1, 1, 2, 0)) + >>> m(input) + tensor([[[[0., 0., 1., 2., 2.], + [0., 0., 1., 2., 2.], + [0., 0., 1., 2., 2.], + [3., 3., 4., 5., 5.], + [6., 6., 7., 8., 8.]]]]) + """ + + # pyrefly: ignore [bad-override] + padding: tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t) -> None: + super().__init__() + self.padding = _quadruple(padding) + + +class ReplicationPad3d(_ReplicationPadNd): + r"""Pads the input tensor using replication of the input boundary. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 6-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, + :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, + :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + Note that the output dimensions must remain positive. + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, + where + + :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}` + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.ReplicationPad3d(3) + >>> input = torch.randn(16, 3, 8, 320, 480) + >>> output = m(input) + >>> # using different paddings for different sides + >>> m = nn.ReplicationPad3d((3, 3, 6, 6, 1, 1)) + >>> output = m(input) + """ + + # pyrefly: ignore [bad-override] + padding: tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t) -> None: + super().__init__() + self.padding = _ntuple(6)(padding) + + +class ZeroPad1d(ConstantPad1d): + r"""Pads the input tensor boundaries with zero. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in both boundaries. If a 2-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + + Shape: + - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. + - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.ZeroPad1d(2) + >>> input = torch.randn(1, 2, 4) + >>> input + tensor([[[-1.0491, -0.7152, -0.0749, 0.8530], + [-1.3287, 1.8966, 0.1466, -0.2771]]]) + >>> m(input) + tensor([[[ 0.0000, 0.0000, -1.0491, -0.7152, -0.0749, 0.8530, 0.0000, + 0.0000], + [ 0.0000, 0.0000, -1.3287, 1.8966, 0.1466, -0.2771, 0.0000, + 0.0000]]]) + >>> m = nn.ZeroPad1d(2) + >>> input = torch.randn(1, 2, 3) + >>> input + tensor([[[ 1.6616, 1.4523, -1.1255], + [-3.6372, 0.1182, -1.8652]]]) + >>> m(input) + tensor([[[ 0.0000, 0.0000, 1.6616, 1.4523, -1.1255, 0.0000, 0.0000], + [ 0.0000, 0.0000, -3.6372, 0.1182, -1.8652, 0.0000, 0.0000]]]) + >>> # using different paddings for different sides + >>> m = nn.ZeroPad1d((3, 1)) + >>> m(input) + tensor([[[ 0.0000, 0.0000, 0.0000, 1.6616, 1.4523, -1.1255, 0.0000], + [ 0.0000, 0.0000, 0.0000, -3.6372, 0.1182, -1.8652, 0.0000]]]) + """ + + padding: tuple[int, int] + + def __init__(self, padding: _size_2_t) -> None: + super().__init__(padding, 0.0) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"{self.padding}" + + +class ZeroPad2d(ConstantPad2d): + r"""Pads the input tensor boundaries with zero. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, + :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = nn.ZeroPad2d(2) + >>> input = torch.randn(1, 1, 3, 3) + >>> input + tensor([[[[-0.1678, -0.4418, 1.9466], + [ 0.9604, -0.4219, -0.5241], + [-0.9162, -0.5436, -0.6446]]]]) + >>> m(input) + tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.1678, -0.4418, 1.9466, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.9604, -0.4219, -0.5241, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.9162, -0.5436, -0.6446, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) + >>> # using different paddings for different sides + >>> m = nn.ZeroPad2d((1, 1, 2, 0)) + >>> m(input) + tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, -0.1678, -0.4418, 1.9466, 0.0000], + [ 0.0000, 0.9604, -0.4219, -0.5241, 0.0000], + [ 0.0000, -0.9162, -0.5436, -0.6446, 0.0000]]]]) + """ + + padding: tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t) -> None: + super().__init__(padding, 0.0) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"{self.padding}" + + +class ZeroPad3d(ConstantPad3d): + r"""Pads the input tensor boundaries with zero. + + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. + + Args: + padding (int, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 6-`tuple`, uses + (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, + :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, + :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or + :math:`(C, D_{out}, H_{out}, W_{out})`, where + + :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}` + + :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}` + + :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}` + + Examples:: + + >>> m = nn.ZeroPad3d(3) + >>> input = torch.randn(16, 3, 10, 20, 30) + >>> output = m(input) + >>> # using different paddings for different sides + >>> m = nn.ZeroPad3d((3, 3, 6, 6, 0, 1)) + >>> output = m(input) + """ + + padding: tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t) -> None: + super().__init__(padding, 0.0) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"{self.padding}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/pixelshuffle.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/pixelshuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..74c9e0878f0b5ecc48878c63115aafc2128b3afd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/pixelshuffle.py @@ -0,0 +1,127 @@ +import torch.nn.functional as F +from torch import Tensor + +from .module import Module + + +__all__ = ["PixelShuffle", "PixelUnshuffle"] + + +class PixelShuffle(Module): + r"""Rearrange elements in a tensor according to an upscaling factor. + + Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` + to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor. + + This is useful for implementing efficient sub-pixel convolution + with a stride of :math:`1/r`. + + See the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ + by Shi et al. (2016) for more details. + + Args: + upscale_factor (int): factor to increase spatial resolution by + + Shape: + - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions + - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where + + .. math:: + C_{out} = C_{in} \div \text{upscale\_factor}^2 + + .. math:: + H_{out} = H_{in} \times \text{upscale\_factor} + + .. math:: + W_{out} = W_{in} \times \text{upscale\_factor} + + Examples:: + + >>> pixel_shuffle = nn.PixelShuffle(3) + >>> input = torch.randn(1, 9, 4, 4) + >>> output = pixel_shuffle(input) + >>> print(output.size()) + torch.Size([1, 1, 12, 12]) + + .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network: + https://arxiv.org/abs/1609.05158 + """ + + __constants__ = ["upscale_factor"] + upscale_factor: int + + def __init__(self, upscale_factor: int) -> None: + super().__init__() + self.upscale_factor = upscale_factor + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.pixel_shuffle(input, self.upscale_factor) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"upscale_factor={self.upscale_factor}" + + +class PixelUnshuffle(Module): + r"""Reverse the PixelShuffle operation. + + Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements + in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape + :math:`(*, C \times r^2, H, W)`, where r is a downscale factor. + + See the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ + by Shi et al. (2016) for more details. + + Args: + downscale_factor (int): factor to decrease spatial resolution by + + Shape: + - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions + - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where + + .. math:: + C_{out} = C_{in} \times \text{downscale\_factor}^2 + + .. math:: + H_{out} = H_{in} \div \text{downscale\_factor} + + .. math:: + W_{out} = W_{in} \div \text{downscale\_factor} + + Examples:: + + >>> pixel_unshuffle = nn.PixelUnshuffle(3) + >>> input = torch.randn(1, 1, 12, 12) + >>> output = pixel_unshuffle(input) + >>> print(output.size()) + torch.Size([1, 9, 4, 4]) + + .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network: + https://arxiv.org/abs/1609.05158 + """ + + __constants__ = ["downscale_factor"] + downscale_factor: int + + def __init__(self, downscale_factor: int) -> None: + super().__init__() + self.downscale_factor = downscale_factor + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.pixel_unshuffle(input, self.downscale_factor) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + return f"downscale_factor={self.downscale_factor}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/pooling.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc57c25b168396fa9ceff5b32fd368befa094af --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/pooling.py @@ -0,0 +1,1550 @@ +import torch.nn.functional as F +from torch import Tensor +from torch.nn.common_types import ( + _ratio_2_t, + _ratio_3_t, + _size_1_t, + _size_2_opt_t, + _size_2_t, + _size_3_opt_t, + _size_3_t, + _size_any_opt_t, + _size_any_t, +) + +from .module import Module +from .utils import _pair, _single, _triple + + +__all__ = [ + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "MaxUnpool1d", + "MaxUnpool2d", + "MaxUnpool3d", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + "FractionalMaxPool2d", + "FractionalMaxPool3d", + "LPPool1d", + "LPPool2d", + "LPPool3d", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", +] + + +class _MaxPoolNd(Module): + __constants__ = [ + "kernel_size", + "stride", + "padding", + "dilation", + "return_indices", + "ceil_mode", + ] + return_indices: bool + ceil_mode: bool + + def __init__( + self, + kernel_size: _size_any_t, + stride: _size_any_t | None = None, + padding: _size_any_t = 0, + dilation: _size_any_t = 1, + return_indices: bool = False, + ceil_mode: bool = False, + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride if (stride is not None) else kernel_size + self.padding = padding + self.dilation = dilation + self.return_indices = return_indices + self.ceil_mode = ceil_mode + + def extra_repr(self) -> str: + return ( + "kernel_size={kernel_size}, stride={stride}, padding={padding}" + ", dilation={dilation}, ceil_mode={ceil_mode}".format(**self.__dict__) + ) + + +class MaxPool1d(_MaxPoolNd): + r"""Applies a 1D max pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, L)` + and output :math:`(N, C, L_{out})` can be precisely described as: + + .. math:: + out(N_i, C_j, k) = \max_{m=0, \ldots, \text{kernel\_size} - 1} + input(N_i, C_j, stride \times k + m) + + If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides + for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the + sliding window. This `link`_ has a nice visualization of the pooling parameters. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + Args: + kernel_size: The size of the sliding window, must be > 0. + stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`. + padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. + dilation: The stride between elements within a sliding window, must be > 0. + return_indices: If ``True``, will return the argmax along with the max values. + Useful for :class:`torch.nn.MaxUnpool1d` later + ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This + ensures that every element in the input tensor is covered by a sliding window. + + Shape: + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, + + where ``ceil_mode = False`` + + .. math:: + L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation} + \times (\text{kernel\_size} - 1) - 1}{\text{stride}}\right\rfloor + 1 + + where ``ceil_mode = True`` + + .. math:: + L_{out} = \left\lceil \frac{L_{in} + 2 \times \text{padding} - \text{dilation} + \times (\text{kernel\_size} - 1) - 1 + (stride - 1)}{\text{stride}}\right\rceil + 1 + + - Ensure that the last pooling starts inside the image, make :math:`L_{out} = L_{out} - 1` + when :math:`(L_{out} - 1) * \text{stride} >= L_{in} + \text{padding}`. + + Examples:: + + >>> # pool of size=3, stride=2 + >>> m = nn.MaxPool1d(3, stride=2) + >>> input = torch.randn(20, 16, 50) + >>> output = m(input) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + kernel_size: _size_1_t + stride: _size_1_t + padding: _size_1_t + dilation: _size_1_t + + def forward(self, input: Tensor): + """Runs the forward pass.""" + return F.max_pool1d( + input, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ceil_mode=self.ceil_mode, + return_indices=self.return_indices, + ) + + +class MaxPool2d(_MaxPoolNd): + r"""Applies a 2D max pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, + output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)` + can be precisely described as: + + .. math:: + \begin{aligned} + out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\ + & \text{input}(N_i, C_j, \text{stride[0]} \times h + m, + \text{stride[1]} \times w + n) + \end{aligned} + + If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides + for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimension + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + Args: + kernel_size: the size of the window to take a max over + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on both sides + dilation: a parameter that controls the stride of elements in the window + return_indices: if ``True``, will return the max indices along with the outputs. + Useful for :class:`torch.nn.MaxUnpool2d` later + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]} + \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]} + \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor + + Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> m = nn.MaxPool2d(3, stride=2) + >>> # pool of non-square window + >>> m = nn.MaxPool2d((3, 2), stride=(2, 1)) + >>> input = torch.randn(20, 16, 50, 32) + >>> output = m(input) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + kernel_size: _size_2_t + stride: _size_2_t + padding: _size_2_t + dilation: _size_2_t + + def forward(self, input: Tensor): + """Runs the forward pass.""" + return F.max_pool2d( + input, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ceil_mode=self.ceil_mode, + return_indices=self.return_indices, + ) + + +class MaxPool3d(_MaxPoolNd): + r"""Applies a 3D max pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`, + output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)` + can be precisely described as: + + .. math:: + \begin{aligned} + \text{out}(N_i, C_j, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\ + & \text{input}(N_i, C_j, \text{stride[0]} \times d + k, + \text{stride[1]} \times h + m, \text{stride[2]} \times w + n) + \end{aligned} + + If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides + for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimension + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + Args: + kernel_size: the size of the window to take a max over + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on all three sides + dilation: a parameter that controls the stride of elements in the window + return_indices: if ``True``, will return the max indices along with the outputs. + Useful for :class:`torch.nn.MaxUnpool3d` later + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times + (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times + (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times + (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor + + Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> m = nn.MaxPool3d(3, stride=2) + >>> # pool of non-square window + >>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2)) + >>> input = torch.randn(20, 16, 50, 44, 31) + >>> output = m(input) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + kernel_size: _size_3_t + stride: _size_3_t + padding: _size_3_t + dilation: _size_3_t + + def forward(self, input: Tensor): + """Runs the forward pass.""" + return F.max_pool3d( + input, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ceil_mode=self.ceil_mode, + return_indices=self.return_indices, + ) + + +class _MaxUnpoolNd(Module): + def extra_repr(self) -> str: + return f"kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}" + + +class MaxUnpool1d(_MaxUnpoolNd): + r"""Computes a partial inverse of :class:`MaxPool1d`. + + :class:`MaxPool1d` is not fully invertible, since the non-maximal values are lost. + + :class:`MaxUnpool1d` takes in as input the output of :class:`MaxPool1d` + including the indices of the maximal values and computes a partial inverse + in which all non-maximal values are set to zero. + + Note: + This operation may behave nondeterministically when the input indices has repeat values. + See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information. + + .. note:: :class:`MaxPool1d` can map several input sizes to the same output + sizes. Hence, the inversion process can get ambiguous. + To accommodate this, you can provide the needed output size + as an additional argument :attr:`output_size` in the forward call. + See the Inputs and Example below. + + Args: + kernel_size (int or tuple): Size of the max pooling window. + stride (int or tuple): Stride of the max pooling window. + It is set to :attr:`kernel_size` by default. + padding (int or tuple): Padding that was added to the input + + Inputs: + - `input`: the input Tensor to invert + - `indices`: the indices given out by :class:`~torch.nn.MaxPool1d` + - `output_size` (optional): the targeted output size + + Shape: + - Input: :math:`(N, C, H_{in})` or :math:`(C, H_{in})`. + - Output: :math:`(N, C, H_{out})` or :math:`(C, H_{out})`, where + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0] + + or as given by :attr:`output_size` in the call operator + + Example:: + + >>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?") + >>> pool = nn.MaxPool1d(2, stride=2, return_indices=True) + >>> unpool = nn.MaxUnpool1d(2, stride=2) + >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8]]]) + >>> output, indices = pool(input) + >>> unpool(output, indices) + tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]]) + + >>> # Example showcasing the use of output_size + >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8, 9]]]) + >>> output, indices = pool(input) + >>> unpool(output, indices, output_size=input.size()) + tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8., 0.]]]) + + >>> unpool(output, indices) + tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]]) + """ + + kernel_size: _size_1_t + stride: _size_1_t + padding: _size_1_t + + def __init__( + self, + kernel_size: _size_1_t, + stride: _size_1_t | None = None, + padding: _size_1_t = 0, + ) -> None: + super().__init__() + self.kernel_size = _single(kernel_size) + self.stride = _single(stride if (stride is not None) else kernel_size) + self.padding = _single(padding) + + def forward( + self, input: Tensor, indices: Tensor, output_size: list[int] | None = None + ) -> Tensor: + """Runs the forward pass.""" + return F.max_unpool1d( + input, indices, self.kernel_size, self.stride, self.padding, output_size + ) + + +class MaxUnpool2d(_MaxUnpoolNd): + r"""Computes a partial inverse of :class:`MaxPool2d`. + + :class:`MaxPool2d` is not fully invertible, since the non-maximal values are lost. + + :class:`MaxUnpool2d` takes in as input the output of :class:`MaxPool2d` + including the indices of the maximal values and computes a partial inverse + in which all non-maximal values are set to zero. + + Note: + This operation may behave nondeterministically when the input indices has repeat values. + See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information. + + .. note:: :class:`MaxPool2d` can map several input sizes to the same output + sizes. Hence, the inversion process can get ambiguous. + To accommodate this, you can provide the needed output size + as an additional argument :attr:`output_size` in the forward call. + See the Inputs and Example below. + + Args: + kernel_size (int or tuple): Size of the max pooling window. + stride (int or tuple): Stride of the max pooling window. + It is set to :attr:`kernel_size` by default. + padding (int or tuple): Padding that was added to the input + + Inputs: + - `input`: the input Tensor to invert + - `indices`: the indices given out by :class:`~torch.nn.MaxPool2d` + - `output_size` (optional): the targeted output size + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + + or as given by :attr:`output_size` in the call operator + + Example:: + + >>> pool = nn.MaxPool2d(2, stride=2, return_indices=True) + >>> unpool = nn.MaxUnpool2d(2, stride=2) + >>> input = torch.tensor([[[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.]]]]) + >>> output, indices = pool(input) + >>> unpool(output, indices) + tensor([[[[ 0., 0., 0., 0.], + [ 0., 6., 0., 8.], + [ 0., 0., 0., 0.], + [ 0., 14., 0., 16.]]]]) + >>> # Now using output_size to resolve an ambiguous size for the inverse + >>> input = torch.tensor([[[[ 1., 2., 3., 4., 5.], + [ 6., 7., 8., 9., 10.], + [11., 12., 13., 14., 15.], + [16., 17., 18., 19., 20.]]]]) + >>> output, indices = pool(input) + >>> # This call will not work without specifying output_size + >>> unpool(output, indices, output_size=input.size()) + tensor([[[[ 0., 0., 0., 0., 0.], + [ 0., 7., 0., 9., 0.], + [ 0., 0., 0., 0., 0.], + [ 0., 17., 0., 19., 0.]]]]) + + + """ + + kernel_size: _size_2_t + stride: _size_2_t + padding: _size_2_t + + def __init__( + self, + kernel_size: _size_2_t, + stride: _size_2_t | None = None, + padding: _size_2_t = 0, + ) -> None: + super().__init__() + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride if (stride is not None) else kernel_size) + self.padding = _pair(padding) + + def forward( + self, input: Tensor, indices: Tensor, output_size: list[int] | None = None + ) -> Tensor: + """Runs the forward pass.""" + return F.max_unpool2d( + input, indices, self.kernel_size, self.stride, self.padding, output_size + ) + + +class MaxUnpool3d(_MaxUnpoolNd): + r"""Computes a partial inverse of :class:`MaxPool3d`. + + :class:`MaxPool3d` is not fully invertible, since the non-maximal values are lost. + :class:`MaxUnpool3d` takes in as input the output of :class:`MaxPool3d` + including the indices of the maximal values and computes a partial inverse + in which all non-maximal values are set to zero. + + Note: + This operation may behave nondeterministically when the input indices has repeat values. + See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information. + + .. note:: :class:`MaxPool3d` can map several input sizes to the same output + sizes. Hence, the inversion process can get ambiguous. + To accommodate this, you can provide the needed output size + as an additional argument :attr:`output_size` in the forward call. + See the Inputs section below. + + Args: + kernel_size (int or tuple): Size of the max pooling window. + stride (int or tuple): Stride of the max pooling window. + It is set to :attr:`kernel_size` by default. + padding (int or tuple): Padding that was added to the input + + Inputs: + - `input`: the input Tensor to invert + - `indices`: the indices given out by :class:`~torch.nn.MaxPool3d` + - `output_size` (optional): the targeted output size + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} + + or as given by :attr:`output_size` in the call operator + + Example:: + + >>> # pool of square window of size=3, stride=2 + >>> pool = nn.MaxPool3d(3, stride=2, return_indices=True) + >>> unpool = nn.MaxUnpool3d(3, stride=2) + >>> output, indices = pool(torch.randn(20, 16, 51, 33, 15)) + >>> unpooled_output = unpool(output, indices) + >>> unpooled_output.size() + torch.Size([20, 16, 51, 33, 15]) + """ + + kernel_size: _size_3_t + stride: _size_3_t + padding: _size_3_t + + def __init__( + self, + kernel_size: _size_3_t, + stride: _size_3_t | None = None, + padding: _size_3_t = 0, + ) -> None: + super().__init__() + self.kernel_size = _triple(kernel_size) + self.stride = _triple(stride if (stride is not None) else kernel_size) + self.padding = _triple(padding) + + def forward( + self, input: Tensor, indices: Tensor, output_size: list[int] | None = None + ) -> Tensor: + """Runs the forward pass.""" + return F.max_unpool3d( + input, indices, self.kernel_size, self.stride, self.padding, output_size + ) + + +class _AvgPoolNd(Module): + __constants__ = [ + "kernel_size", + "stride", + "padding", + "ceil_mode", + "count_include_pad", + ] + + def extra_repr(self) -> str: + return f"kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}" + + +class AvgPool1d(_AvgPoolNd): + r"""Applies a 1D average pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, L)`, + output :math:`(N, C, L_{out})` and :attr:`kernel_size` :math:`k` + can be precisely described as: + + .. math:: + + \text{out}(N_i, C_j, l) = \frac{1}{k} \sum_{m=0}^{k-1} + \text{input}(N_i, C_j, \text{stride} \times l + m) + + If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + for :attr:`padding` number of points. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + .. note:: + pad should be at most half of effective kernel size. + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can each be + an ``int`` or a one-element tuple. + + Args: + kernel_size: the size of the window + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: implicit zero padding to be added on both sides + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + count_include_pad: when True, will include the zero-padding in the averaging calculation + + Shape: + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where + + .. math:: + L_{out} = \left\lfloor \frac{L_{in} + + 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor + + Per the note above, if ``ceil_mode`` is True and :math:`(L_{out} - 1) \times \text{stride} \geq L_{in} + + \text{padding}`, we skip the last window as it would start in the right padded region, resulting in + :math:`L_{out}` being reduced by one. + + Examples:: + + >>> # pool with window of size=3, stride=2 + >>> m = nn.AvgPool1d(3, stride=2) + >>> m(torch.tensor([[[1., 2, 3, 4, 5, 6, 7]]])) + tensor([[[2., 4., 6.]]]) + """ + + kernel_size: _size_1_t + stride: _size_1_t + padding: _size_1_t + ceil_mode: bool + count_include_pad: bool + + def __init__( + self, + kernel_size: _size_1_t, + stride: _size_1_t = None, + padding: _size_1_t = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + ) -> None: + super().__init__() + self.kernel_size = _single(kernel_size) + self.stride = _single(stride if stride is not None else kernel_size) + self.padding = _single(padding) + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + + def forward(self, input: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.avg_pool1d( + input, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + ) + + +class AvgPool2d(_AvgPoolNd): + r"""Applies a 2D average pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, + output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)` + can be precisely described as: + + .. math:: + + out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} + input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) + + If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides + for :attr:`padding` number of points. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + .. note:: + pad should be at most half of effective kernel size. + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can either be: + + - a single ``int`` or a single-element tuple -- in which case the same value is used for the height and width dimension + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + Args: + kernel_size: the size of the window + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: implicit zero padding to be added on both sides + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + count_include_pad: when True, will include the zero-padding in the averaging calculation + divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. + + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - + \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - + \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor + + Per the note above, if ``ceil_mode`` is True and :math:`(H_{out} - 1)\times \text{stride}[0]\geq H_{in} + + \text{padding}[0]`, we skip the last window as it would start in the bottom padded region, + resulting in :math:`H_{out}` being reduced by one. + + The same applies for :math:`W_{out}`. + + Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> m = nn.AvgPool2d(3, stride=2) + >>> # pool of non-square window + >>> m = nn.AvgPool2d((3, 2), stride=(2, 1)) + >>> input = torch.randn(20, 16, 50, 32) + >>> output = m(input) + """ + + __constants__ = [ + "kernel_size", + "stride", + "padding", + "ceil_mode", + "count_include_pad", + "divisor_override", + ] + + kernel_size: _size_2_t + stride: _size_2_t + padding: _size_2_t + ceil_mode: bool + count_include_pad: bool + + def __init__( + self, + kernel_size: _size_2_t, + stride: _size_2_t | None = None, + padding: _size_2_t = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: int | None = None, + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride if (stride is not None) else kernel_size + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + self.divisor_override = divisor_override + + def forward(self, input: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.avg_pool2d( + input, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + self.divisor_override, + ) + + +class AvgPool3d(_AvgPoolNd): + r"""Applies a 3D average pooling over an input signal composed of several input planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`, + output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)` + can be precisely described as: + + .. math:: + \begin{aligned} + \text{out}(N_i, C_j, d, h, w) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\ + & \frac{\text{input}(N_i, C_j, \text{stride}[0] \times d + k, + \text{stride}[1] \times h + m, \text{stride}[2] \times w + n)} + {kD \times kH \times kW} + \end{aligned} + + If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides + for :attr:`padding` number of points. + + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + + .. note:: + pad should be at most half of effective kernel size. + + The parameters :attr:`kernel_size`, :attr:`stride` can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimension + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + Args: + kernel_size: the size of the window + stride: the stride of the window. Default value is :attr:`kernel_size` + padding: implicit zero padding to be added on all three sides + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + count_include_pad: when True, will include the zero-padding in the averaging calculation + divisor_override: if specified, it will be used as divisor, otherwise :attr:`kernel_size` will be used + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or + :math:`(C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - + \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - + \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - + \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor + + Per the note above, if ``ceil_mode`` is True and :math:`(D_{out} - 1)\times \text{stride}[0]\geq D_{in} + + \text{padding}[0]`, we skip the last window as it would start in the padded region, + resulting in :math:`D_{out}` being reduced by one. + + The same applies for :math:`W_{out}` and :math:`H_{out}`. + + Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> m = nn.AvgPool3d(3, stride=2) + >>> # pool of non-square window + >>> m = nn.AvgPool3d((3, 2, 2), stride=(2, 1, 2)) + >>> input = torch.randn(20, 16, 50, 44, 31) + >>> output = m(input) + """ + + __constants__ = [ + "kernel_size", + "stride", + "padding", + "ceil_mode", + "count_include_pad", + "divisor_override", + ] + + kernel_size: _size_3_t + stride: _size_3_t + padding: _size_3_t + ceil_mode: bool + count_include_pad: bool + + def __init__( + self, + kernel_size: _size_3_t, + stride: _size_3_t | None = None, + padding: _size_3_t = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: int | None = None, + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride if (stride is not None) else kernel_size + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + self.divisor_override = divisor_override + + def forward(self, input: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.avg_pool3d( + input, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + self.divisor_override, + ) + + def __setstate__(self, d): + super().__setstate__(d) + self.__dict__.setdefault("padding", 0) + self.__dict__.setdefault("ceil_mode", False) + self.__dict__.setdefault("count_include_pad", True) + + +class FractionalMaxPool2d(Module): + r"""Applies a 2D fractional max pooling over an input signal composed of several input planes. + + Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham + + The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic + step size determined by the target output size. + The number of output features is equal to the number of input planes. + + .. note:: Exactly one of ``output_size`` or ``output_ratio`` must be defined. + + Args: + kernel_size: the size of the window to take a max over. + Can be a single number k (for a square kernel of k x k) or a tuple `(kh, kw)` + output_size: the target output size of the image of the form `oH x oW`. + Can be a tuple `(oH, oW)` or a single number oH for a square image `oH x oH`. + Note that we must have :math:`kH + oH - 1 <= H_{in}` and :math:`kW + oW - 1 <= W_{in}` + output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. + This has to be a number or tuple in the range (0, 1). + Note that we must have :math:`kH + (output\_ratio\_H * H_{in}) - 1 <= H_{in}` + and :math:`kW + (output\_ratio\_W * W_{in}) - 1 <= W_{in}` + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to :meth:`nn.MaxUnpool2d`. Default: ``False`` + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + :math:`(H_{out}, W_{out})=\text{output\_size}` or + :math:`(H_{out}, W_{out})=\text{output\_ratio} \times (H_{in}, W_{in})`. + + Examples: + >>> # pool of square window of size=3, and target output size 13x12 + >>> m = nn.FractionalMaxPool2d(3, output_size=(13, 12)) + >>> # pool of square window and target output size being half of input image size + >>> m = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5)) + >>> input = torch.randn(20, 16, 50, 32) + >>> output = m(input) + + .. _Fractional MaxPooling: + https://arxiv.org/abs/1412.6071 + """ + + __constants__ = ["kernel_size", "return_indices", "output_size", "output_ratio"] + + kernel_size: _size_2_t + return_indices: bool + output_size: _size_2_t + output_ratio: _ratio_2_t + + def __init__( + self, + kernel_size: _size_2_t, + output_size: _size_2_t | None = None, + output_ratio: _ratio_2_t | None = None, + return_indices: bool = False, + _random_samples=None, + ) -> None: + super().__init__() + self.kernel_size = _pair(kernel_size) + self.return_indices = return_indices + self.register_buffer("_random_samples", _random_samples) + self.output_size = _pair(output_size) if output_size is not None else None + self.output_ratio = _pair(output_ratio) if output_ratio is not None else None + if output_size is None and output_ratio is None: + raise ValueError( + "FractionalMaxPool2d requires specifying either " + "an output size, or a pooling ratio" + ) + if output_size is not None and output_ratio is not None: + raise ValueError( + "only one of output_size and output_ratio may be specified" + ) + if self.output_ratio is not None: + if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1): + raise ValueError( + f"output_ratio must be between 0 and 1 (got {output_ratio})" + ) + + def forward(self, input: Tensor): + return F.fractional_max_pool2d( + input, + self.kernel_size, + self.output_size, + self.output_ratio, + self.return_indices, + _random_samples=self._random_samples, + ) + + +class FractionalMaxPool3d(Module): + r"""Applies a 3D fractional max pooling over an input signal composed of several input planes. + + Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham + + The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic + step size determined by the target output size. + The number of output features is equal to the number of input planes. + + .. note:: Exactly one of ``output_size`` or ``output_ratio`` must be defined. + + Args: + kernel_size: the size of the window to take a max over. + Can be a single number `k` (for a square kernel of `k x k x k`) or a tuple `(kt x kh x kw)`, + `k` must greater than 0. + output_size: the target output size of the image of the form `oT x oH x oW`. + Can be a tuple `(oT, oH, oW)` or a single number oH for a square image `oH x oH x oH` + output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. + This has to be a number or tuple in the range (0, 1) + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to :meth:`nn.MaxUnpool3d`. Default: ``False`` + + Shape: + - Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where + :math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or + :math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})` + + Examples: + >>> # pool of cubic window of size=3, and target output size 13x12x11 + >>> m = nn.FractionalMaxPool3d(3, output_size=(13, 12, 11)) + >>> # pool of cubic window and target output size being half of input size + >>> m = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5)) + >>> input = torch.randn(20, 16, 50, 32, 16) + >>> output = m(input) + + .. _Fractional MaxPooling: + https://arxiv.org/abs/1412.6071 + """ + + __constants__ = ["kernel_size", "return_indices", "output_size", "output_ratio"] + kernel_size: _size_3_t + return_indices: bool + output_size: _size_3_t + output_ratio: _ratio_3_t + + def __init__( + self, + kernel_size: _size_3_t, + output_size: _size_3_t | None = None, + output_ratio: _ratio_3_t | None = None, + return_indices: bool = False, + _random_samples=None, + ) -> None: + super().__init__() + if (isinstance(kernel_size, int) and kernel_size <= 0) or ( + isinstance(kernel_size, (tuple, list)) + and not all(k > 0 for k in kernel_size) + ): + raise ValueError(f"kernel_size must greater than 0, but got {kernel_size}") + self.kernel_size = _triple(kernel_size) + self.return_indices = return_indices + self.register_buffer("_random_samples", _random_samples) + self.output_size = _triple(output_size) if output_size is not None else None + self.output_ratio = _triple(output_ratio) if output_ratio is not None else None + if output_size is None and output_ratio is None: + raise ValueError( + "FractionalMaxPool3d requires specifying either " + "an output size, or a pooling ratio" + ) + if output_size is not None and output_ratio is not None: + raise ValueError( + "only one of output_size and output_ratio may be specified" + ) + if self.output_ratio is not None: + if not ( + 0 < self.output_ratio[0] < 1 + and 0 < self.output_ratio[1] < 1 + and 0 < self.output_ratio[2] < 1 + ): + raise ValueError( + f"output_ratio must be between 0 and 1 (got {output_ratio})" + ) + + def forward(self, input: Tensor): + return F.fractional_max_pool3d( + input, + self.kernel_size, + self.output_size, + self.output_ratio, + self.return_indices, + _random_samples=self._random_samples, + ) + + +class _LPPoolNd(Module): + __constants__ = ["norm_type", "kernel_size", "stride", "ceil_mode"] + + norm_type: float + ceil_mode: bool + + def __init__( + self, + norm_type: float, + kernel_size: _size_any_t, + stride: _size_any_t | None = None, + ceil_mode: bool = False, + ) -> None: + super().__init__() + self.norm_type = norm_type + self.kernel_size = kernel_size + self.stride = stride + self.ceil_mode = ceil_mode + + def extra_repr(self) -> str: + return ( + "norm_type={norm_type}, kernel_size={kernel_size}, stride={stride}, " + "ceil_mode={ceil_mode}".format(**self.__dict__) + ) + + +class LPPool1d(_LPPoolNd): + r"""Applies a 1D power-average pooling over an input signal composed of several input planes. + + On each window, the function computed is: + + .. math:: + f(X) = \sqrt[p]{\sum_{x \in X} x^{p}} + + - At p = :math:`\infty`, one gets Max Pooling + - At p = 1, one gets Sum Pooling (which is proportional to Average Pooling) + + .. note:: If the sum to the power of `p` is zero, the gradient of this function is + not defined. This implementation will set the gradient to zero in this case. + + Args: + kernel_size: a single int, the size of the window + stride: a single int, the stride of the window. Default value is :attr:`kernel_size` + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Note: + When :attr:`ceil_mode` is ``True``, sliding windows may go off-bounds if they start within the + left padding or the input. Sliding windows that would start in the right padded region are ignored. + + Shape: + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where + + .. math:: + L_{out} = \left\lfloor\frac{L_{in} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor + + Examples:: + >>> # power-2 pool of window of length 3, with stride 2. + >>> m = nn.LPPool1d(2, 3, stride=2) + >>> input = torch.randn(20, 16, 50) + >>> output = m(input) + """ + + kernel_size: _size_1_t + stride: _size_1_t + + def forward(self, input: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.lp_pool1d( + input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode + ) + + +class LPPool2d(_LPPoolNd): + r"""Applies a 2D power-average pooling over an input signal composed of several input planes. + + On each window, the function computed is: + + .. math:: + f(X) = \sqrt[p]{\sum_{x \in X} x^{p}} + + - At p = :math:`\infty`, one gets Max Pooling + - At p = 1, one gets Sum Pooling (which is proportional to average pooling) + + The parameters :attr:`kernel_size`, :attr:`stride` can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimension + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + .. note:: If the sum to the power of `p` is zero, the gradient of this function is + not defined. This implementation will set the gradient to zero in this case. + + Args: + kernel_size: the size of the window + stride: the stride of the window. Default value is :attr:`kernel_size` + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Note: + When :attr:`ceil_mode` is ``True``, sliding windows may go off-bounds if they start within the + left padding or the input. Sliding windows that would start in the right padded region are ignored. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor + + Examples:: + + >>> # power-2 pool of square window of size=3, stride=2 + >>> m = nn.LPPool2d(2, 3, stride=2) + >>> # pool of non-square window of power 1.2 + >>> m = nn.LPPool2d(1.2, (3, 2), stride=(2, 1)) + >>> input = torch.randn(20, 16, 50, 32) + >>> output = m(input) + + """ + + kernel_size: _size_2_t + stride: _size_2_t + + def forward(self, input: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.lp_pool2d( + input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode + ) + + +class LPPool3d(_LPPoolNd): + r"""Applies a 3D power-average pooling over an input signal composed of several input planes. + + On each window, the function computed is: + + .. math:: + f(X) = \sqrt[p]{\sum_{x \in X} x^{p}} + + - At p = :math:`\infty`, one gets Max Pooling + - At p = 1, one gets Sum Pooling (which is proportional to average pooling) + + The parameters :attr:`kernel_size`, :attr:`stride` can either be: + + - a single ``int`` -- in which case the same value is used for the height, width and depth dimension + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + .. note:: If the sum to the power of `p` is zero, the gradient of this function is + not defined. This implementation will set the gradient to zero in this case. + + Args: + kernel_size: the size of the window + stride: the stride of the window. Default value is :attr:`kernel_size` + ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + + Note: + When :attr:`ceil_mode` is ``True``, sliding windows may go off-bounds if they start within the + left padding or the input. Sliding windows that would start in the right padded region are ignored. + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or + :math:`(C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = \left\lfloor\frac{D_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor + + Examples:: + + >>> # power-2 pool of square window of size=3, stride=2 + >>> m = nn.LPPool3d(2, 3, stride=2) + >>> # pool of non-square window of power 1.2 + >>> m = nn.LPPool3d(1.2, (3, 2, 2), stride=(2, 1, 2)) + >>> input = torch.randn(20, 16, 50, 44, 31) + >>> output = m(input) + + """ + + kernel_size: _size_3_t + stride: _size_3_t + + def forward(self, input: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.lp_pool3d( + input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode + ) + + +class _AdaptiveMaxPoolNd(Module): + __constants__ = ["output_size", "return_indices"] + return_indices: bool + + def __init__( + self, output_size: _size_any_opt_t, return_indices: bool = False + ) -> None: + super().__init__() + self.output_size = output_size + self.return_indices = return_indices + + def extra_repr(self) -> str: + return f"output_size={self.output_size}" + + +# FIXME (by @ssnl): Improve adaptive pooling docs: specify what the input and +# output shapes are, and how the operation computes output. + + +class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd): + r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes. + + The output size is :math:`L_{out}`, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size :math:`L_{out}`. + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to nn.MaxUnpool1d. Default: ``False`` + + Shape: + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where + :math:`L_{out}=\text{output\_size}`. + + Examples: + >>> # target output size of 5 + >>> m = nn.AdaptiveMaxPool1d(5) + >>> input = torch.randn(1, 64, 8) + >>> output = m(input) + + """ + + output_size: _size_1_t + + def forward(self, input: Tensor): + """Runs the forward pass.""" + return F.adaptive_max_pool1d(input, self.output_size, self.return_indices) + + +class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd): + r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes. + + The output is of size :math:`H_{out} \times W_{out}`, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size of the image of the form :math:`H_{out} \times W_{out}`. + Can be a tuple :math:`(H_{out}, W_{out})` or a single :math:`H_{out}` for a + square image :math:`H_{out} \times H_{out}`. :math:`H_{out}` and :math:`W_{out}` + can be either a ``int``, or ``None`` which means the size will be the same as that + of the input. + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to nn.MaxUnpool2d. Default: ``False`` + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where + :math:`(H_{out}, W_{out})=\text{output\_size}`. + + Examples: + >>> # target output size of 5x7 + >>> m = nn.AdaptiveMaxPool2d((5, 7)) + >>> input = torch.randn(1, 64, 8, 9) + >>> output = m(input) + >>> # target output size of 7x7 (square) + >>> m = nn.AdaptiveMaxPool2d(7) + >>> input = torch.randn(1, 64, 10, 9) + >>> output = m(input) + >>> # target output size of 10x7 + >>> m = nn.AdaptiveMaxPool2d((None, 7)) + >>> input = torch.randn(1, 64, 10, 9) + >>> output = m(input) + + """ + + output_size: _size_2_opt_t + + def forward(self, input: Tensor): + """Runs the forward pass.""" + return F.adaptive_max_pool2d(input, self.output_size, self.return_indices) + + +class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd): + r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes. + + The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size of the image of the form :math:`D_{out} \times H_{out} \times W_{out}`. + Can be a tuple :math:`(D_{out}, H_{out}, W_{out})` or a single + :math:`D_{out}` for a cube :math:`D_{out} \times D_{out} \times D_{out}`. + :math:`D_{out}`, :math:`H_{out}` and :math:`W_{out}` can be either a + ``int``, or ``None`` which means the size will be the same as that of the input. + + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to nn.MaxUnpool3d. Default: ``False`` + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, + where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`. + + Examples: + >>> # target output size of 5x7x9 + >>> m = nn.AdaptiveMaxPool3d((5, 7, 9)) + >>> input = torch.randn(1, 64, 8, 9, 10) + >>> output = m(input) + >>> # target output size of 7x7x7 (cube) + >>> m = nn.AdaptiveMaxPool3d(7) + >>> input = torch.randn(1, 64, 10, 9, 8) + >>> output = m(input) + >>> # target output size of 7x9x8 + >>> m = nn.AdaptiveMaxPool3d((7, None, None)) + >>> input = torch.randn(1, 64, 10, 9, 8) + >>> output = m(input) + + """ + + output_size: _size_3_opt_t + + def forward(self, input: Tensor): + """Runs the forward pass.""" + return F.adaptive_max_pool3d(input, self.output_size, self.return_indices) + + +class _AdaptiveAvgPoolNd(Module): + __constants__ = ["output_size"] + + def __init__(self, output_size: _size_any_opt_t) -> None: + super().__init__() + self.output_size = output_size + + def extra_repr(self) -> str: + return f"output_size={self.output_size}" + + +class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd): + r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes. + + The output size is :math:`L_{out}`, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size :math:`L_{out}`. + + Shape: + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where + :math:`L_{out}=\text{output\_size}`. + + Examples: + >>> # target output size of 5 + >>> m = nn.AdaptiveAvgPool1d(5) + >>> input = torch.randn(1, 64, 8) + >>> output = m(input) + + """ + + output_size: _size_1_t + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.adaptive_avg_pool1d(input, self.output_size) + + +class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd): + r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes. + + The output is of size H x W, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size of the image of the form H x W. + Can be a tuple (H, W) or a single H for a square image H x H. + H and W can be either a ``int``, or ``None`` which means the size will + be the same as that of the input. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. + - Output: :math:`(N, C, S_{0}, S_{1})` or :math:`(C, S_{0}, S_{1})`, where + :math:`S=\text{output\_size}`. + + Examples: + >>> # target output size of 5x7 + >>> m = nn.AdaptiveAvgPool2d((5, 7)) + >>> input = torch.randn(1, 64, 8, 9) + >>> output = m(input) + >>> # target output size of 7x7 (square) + >>> m = nn.AdaptiveAvgPool2d(7) + >>> input = torch.randn(1, 64, 10, 9) + >>> output = m(input) + >>> # target output size of 10x7 + >>> m = nn.AdaptiveAvgPool2d((None, 7)) + >>> input = torch.randn(1, 64, 10, 9) + >>> output = m(input) + + """ + + output_size: _size_2_opt_t + + def forward(self, input: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.adaptive_avg_pool2d(input, self.output_size) + + +class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd): + r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes. + + The output is of size D x H x W, for any input size. + The number of output features is equal to the number of input planes. + + Args: + output_size: the target output size of the form D x H x W. + Can be a tuple (D, H, W) or a single number D for a cube D x D x D. + D, H and W can be either a ``int``, or ``None`` which means the size will + be the same as that of the input. + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, S_{0}, S_{1}, S_{2})` or :math:`(C, S_{0}, S_{1}, S_{2})`, + where :math:`S=\text{output\_size}`. + + Examples: + >>> # target output size of 5x7x9 + >>> m = nn.AdaptiveAvgPool3d((5, 7, 9)) + >>> input = torch.randn(1, 64, 8, 9, 10) + >>> output = m(input) + >>> # target output size of 7x7x7 (cube) + >>> m = nn.AdaptiveAvgPool3d(7) + >>> input = torch.randn(1, 64, 10, 9, 8) + >>> output = m(input) + >>> # target output size of 7x9x8 + >>> m = nn.AdaptiveAvgPool3d((7, None, None)) + >>> input = torch.randn(1, 64, 10, 9, 8) + >>> output = m(input) + + """ + + output_size: _size_3_opt_t + + def forward(self, input: Tensor) -> Tensor: + """Runs the forward pass.""" + return F.adaptive_avg_pool3d(input, self.output_size) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/rnn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..68e8292870fc8a8d19ce3307294377b162c8b6fe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/rnn.py @@ -0,0 +1,1850 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math +import numbers +import warnings +import weakref +from typing import overload +from typing_extensions import deprecated + +import torch +from torch import _VF, Tensor +from torch.nn import init +from torch.nn.parameter import Parameter +from torch.nn.utils.rnn import PackedSequence + +from .module import Module + + +__all__ = [ + "RNNBase", + "RNN", + "LSTM", + "GRU", + "RNNCellBase", + "RNNCell", + "LSTMCell", + "GRUCell", +] + +_rnn_impls = { + "RNN_TANH": _VF.rnn_tanh, + "RNN_RELU": _VF.rnn_relu, +} + + +def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return tensor.index_select(dim, permutation) + + +@deprecated( + "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead", + category=FutureWarning, +) +def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return _apply_permutation(tensor, permutation, dim) + + +class RNNBase(Module): + r"""Base class for RNN modules (RNN, LSTM, GRU). + + Implements aspects of RNNs shared by the RNN, LSTM, and GRU classes, such as module initialization + and utility methods for parameter storage management. + + .. note:: + The forward method is not implemented by the RNNBase class. + + .. note:: + LSTM and GRU classes override some methods implemented by RNNBase. + """ + + __constants__ = [ + "mode", + "input_size", + "hidden_size", + "num_layers", + "bias", + "batch_first", + "dropout", + "bidirectional", + "proj_size", + ] + __jit_unused_properties__ = ["all_weights"] + + mode: str + input_size: int + hidden_size: int + num_layers: int + bias: bool + batch_first: bool + dropout: float + bidirectional: bool + proj_size: int + + def __init__( + self, + mode: str, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + proj_size: int = 0, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.proj_size = proj_size + self._flat_weight_refs: list[weakref.ReferenceType[Parameter] | None] = [] + num_directions = 2 if bidirectional else 1 + + if ( + not isinstance(dropout, numbers.Number) + or not 0 <= dropout <= 1 + or isinstance(dropout, bool) + ): + raise ValueError( + "dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed" + ) + if dropout > 0 and num_layers == 1: + warnings.warn( + "dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + f"num_layers greater than 1, but got dropout={dropout} and " + f"num_layers={num_layers}", + stacklevel=2, + ) + + if not isinstance(hidden_size, int): + raise TypeError( + f"hidden_size should be of type int, got: {type(hidden_size).__name__}" + ) + if hidden_size <= 0: + raise ValueError("hidden_size must be greater than zero") + if num_layers <= 0: + raise ValueError("num_layers must be greater than zero") + if proj_size < 0: + raise ValueError( + "proj_size should be a positive integer or zero to disable projections" + ) + if proj_size >= hidden_size: + raise ValueError("proj_size has to be smaller than hidden_size") + + if mode == "LSTM": + gate_size = 4 * hidden_size + elif mode == "GRU": + gate_size = 3 * hidden_size + elif mode == "RNN_TANH": + gate_size = hidden_size + elif mode == "RNN_RELU": + gate_size = hidden_size + else: + raise ValueError("Unrecognized RNN mode: " + mode) + + self._flat_weights_names = [] + self._all_weights = [] + for layer in range(num_layers): + for direction in range(num_directions): + real_hidden_size = proj_size if proj_size > 0 else hidden_size + layer_input_size = ( + input_size if layer == 0 else real_hidden_size * num_directions + ) + + w_ih = Parameter( + torch.empty((gate_size, layer_input_size), **factory_kwargs) + ) + w_hh = Parameter( + torch.empty((gate_size, real_hidden_size), **factory_kwargs) + ) + b_ih = Parameter(torch.empty(gate_size, **factory_kwargs)) + # Second bias vector included for CuDNN compatibility. Only one + # bias vector is needed in standard definition. + b_hh = Parameter(torch.empty(gate_size, **factory_kwargs)) + layer_params: tuple[Tensor, ...] = () + if self.proj_size == 0: + if bias: + layer_params = (w_ih, w_hh, b_ih, b_hh) + else: + layer_params = (w_ih, w_hh) + else: + w_hr = Parameter( + torch.empty((proj_size, hidden_size), **factory_kwargs) + ) + if bias: + layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr) + else: + layer_params = (w_ih, w_hh, w_hr) + + suffix = "_reverse" if direction == 1 else "" + param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"] + if bias: + param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"] + if self.proj_size > 0: + param_names += ["weight_hr_l{}{}"] + param_names = [x.format(layer, suffix) for x in param_names] + + for name, param in zip(param_names, layer_params, strict=True): + setattr(self, name, param) + self._flat_weights_names.extend(param_names) + self._all_weights.append(param_names) + + self._init_flat_weights() + + self.reset_parameters() + + def _init_flat_weights(self) -> None: + self._flat_weights = [ + getattr(self, wn) if hasattr(self, wn) else None + for wn in self._flat_weights_names + ] + self._flat_weight_refs = [ + weakref.ref(w) if w is not None else None for w in self._flat_weights + ] + self.flatten_parameters() + + def __setattr__(self, attr, value) -> None: + if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names: + # keep self._flat_weights up to date if you do self.weight = ... + idx = self._flat_weights_names.index(attr) + self._flat_weights[idx] = value + super().__setattr__(attr, value) + + def flatten_parameters(self) -> None: + """Reset parameter data pointer so that they can use faster code paths. + + Right now, this works only if the module is on the GPU and cuDNN is enabled. + Otherwise, it's a no-op. + """ + # Short-circuits if _flat_weights is only partially instantiated + if len(self._flat_weights) != len(self._flat_weights_names): + return + + for w in self._flat_weights: + if not isinstance(w, Tensor): + return + # Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN + # or the tensors in _flat_weights are of different dtypes + + first_fw = self._flat_weights[0] # type: ignore[union-attr] + dtype = first_fw.dtype # type: ignore[union-attr] + for fw in self._flat_weights: + if ( + not isinstance(fw, Tensor) + or fw.dtype != dtype + or not fw.is_cuda + or not torch.backends.cudnn.is_acceptable(fw) + ): + return + + # If any parameters alias, we fall back to the slower, copying code path. This is + # a sufficient check, because overlapping parameter buffers that don't completely + # alias would break the assumptions of the uniqueness check in + # Module.named_parameters(). + unique_data_ptrs = { + p.data_ptr() # type: ignore[union-attr] + for p in self._flat_weights + } + if len(unique_data_ptrs) != len(self._flat_weights): + return + + with torch.cuda.device_of(first_fw): + import torch.backends.cudnn.rnn as rnn + + # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is + # an inplace operation on self._flat_weights + with torch.no_grad(): + if torch._use_cudnn_rnn_flatten_weight(): + num_weights = 4 if self.bias else 2 + if self.proj_size > 0: + num_weights += 1 + torch._cudnn_rnn_flatten_weight( + self._flat_weights, # type: ignore[arg-type] + num_weights, + self.input_size, + rnn.get_cudnn_mode(self.mode), + self.hidden_size, + self.proj_size, + self.num_layers, + self.batch_first, + bool(self.bidirectional), + ) + + def _apply(self, fn, recurse=True): + self._flat_weight_refs = [] + ret = super()._apply(fn, recurse) + + # Resets _flat_weights + # Note: be v. careful before removing this, as 3rd party device types + # likely rely on this behavior to properly .to() modules like LSTM. + self._init_flat_weights() + + return ret + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 + for weight in self.parameters(): + init.uniform_(weight, -stdv, stdv) + + def check_input(self, input: Tensor, batch_sizes: Tensor | None) -> None: + if not torch.jit.is_scripting(): + if ( + input.dtype != self._flat_weights[0].dtype # type: ignore[union-attr] + and not torch._C._is_any_autocast_enabled() + ): + raise ValueError( + f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}" # type: ignore[union-attr] + ) + expected_input_dim = 2 if batch_sizes is not None else 3 + if input.dim() != expected_input_dim: + raise RuntimeError( + f"input must have {expected_input_dim} dimensions, got {input.dim()}" + ) + if self.input_size != input.size(-1): + raise RuntimeError( + f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}" + ) + + def get_expected_hidden_size( + self, input: Tensor, batch_sizes: Tensor | None + ) -> tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + if self.proj_size > 0: + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.proj_size, + ) + else: + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) + return expected_hidden_size + + def check_hidden_size( + self, + hx: Tensor, + expected_hidden_size: tuple[int, int, int], + msg: str = "Expected hidden size {}, got {}", + ) -> None: + if hx.size() != expected_hidden_size: + raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) + + def _weights_have_changed(self): + # Returns True if the weight tensors have changed since the last forward pass. + # This is the case when used with torch.func.functional_call(), for example. + weights_changed = False + for ref, name in zip( + self._flat_weight_refs, self._flat_weights_names, strict=True + ): + weight = getattr(self, name) if hasattr(self, name) else None + if weight is not None and ref is not None and ref() is not weight: + weights_changed = True + break + return weights_changed + + def check_forward_args( + self, input: Tensor, hidden: Tensor, batch_sizes: Tensor | None + ) -> None: + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + + self.check_hidden_size(hidden, expected_hidden_size) + + def permute_hidden(self, hx: Tensor, permutation: Tensor | None): + if permutation is None: + return hx + return _apply_permutation(hx, permutation) + + def extra_repr(self) -> str: + s = "{input_size}, {hidden_size}" + if self.proj_size != 0: + s += ", proj_size={proj_size}" + if self.num_layers != 1: + s += ", num_layers={num_layers}" + if self.bias is not True: + s += ", bias={bias}" + if self.batch_first is not False: + s += ", batch_first={batch_first}" + if self.dropout != 0: + s += ", dropout={dropout}" + if self.bidirectional is not False: + s += ", bidirectional={bidirectional}" + return s.format(**self.__dict__) + + def _update_flat_weights(self) -> None: + if not torch.jit.is_scripting(): + if self._weights_have_changed(): + self._init_flat_weights() + + def __getstate__(self): + # If weights have been changed, update the _flat_weights in __getstate__ here. + self._update_flat_weights() + # Don't serialize the weight references. + state = self.__dict__.copy() + del state["_flat_weight_refs"] + return state + + def __setstate__(self, d): + super().__setstate__(d) + if "all_weights" in d: + self._all_weights = d["all_weights"] + # In PyTorch 1.8 we added a proj_size member variable to LSTM. + # LSTMs that were serialized via torch.save(module) before PyTorch 1.8 + # don't have it, so to preserve compatibility we set proj_size here. + if "proj_size" not in d: + self.proj_size = 0 + + if not isinstance(self._all_weights[0][0], str): + num_layers = self.num_layers + num_directions = 2 if self.bidirectional else 1 + self._flat_weights_names = [] + self._all_weights = [] + for layer in range(num_layers): + for direction in range(num_directions): + suffix = "_reverse" if direction == 1 else "" + weights = [ + "weight_ih_l{}{}", + "weight_hh_l{}{}", + "bias_ih_l{}{}", + "bias_hh_l{}{}", + "weight_hr_l{}{}", + ] + weights = [x.format(layer, suffix) for x in weights] + if self.bias: + if self.proj_size > 0: + self._all_weights += [weights] + self._flat_weights_names.extend(weights) + else: + self._all_weights += [weights[:4]] + self._flat_weights_names.extend(weights[:4]) + else: + if self.proj_size > 0: + self._all_weights += [weights[:2]] + [weights[-1:]] + self._flat_weights_names.extend( + weights[:2] + [weights[-1:]] + ) + else: + self._all_weights += [weights[:2]] + self._flat_weights_names.extend(weights[:2]) + self._flat_weights = [ + getattr(self, wn) if hasattr(self, wn) else None + for wn in self._flat_weights_names + ] + + self._flat_weight_refs = [ + weakref.ref(w) if w is not None else None for w in self._flat_weights + ] + + @property + def all_weights(self) -> list[list[Parameter]]: + return [ + [getattr(self, weight) for weight in weights] + for weights in self._all_weights + ] + + def _replicate_for_data_parallel(self): + replica = super()._replicate_for_data_parallel() + # Need to copy these caches, otherwise the replica will share the same + # flat weights list. + replica._flat_weights = replica._flat_weights[:] + replica._flat_weights_names = replica._flat_weights_names[:] + return replica + + +class RNN(RNNBase): + r"""__init__(input_size,hidden_size,num_layers=1,nonlinearity='tanh',bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None) + + Apply a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` + non-linearity to an input sequence. For each element in the input sequence, + each layer computes the following function: + + .. math:: + h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh}) + + where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is + the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the + previous layer at time `t-1` or the initial hidden state at time `0`. + If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`. + + .. code-block:: python + + # Efficient implementation equivalent to the following with bidirectional=False + rnn = nn.RNN(input_size, hidden_size, num_layers) + params = dict(rnn.named_parameters()) + def forward(x, hx=None, batch_first=False): + if batch_first: + x = x.transpose(0, 1) + seq_len, batch_size, _ = x.size() + if hx is None: + hx = torch.zeros(rnn.num_layers, batch_size, rnn.hidden_size) + h_t_minus_1 = hx.clone() + h_t = hx.clone() + output = [] + for t in range(seq_len): + for layer in range(rnn.num_layers): + input_t = x[t] if layer == 0 else h_t[layer - 1] + h_t[layer] = torch.tanh( + input_t @ params[f"weight_ih_l{layer}"].T + + h_t_minus_1[layer] @ params[f"weight_hh_l{layer}"].T + + params[f"bias_hh_l{layer}"] + + params[f"bias_ih_l{layer}"] + ) + output.append(h_t[-1].clone()) + h_t_minus_1 = h_t.clone() + output = torch.stack(output) + if batch_first: + output = output.transpose(0, 1) + return output, h_t + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two RNNs together to form a `stacked RNN`, + with the second RNN taking in outputs of the first RNN and + computing the final results. Default: 1 + nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + RNN layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False`` + + Inputs: input, hx + * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, + :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **hx**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden + state for the input sequence batch. Defaults to zeros if not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{out} ={} & \text{hidden\_size} + \end{aligned} + + Outputs: output, h_n + * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, + :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the RNN, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state + for each element in the batch. + + Attributes: + weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, + of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is + `(hidden_size, num_directions * hidden_size)` + weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, + of shape `(hidden_size, hidden_size)` + bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, + of shape `(hidden_size)` + bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, + of shape `(hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + For bidirectional RNNs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + + .. note:: + ``batch_first`` argument is ignored for unbatched inputs. + + .. include:: ../cudnn_rnn_determinism.rst + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> rnn = nn.RNN(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + """ + + @overload + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + nonlinearity: str = "tanh", + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + device=None, + dtype=None, + ) -> None: ... + + @overload + def __init__(self, *args, **kwargs) -> None: ... + + def __init__(self, *args, **kwargs): + if "proj_size" in kwargs: + raise ValueError( + "proj_size argument is only supported for LSTM, not RNN or GRU" + ) + if len(args) > 3: + self.nonlinearity = args[3] + args = args[:3] + args[4:] + else: + self.nonlinearity = kwargs.pop("nonlinearity", "tanh") + if self.nonlinearity == "tanh": + mode = "RNN_TANH" + elif self.nonlinearity == "relu": + mode = "RNN_RELU" + else: + raise ValueError( + f"Unknown nonlinearity '{self.nonlinearity}'. Select from 'tanh' or 'relu'." + ) + super().__init__(mode, *args, **kwargs) + + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, + input: Tensor, + hx: Tensor | None = None, + ) -> tuple[Tensor, Tensor]: + pass + + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, + input: PackedSequence, + hx: Tensor | None = None, + ) -> tuple[PackedSequence, Tensor]: + pass + + def forward(self, input, hx=None): # noqa: F811 + """ + Runs the forward pass. + """ + self._update_flat_weights() + + num_directions = 2 if self.bidirectional else 1 + orig_input = input + + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + # script() is unhappy when max_batch_size is different type in cond branches, so we duplicate + if hx is None: + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + else: + batch_sizes = None + if input.dim() not in (2, 3): + raise ValueError( + f"RNN: Expected input to be 2D or 3D, got {input.dim()}D tensor instead" + ) + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + if hx is not None: + if hx.dim() != 2: + raise RuntimeError( + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" + ) + hx = hx.unsqueeze(1) + else: + if hx is not None and hx.dim() != 3: + raise RuntimeError( + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" + ) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + if hx is None: + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + assert hx is not None + self.check_forward_args(input, hx, batch_sizes) + assert self.mode == "RNN_TANH" or self.mode == "RNN_RELU" + if batch_sizes is None: + if self.mode == "RNN_TANH": + result = _VF.rnn_tanh( + input, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.rnn_relu( + input, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + if self.mode == "RNN_TANH": + result = _VF.rnn_tanh( + input, + batch_sizes, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + else: + result = _VF.rnn_relu( + input, + batch_sizes, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + + output = result[0] + hidden = result[1] + + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, + batch_sizes, + sorted_indices, + unsorted_indices, + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = hidden.squeeze(1) + + return output, self.permute_hidden(hidden, unsorted_indices) + + +# XXX: LSTM and GRU implementation is different from RNNBase, this is because: +# 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in +# its current state could not support the python Union Type or Any Type +# 2. TorchScript static typing does not allow a Function or Callable type in +# Dict values, so we have to separately call _VF instead of using _rnn_impls +# 3. This is temporary only and in the transition state that we want to make it +# on time for the release +# +# More discussion details in https://github.com/pytorch/pytorch/pull/23266 +# +# TODO: remove the overriding implementations for LSTM and GRU when TorchScript +# support expressing these two modules generally. + + +class LSTM(RNNBase): + r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,proj_size=0,device=None,dtype=None) + + Apply a multi-layer long short-term memory (LSTM) RNN to an input sequence. + For each element in the input sequence, each layer computes the following + function: + + .. math:: + \begin{array}{ll} \\ + i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ + f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ + g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ + o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ + c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ + h_t = o_t \odot \tanh(c_t) \\ + \end{array} + + where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell + state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}` + is the hidden state of the layer at time `t-1` or the initial hidden + state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`, + :math:`o_t` are the input, forget, cell, and output gates, respectively. + :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer + (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by + dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random + variable which is :math:`0` with probability :attr:`dropout`. + + If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes + the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from + ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly). + Second, the output hidden state of each layer will be multiplied by a learnable projection + matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output + of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact + dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two LSTMs together to form a `stacked LSTM`, + with the second LSTM taking in outputs of the first LSTM and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + LSTM layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False`` + proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 + + Inputs: input, (h_0, c_0) + * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, + :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the + initial hidden state for each element in the input sequence. + Defaults to zeros if (h_0, c_0) is not provided. + * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{cell})` containing the + initial cell state for each element in the input sequence. + Defaults to zeros if (h_0, c_0) is not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{cell} ={} & \text{hidden\_size} \\ + H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\ + \end{aligned} + + Outputs: output, (h_n, c_n) + * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, + :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the LSTM, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. When ``bidirectional=True``, `output` will contain + a concatenation of the forward and reverse hidden states at each time step in the sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the + final hidden state for each element in the sequence. When ``bidirectional=True``, + `h_n` will contain a concatenation of the final forward and reverse hidden states, respectively. + * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{cell})` containing the + final cell state for each element in the sequence. When ``bidirectional=True``, + `c_n` will contain a concatenation of the final forward and reverse cell states, respectively. + + Attributes: + weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer + `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`. + Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If + ``proj_size > 0`` was specified, the shape will be + `(4*hidden_size, num_directions * proj_size)` for `k > 0` + weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer + `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0`` + was specified, the shape will be `(4*hidden_size, proj_size)`. + bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer + `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)` + bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer + `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)` + weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer + of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was + specified. + weight_ih_l[k]_reverse: Analogous to `weight_ih_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + weight_hh_l[k]_reverse: Analogous to `weight_hh_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + bias_ih_l[k]_reverse: Analogous to `bias_ih_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + bias_hh_l[k]_reverse: Analogous to `bias_hh_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + weight_hr_l[k]_reverse: Analogous to `weight_hr_l[k]` for the reverse direction. + Only present when ``bidirectional=True`` and ``proj_size > 0`` was specified. + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + + .. note:: + For bidirectional LSTMs, `h_n` is not equivalent to the last element of `output`; the + former contains the final forward and reverse hidden states, while the latter contains the + final forward hidden state and the initial reverse hidden state. + + .. note:: + ``batch_first`` argument is ignored for unbatched inputs. + + .. note:: + ``proj_size`` should be smaller than ``hidden_size``. + + .. include:: ../cudnn_rnn_determinism.rst + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> rnn = nn.LSTM(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> c0 = torch.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + """ + + @overload + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + proj_size: int = 0, + device=None, + dtype=None, + ) -> None: ... + + @overload + def __init__(self, *args, **kwargs) -> None: ... + + def __init__(self, *args, **kwargs): + super().__init__("LSTM", *args, **kwargs) + + def get_expected_cell_size( + self, input: Tensor, batch_sizes: Tensor | None + ) -> tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) + return expected_hidden_size + + # In the future, we should prevent mypy from applying contravariance rules here. + # See torch/nn/modules/module.py::_forward_unimplemented + def check_forward_args( + self, + input: Tensor, + hidden: tuple[Tensor, Tensor], # type: ignore[override] + batch_sizes: Tensor | None, + ) -> None: + self.check_input(input, batch_sizes) + self.check_hidden_size( + hidden[0], + self.get_expected_hidden_size(input, batch_sizes), + "Expected hidden[0] size {}, got {}", + ) + self.check_hidden_size( + hidden[1], + self.get_expected_cell_size(input, batch_sizes), + "Expected hidden[1] size {}, got {}", + ) + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + def permute_hidden( # type: ignore[override] + self, + hx: tuple[Tensor, Tensor], + permutation: Tensor | None, + ) -> tuple[Tensor, Tensor]: + if permutation is None: + return hx + return _apply_permutation(hx[0], permutation), _apply_permutation( + hx[1], permutation + ) + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + @overload # type: ignore[override] + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, + input: Tensor, + hx: tuple[Tensor, Tensor] | None = None, + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811 + pass + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, + input: PackedSequence, + hx: tuple[Tensor, Tensor] | None = None, + ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811 + pass + + def forward(self, input, hx=None): # noqa: F811 + self._update_flat_weights() + + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + batch_sizes = None + num_directions = 2 if self.bidirectional else 1 + real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + if hx is None: + h_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + real_hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (h_zeros, c_zeros) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + else: + if input.dim() not in (2, 3): + raise ValueError( + f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead" + ) + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + if hx is None: + h_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + real_hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (h_zeros, c_zeros) + self.check_forward_args(input, hx, batch_sizes) + else: + if is_batched: + if hx[0].dim() != 3 or hx[1].dim() != 3: + msg = ( + "For batched 3-D input, hx and cx should " + f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) + raise RuntimeError(msg) + else: + if hx[0].dim() != 2 or hx[1].dim() != 2: + msg = ( + "For unbatched 2-D input, hx and cx should " + f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) + raise RuntimeError(msg) + hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + self.check_forward_args(input, hx, batch_sizes) + hx = self.permute_hidden(hx, sorted_indices) + + if batch_sizes is None: + result = _VF.lstm( + input, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.lstm( + input, + batch_sizes, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1:] + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, + batch_sizes, + sorted_indices, + unsorted_indices, + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) + return output, self.permute_hidden(hidden, unsorted_indices) + + +class GRU(RNNBase): + r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None) + + Apply a multi-layer gated recurrent unit (GRU) RNN to an input sequence. + For each element in the input sequence, each layer computes the following + function: + + .. math:: + \begin{array}{ll} + r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ + z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\ + h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)} + \end{array} + + where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input + at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer + at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, + :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. + :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer + (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by + dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random + variable which is :math:`0` with probability :attr:`dropout`. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two GRUs together to form a `stacked GRU`, + with the second GRU taking in outputs of the first GRU and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + GRU layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` + + Inputs: input, h_0 + * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, + :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or + :math:`(D * \text{num\_layers}, N, H_{out})` + containing the initial hidden state for the input sequence. Defaults to zeros if not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{out} ={} & \text{hidden\_size} + \end{aligned} + + Outputs: output, h_n + * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, + :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the GRU, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state + for the input sequence. + + Attributes: + weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer + (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. + Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` + weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer + (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` + bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer + (b_ir|b_iz|b_in), of shape `(3*hidden_size)` + bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer + (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + For bidirectional GRUs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + + .. note:: + ``batch_first`` argument is ignored for unbatched inputs. + + .. note:: + The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks. + In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the + previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix + `W` and addition of bias: + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn}) + \end{aligned} + + This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}` + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) + \end{aligned} + + This implementation differs on purpose for efficiency. + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> rnn = nn.GRU(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + """ + + @overload + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + device=None, + dtype=None, + ) -> None: ... + + @overload + def __init__(self, *args, **kwargs) -> None: ... + + def __init__(self, *args, **kwargs): + if "proj_size" in kwargs: + raise ValueError( + "proj_size argument is only supported for LSTM, not RNN or GRU" + ) + super().__init__("GRU", *args, **kwargs) + + @overload # type: ignore[override] + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, + input: Tensor, + hx: Tensor | None = None, + ) -> tuple[Tensor, Tensor]: # noqa: F811 + pass + + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, + input: PackedSequence, + hx: Tensor | None = None, + ) -> tuple[PackedSequence, Tensor]: # noqa: F811 + pass + + def forward(self, input, hx=None): # noqa: F811 + self._update_flat_weights() + + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + else: + batch_sizes = None + if input.dim() not in (2, 3): + raise ValueError( + f"GRU: Expected input to be 2D or 3D, got {input.dim()}D instead" + ) + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + if hx is not None: + if hx.dim() != 2: + raise RuntimeError( + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" + ) + hx = hx.unsqueeze(1) + else: + if hx is not None and hx.dim() != 3: + raise RuntimeError( + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" + ) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = _VF.gru( + input, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.gru( + input, + batch_sizes, + hx, + self._flat_weights, # type: ignore[arg-type] + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1] + + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, + batch_sizes, + sorted_indices, + unsorted_indices, + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = hidden.squeeze(1) + + return output, self.permute_hidden(hidden, unsorted_indices) + + +class RNNCellBase(Module): + __constants__ = ["input_size", "hidden_size", "bias"] + + input_size: int + hidden_size: int + bias: bool + weight_ih: Tensor + weight_hh: Tensor + # WARNING: bias_ih and bias_hh purposely not defined here. + # See https://github.com/pytorch/pytorch/issues/39670 + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool, + num_chunks: int, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.weight_ih = Parameter( + torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs) + ) + self.weight_hh = Parameter( + torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs) + ) + if bias: + self.bias_ih = Parameter( + torch.empty(num_chunks * hidden_size, **factory_kwargs) + ) + self.bias_hh = Parameter( + torch.empty(num_chunks * hidden_size, **factory_kwargs) + ) + else: + self.register_parameter("bias_ih", None) + self.register_parameter("bias_hh", None) + + self.reset_parameters() + + def extra_repr(self) -> str: + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": + s += ", nonlinearity={nonlinearity}" + return s.format(**self.__dict__) + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 + for weight in self.parameters(): + init.uniform_(weight, -stdv, stdv) + + +class RNNCell(RNNCellBase): + r"""An Elman RNN cell with tanh or ReLU non-linearity. + + .. math:: + + h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh}) + + If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` + + Inputs: input, hidden + - **input**: tensor containing input features + - **hidden**: tensor containing the initial hidden state + Defaults to zero if not provided. + + Outputs: h' + - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state + for each element in the batch + + Shape: + - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where + :math:`H_{in}` = `input_size`. + - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden + state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. + - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. + + Attributes: + weight_ih: the learnable input-hidden weights, of shape + `(hidden_size, input_size)` + weight_hh: the learnable hidden-hidden weights, of shape + `(hidden_size, hidden_size)` + bias_ih: the learnable input-hidden bias, of shape `(hidden_size)` + bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + Examples:: + + >>> rnn = nn.RNNCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"] + nonlinearity: str + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + nonlinearity: str = "tanh", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) + self.nonlinearity = nonlinearity + + def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor: + if input.dim() not in (1, 2): + raise ValueError( + f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) + if hx is not None and hx.dim() not in (1, 2): + raise ValueError( + f"RNNCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + if self.nonlinearity == "tanh": + ret = _VF.rnn_tanh_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + ) + elif self.nonlinearity == "relu": + ret = _VF.rnn_relu_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + ) + else: + ret = input # TODO: remove when jit supports exception flow + raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") + + if not is_batched: + ret = ret.squeeze(0) + + return ret + + +class LSTMCell(RNNCellBase): + r"""A long short-term memory (LSTM) cell. + + .. math:: + + \begin{array}{ll} + i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ + f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ + g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ + o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ + c' = f \odot c + i \odot g \\ + h' = o \odot \tanh(c') \\ + \end{array} + + where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + bias: If ``False``, then the layer does not use bias weights `b_ih` and + `b_hh`. Default: ``True`` + + Inputs: input, (h_0, c_0) + - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features + - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state + - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state + + If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. + + Outputs: (h_1, c_1) + - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state + - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state + + Attributes: + weight_ih: the learnable input-hidden weights, of shape + `(4*hidden_size, input_size)` + weight_hh: the learnable hidden-hidden weights, of shape + `(4*hidden_size, hidden_size)` + bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)` + bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Examples:: + + >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size) + >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size) + >>> hx = torch.randn(3, 20) # (batch, hidden_size) + >>> cx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(input.size()[0]): + ... hx, cx = rnn(input[i], (hx, cx)) + ... output.append(hx) + >>> output = torch.stack(output, dim=0) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) + + def forward( + self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None + ) -> tuple[Tensor, Tensor]: + if input.dim() not in (1, 2): + raise ValueError( + f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) + if hx is not None: + for idx, value in enumerate(hx): + if value.dim() not in (1, 2): + raise ValueError( + f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + zeros = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + hx = (zeros, zeros) + else: + hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx + + ret = _VF.lstm_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + ) + + if not is_batched: + ret = (ret[0].squeeze(0), ret[1].squeeze(0)) + return ret + + +class GRUCell(RNNCellBase): + r"""A gated recurrent unit (GRU) cell. + + .. math:: + + \begin{array}{ll} + r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ + z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ + n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\ + h' = (1 - z) \odot n + z \odot h + \end{array} + + where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + bias: If ``False``, then the layer does not use bias weights `b_ih` and + `b_hh`. Default: ``True`` + + Inputs: input, hidden + - **input** : tensor containing input features + - **hidden** : tensor containing the initial hidden + state for each element in the batch. + Defaults to zero if not provided. + + Outputs: h' + - **h'** : tensor containing the next hidden state + for each element in the batch + + Shape: + - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where + :math:`H_{in}` = `input_size`. + - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden + state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. + - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. + + Attributes: + weight_ih: the learnable input-hidden weights, of shape + `(3*hidden_size, input_size)` + weight_hh: the learnable hidden-hidden weights, of shape + `(3*hidden_size, hidden_size)` + bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)` + bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Examples:: + + >>> rnn = nn.GRUCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) + + def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor: + if input.dim() not in (1, 2): + raise ValueError( + f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) + if hx is not None and hx.dim() not in (1, 2): + raise ValueError( + f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + ret = _VF.gru_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + ) + + if not is_batched: + ret = ret.squeeze(0) + + return ret diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/sparse.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec531abce695374b919fbf92d4863ce73da515f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/sparse.py @@ -0,0 +1,549 @@ +# mypy: allow-untyped-defs + +import torch +from torch import Tensor +from torch.nn import functional as F, init +from torch.nn.parameter import Parameter + +from .module import Module + + +__all__ = ["Embedding", "EmbeddingBag"] + + +class Embedding(Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; + therefore, the embedding vector at :attr:`padding_idx` is not updated during training, + i.e. it remains as a fixed "pad". For a newly constructed Embedding, + the embedding vector at :attr:`padding_idx` will default to all zeros, + but can be updated to another value to be used as the padding vector. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the + :attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be + modified in-place, performing a differentiable operation on ``Embedding.weight`` before + calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when + :attr:`max_norm` is not ``None``. For example:: + + n, d, m = 3, 5, 7 + embedding = nn.Embedding(n, d, max_norm=1.0) + W = torch.randn((m, d), requires_grad=True) + idx = torch.tensor([1, 2]) + a = ( + embedding.weight.clone() @ W.t() + ) # weight must be cloned for this to be differentiable + b = embedding(idx) @ W.t() # modifies weight in-place + out = a.unsqueeze(0) + b.unsqueeze(1) + loss = out.sigmoid().prod() + loss.backward() + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0, 2, 0, 5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + + >>> # example of changing `pad` vector + >>> padding_idx = 0 + >>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx) + >>> embedding.weight + Parameter containing: + tensor([[ 0.0000, 0.0000, 0.0000], + [-0.7895, -0.7089, -0.0364], + [ 0.6778, 0.5803, 0.2678]], requires_grad=True) + >>> with torch.no_grad(): + ... embedding.weight[padding_idx] = torch.ones(3) + >>> embedding.weight + Parameter containing: + tensor([[ 1.0000, 1.0000, 1.0000], + [-0.7895, -0.7089, -0.0364], + [ 0.6778, 0.5803, 0.2678]], requires_grad=True) + """ + + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "max_norm", + "norm_type", + "scale_grad_by_freq", + "sparse", + ] + + num_embeddings: int + embedding_dim: int + padding_idx: int | None + max_norm: float | None + norm_type: float + scale_grad_by_freq: bool + weight: Tensor + freeze: bool + sparse: bool + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int | None = None, + max_norm: float | None = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Tensor | None = None, + _freeze: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, ( + "Padding_idx must be within num_embeddings" + ) + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, ( + "Padding_idx must be within num_embeddings" + ) + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + if _weight is None: + self.weight = Parameter( + torch.empty((num_embeddings, embedding_dim), **factory_kwargs), + requires_grad=not _freeze, + ) + self.reset_parameters() + else: + assert list(_weight.shape) == [ + num_embeddings, + embedding_dim, + ], "Shape of weight does not match num_embeddings and embedding_dim" + self.weight = Parameter(_weight, requires_grad=not _freeze) + + self.sparse = sparse + + def reset_parameters(self) -> None: + init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + return F.embedding( + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + def extra_repr(self) -> str: + s = "{num_embeddings}, {embedding_dim}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + if self.max_norm is not None: + s += ", max_norm={max_norm}" + if self.norm_type != 2: + s += ", norm_type={norm_type}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + if self.sparse is not False: + s += ", sparse=True" + return s.format(**self.__dict__) + + @classmethod + def from_pretrained( + cls, + embeddings, + freeze=True, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + ): + r"""Create Embedding instance from given 2-dimensional FloatTensor. + + Args: + embeddings (Tensor): FloatTensor containing weights for the Embedding. + First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. + freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process. + Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; + therefore, the embedding vector at :attr:`padding_idx` is not updated during training, + i.e. it remains as a fixed "pad". + max_norm (float, optional): See module initialization documentation. + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``. + sparse (bool, optional): See module initialization documentation. + + Examples:: + + >>> # FloatTensor containing pretrained weights + >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> embedding = nn.Embedding.from_pretrained(weight) + >>> # Get embeddings for index 1 + >>> input = torch.LongTensor([1]) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> embedding(input) + tensor([[ 4.0000, 5.1000, 6.3000]]) + """ + assert embeddings.dim() == 2, ( + "Embeddings parameter is expected to be 2-dimensional" + ) + rows, cols = embeddings.shape + embedding = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + _freeze=freeze, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) + return embedding + + +class EmbeddingBag(Module): + r"""Compute sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings. + + For bags of constant length, no :attr:`per_sample_weights`, no indices equal to :attr:`padding_idx`, + and with 2D inputs, this class + + * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=1)``, + * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=1)``, + * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=1)``. + + However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these + operations. + + EmbeddingBag also supports per-sample weights as an argument to the forward + pass. This scales the output of the Embedding before performing a weighted + reduction as specified by ``mode``. If :attr:`per_sample_weights` is passed, the + only supported ``mode`` is ``"sum"``, which computes a weighted sum according to + :attr:`per_sample_weights`. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + Note: this option is not supported when ``mode="max"``. + mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. + ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights` + into consideration. ``"mean"`` computes the average of the values + in the bag, ``"max"`` computes the max value over each bag. + Default: ``"mean"`` + sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See + Notes for more details regarding sparse gradients. Note: this option is not + supported when ``mode="max"``. + include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1. + The last element is the size of the input, or the ending index position + of the last bag (sequence). This matches the CSR format. Ignored when + input is 2D. Default ``False``. + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the + gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated + during training, i.e. it remains as a fixed "pad". For a newly constructed + EmbeddingBag, the embedding vector at :attr:`padding_idx` will default to all + zeros, but can be updated to another value to be used as the padding vector. + Note that the embedding vector at :attr:`padding_idx` is excluded from the + reduction. + + Attributes: + weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` + initialized from :math:`\mathcal{N}(0, 1)`. + + Examples:: + + >>> # an EmbeddingBag module containing 10 tensors of size 3 + >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum') + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long) + >>> offsets = torch.tensor([0, 4], dtype=torch.long) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> embedding_sum(input, offsets) + tensor([[-0.8861, -5.4350, -0.0523], + [ 1.1306, -2.5798, -1.0044]]) + + >>> # Example with padding_idx + >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum', padding_idx=2) + >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9], dtype=torch.long) + >>> offsets = torch.tensor([0, 4], dtype=torch.long) + >>> embedding_sum(input, offsets) + tensor([[ 0.0000, 0.0000, 0.0000], + [-0.7082, 3.2145, -2.6251]]) + + >>> # An EmbeddingBag can be loaded from an Embedding like so + >>> embedding = nn.Embedding(10, 3, padding_idx=2) + >>> embedding_sum = nn.EmbeddingBag.from_pretrained( + embedding.weight, + padding_idx=embedding.padding_idx, + mode='sum') + """ + + __constants__ = [ + "num_embeddings", + "embedding_dim", + "max_norm", + "norm_type", + "scale_grad_by_freq", + "mode", + "sparse", + "include_last_offset", + "padding_idx", + ] + + num_embeddings: int + embedding_dim: int + max_norm: float | None + norm_type: float + scale_grad_by_freq: bool + weight: Tensor + mode: str + sparse: bool + include_last_offset: bool + padding_idx: int | None + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + max_norm: float | None = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + _weight: Tensor | None = None, + include_last_offset: bool = False, + padding_idx: int | None = None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, ( + "padding_idx must be within num_embeddings" + ) + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, ( + "padding_idx must be within num_embeddings" + ) + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + if _weight is None: + self.weight = Parameter( + torch.empty((num_embeddings, embedding_dim), **factory_kwargs) + ) + self.reset_parameters() + else: + assert list(_weight.shape) == [ + num_embeddings, + embedding_dim, + ], "Shape of weight does not match num_embeddings and embedding_dim" + self.weight = Parameter(_weight) + self.mode = mode + self.sparse = sparse + self.include_last_offset = include_last_offset + + def reset_parameters(self) -> None: + init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward( + self, + input: Tensor, + offsets: Tensor | None = None, + per_sample_weights: Tensor | None = None, + ) -> Tensor: + """Forward pass of EmbeddingBag. + + Args: + input (Tensor): Tensor containing bags of indices into the embedding matrix. + offsets (Tensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines + the starting index position of each bag (sequence) in :attr:`input`. + per_sample_weights (Tensor, optional): a tensor of float / double weights, or None + to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights` + must have exactly the same shape as input and is treated as having the same + :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``. + + Returns: + Tensor output shape of `(B, embedding_dim)`. + + .. note:: + + A few notes about ``input`` and ``offsets``: + + - :attr:`input` and :attr:`offsets` have to be of the same type, either int or long + + - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences) + each of fixed length ``N``, and this will return ``B`` values aggregated in a way + depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case. + + - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of + multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing the + starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` of shape `(B)`, + :attr:`input` will be viewed as having ``B`` bags. Empty bags (i.e., having 0-length) will have + returned vectors filled by zeros. + """ + return F.embedding_bag( + input, + self.weight, + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) + + def extra_repr(self) -> str: + s = "{num_embeddings}, {embedding_dim}" + if self.max_norm is not None: + s += ", max_norm={max_norm}" + if self.norm_type != 2: + s += ", norm_type={norm_type}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + s += ", mode={mode}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + return s.format(**{k: repr(v) for k, v in self.__dict__.items()}) + + @classmethod + def from_pretrained( + cls, + embeddings: Tensor, + freeze: bool = True, + max_norm: float | None = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + include_last_offset: bool = False, + padding_idx: int | None = None, + ) -> "EmbeddingBag": + r"""Create EmbeddingBag instance from given 2-dimensional FloatTensor. + + Args: + embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag. + First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'. + freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process. + Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True`` + max_norm (float, optional): See module initialization documentation. Default: ``None`` + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``. + mode (str, optional): See module initialization documentation. Default: ``"mean"`` + sparse (bool, optional): See module initialization documentation. Default: ``False``. + include_last_offset (bool, optional): See module initialization documentation. Default: ``False``. + padding_idx (int, optional): See module initialization documentation. Default: ``None``. + + Examples:: + + >>> # FloatTensor containing pretrained weights + >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight) + >>> # Get embeddings for index 1 + >>> input = torch.LongTensor([[1, 0]]) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> embeddingbag(input) + tensor([[ 2.5000, 3.7000, 4.6500]]) + """ + assert embeddings.dim() == 2, ( + "Embeddings parameter is expected to be 2-dimensional" + ) + rows, cols = embeddings.shape + embeddingbag = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + include_last_offset=include_last_offset, + padding_idx=padding_idx, + ) + embeddingbag.weight.requires_grad = not freeze + return embeddingbag diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/transformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6841e85ed6d2e423aa30e95b5b1d3e62f30ec9fb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/transformer.py @@ -0,0 +1,1256 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from collections.abc import Callable +from typing import Any + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.init import xavier_uniform_ + +from .activation import MultiheadAttention +from .container import ModuleList +from .dropout import Dropout +from .linear import Linear +from .module import Module +from .normalization import LayerNorm + + +__all__ = [ + "Transformer", + "TransformerEncoder", + "TransformerDecoder", + "TransformerEncoderLayer", + "TransformerDecoderLayer", +] + + +def _generate_square_subsequent_mask( + sz: int, + device: torch.device | None = None, + dtype: torch.dtype | None = None, +) -> Tensor: + r"""Generate a square causal mask for the sequence. + + The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). + """ + return torch.triu( + torch.full((sz, sz), float("-inf"), dtype=dtype, device=device), + diagonal=1, + ) + + +def _get_seq_len(src: Tensor, batch_first: bool) -> int | None: + if src.is_nested: + return None + else: + src_size = src.size() + if len(src_size) == 2: + # unbatched: S, E + return src_size[0] + else: + # batched: B, S, E if batch_first else S, B, E + seq_len_pos = 1 if batch_first else 0 + return src_size[seq_len_pos] + + +class Transformer(Module): + r"""A basic transformer layer. + + + This Transformer layer implements the original Transformer architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build an efficient transformer layer from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. + + Args: + d_model: the number of expected features in the encoder/decoder inputs (default=512). + nhead: the number of heads in the multiheadattention models (default=8). + num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). + num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of encoder/decoder intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + custom_encoder: custom encoder (default=None). + custom_decoder: custom decoder (default=None). + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before + other attention and feedforward operations, otherwise after. Default: ``False`` (after). + bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive + bias. Default: ``True``. + + Examples: + >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) + >>> src = torch.rand((10, 32, 512)) + >>> tgt = torch.rand((20, 32, 512)) + >>> out = transformer_model(src, tgt) + + Note: A full example to apply nn.Transformer module for the word language model is available in + https://github.com/pytorch/examples/tree/master/word_language_model + """ + + def __init__( + self, + d_model: int = 512, + nhead: int = 8, + num_encoder_layers: int = 6, + num_decoder_layers: int = 6, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str | Callable[[Tensor], Tensor] = F.relu, + custom_encoder: Any | None = None, + custom_decoder: Any | None = None, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = False, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") + + if custom_encoder is not None: + self.encoder = custom_encoder + else: + encoder_layer = TransformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + activation, + layer_norm_eps, + batch_first, + norm_first, + bias, + **factory_kwargs, + ) + encoder_norm = LayerNorm( + d_model, + eps=layer_norm_eps, + bias=bias, + # pyrefly: ignore [bad-argument-type] + **factory_kwargs, + ) + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + if custom_decoder is not None: + self.decoder = custom_decoder + else: + decoder_layer = TransformerDecoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + activation, + layer_norm_eps, + batch_first, + norm_first, + bias, + **factory_kwargs, + ) + decoder_norm = LayerNorm( + d_model, + eps=layer_norm_eps, + bias=bias, + # pyrefly: ignore [bad-argument-type] + **factory_kwargs, + ) + self.decoder = TransformerDecoder( + decoder_layer, num_decoder_layers, decoder_norm + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + self.batch_first = batch_first + + def forward( + self, + src: Tensor, + tgt: Tensor, + src_mask: Tensor | None = None, + tgt_mask: Tensor | None = None, + memory_mask: Tensor | None = None, + src_key_padding_mask: Tensor | None = None, + tgt_key_padding_mask: Tensor | None = None, + memory_key_padding_mask: Tensor | None = None, + src_is_causal: bool | None = None, + tgt_is_causal: bool | None = None, + memory_is_causal: bool = False, + ) -> Tensor: + r"""Take in and process masked source/target sequences. + + .. note:: + + If a boolean tensor is provided for any of the [src/tgt/memory]_mask arguments, positions with a ``True`` value are + not allowed to participate in the attention, + which is the opposite of the definition for :attr:`attn_mask` + in :func:`torch.nn.functional.scaled_dot_product_attention`. + + Args: + src: the sequence to the encoder (required). + tgt: the sequence to the decoder (required). + src_mask: the additive mask for the src sequence (optional). + tgt_mask: the additive mask for the tgt sequence (optional). + memory_mask: the additive mask for the encoder output (optional). + src_key_padding_mask: the Tensor mask for src keys per batch (optional). + tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional). + memory_key_padding_mask: the Tensor mask for memory keys per batch (optional). + src_is_causal: If specified, applies a causal mask as ``src_mask``. + Default: ``None``; try to detect a causal mask. + Warning: + ``src_is_causal`` provides a hint that ``src_mask`` is + the causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``. + Default: ``None``; try to detect a causal mask. + Warning: + ``tgt_is_causal`` provides a hint that ``tgt_mask`` is + the causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + memory_is_causal: If specified, applies a causal mask as + ``memory_mask``. + Default: ``False``. + Warning: + ``memory_is_causal`` provides a hint that + ``memory_mask`` is the causal mask. Providing incorrect + hints can result in incorrect execution, including + forward and backward compatibility. + + Shape: + - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or + `(N, S, E)` if `batch_first=True`. + - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or + `(N, T, E)` if `batch_first=True`. + - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`. + - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`. + - memory_mask: :math:`(T, S)`. + - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`. + - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`. + - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`. + + Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked + positions. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by + the attention. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + + - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or + `(N, T, E)` if `batch_first=True`. + + Note: Due to the multi-head attention architecture in the transformer model, + the output sequence length of a transformer is same as the input sequence + (i.e. target) length of the decoder. + + where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the + batch size, :math:`E` is the feature number + + Examples: + >>> # xdoctest: +SKIP + >>> output = transformer_model( + ... src, tgt, src_mask=src_mask, tgt_mask=tgt_mask + ... ) + """ + is_batched = src.dim() == 3 + if not self.batch_first and src.size(1) != tgt.size(1) and is_batched: + raise RuntimeError("the batch number of src and tgt must be equal") + elif self.batch_first and src.size(0) != tgt.size(0) and is_batched: + raise RuntimeError("the batch number of src and tgt must be equal") + + if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model: + raise RuntimeError( + "the feature number of src and tgt must be equal to d_model" + ) + + memory = self.encoder( + src, + mask=src_mask, + src_key_padding_mask=src_key_padding_mask, + is_causal=src_is_causal, + ) + output = self.decoder( + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + tgt_is_causal=tgt_is_causal, + memory_is_causal=memory_is_causal, + ) + return output + + @staticmethod + def generate_square_subsequent_mask( + sz: int, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> Tensor: + r"""Generate a square causal mask for the sequence. + + The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). + """ + return _generate_square_subsequent_mask(sz, dtype=dtype, device=device) + + def _reset_parameters(self) -> None: + r"""Initiate parameters in the transformer model.""" + for p in self.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + + +class TransformerEncoder(Module): + r"""TransformerEncoder is a stack of N encoder layers. + + This TransformerEncoder layer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. + + .. warning:: + All layers in the TransformerEncoder are initialized with the same parameters. + It is recommended to manually initialize the layers after creating the TransformerEncoder instance. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + + Examples: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + + __constants__ = ["norm"] + + def __init__( + self, + encoder_layer: "TransformerEncoderLayer", + num_layers: int, + norm: Module | None = None, + enable_nested_tensor: bool = True, + mask_check: bool = True, + ) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + # this attribute saves the value providedat object construction + self.enable_nested_tensor = enable_nested_tensor + # this attribute controls whether nested tensors are used + self.use_nested_tensor = enable_nested_tensor + self.mask_check = mask_check + + enc_layer = "encoder_layer" + why_not_sparsity_fast_path = "" + if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer): + why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer" + elif encoder_layer.norm_first: + why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True" + elif not encoder_layer.self_attn.batch_first: + why_not_sparsity_fast_path = ( + f"{enc_layer}.self_attn.batch_first was not True" + + "(use batch_first for better inference performance)" + ) + elif not encoder_layer.self_attn._qkv_same_embed_dim: + why_not_sparsity_fast_path = ( + f"{enc_layer}.self_attn._qkv_same_embed_dim was not True" + ) + elif encoder_layer.self_attn.in_proj_bias is None: + why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False" + elif not encoder_layer.activation_relu_or_gelu: + why_not_sparsity_fast_path = ( + f"{enc_layer}.activation_relu_or_gelu was not True" + ) + elif encoder_layer.norm1.eps != encoder_layer.norm2.eps: + why_not_sparsity_fast_path = ( + f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps" + ) + elif encoder_layer.self_attn.num_heads % 2 == 1: + why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd" + + if enable_nested_tensor and why_not_sparsity_fast_path: + warnings.warn( + f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}", + stacklevel=2, + ) + self.use_nested_tensor = False + + def forward( + self, + src: Tensor, + mask: Tensor | None = None, + src_key_padding_mask: Tensor | None = None, + is_causal: bool | None = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + is_causal: If specified, applies a causal mask as ``mask``. + Default: ``None``; try to detect a causal mask. + Warning: + ``is_causal`` provides a hint that ``mask`` is the + causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + + Shape: + see the docs in :class:`~torch.nn.Transformer`. + """ + src_key_padding_mask = F._canonical_mask( + mask=src_key_padding_mask, + mask_name="src_key_padding_mask", + other_type=F._none_or_dtype(mask), + other_name="mask", + target_type=src.dtype, + ) + + mask = F._canonical_mask( + mask=mask, + mask_name="mask", + other_type=None, + other_name="", + target_type=src.dtype, + check_other=False, + ) + + output = src + convert_to_nested = False + first_layer = self.layers[0] + src_key_padding_mask_for_layers = src_key_padding_mask + why_not_sparsity_fast_path = "" + str_first_layer = "self.layers[0]" + batch_first = first_layer.self_attn.batch_first + is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled() + do_mask_check = getattr(self, "mask_check", True) + + if not is_fastpath_enabled: + why_not_sparsity_fast_path = ( + "torch.backends.mha.get_fastpath_enabled() was not True" + ) + elif not hasattr(self, "use_nested_tensor"): + why_not_sparsity_fast_path = "use_nested_tensor attribute not present" + elif not self.use_nested_tensor: + why_not_sparsity_fast_path = ( + "self.use_nested_tensor (set in init) was not True" + ) + elif first_layer.training: + why_not_sparsity_fast_path = f"{str_first_layer} was in training mode" + elif src.dim() != 3: + why_not_sparsity_fast_path = ( + f"input not batched; expected src.dim() of 3 but got {src.dim()}" + ) + elif src_key_padding_mask is None: + why_not_sparsity_fast_path = "src_key_padding_mask was None" + # This check avoids a call to torch._nested_tensor_from_mask_left_aligned() that + # breaks in torch.compile. + elif do_mask_check and torch.compiler.is_compiling(): + why_not_sparsity_fast_path = ( + "mask_check enabled with torch.compile or torch.export" + ) + elif do_mask_check and not torch._nested_tensor_from_mask_left_aligned( + src, src_key_padding_mask.logical_not() + ): + why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned" + elif output.is_nested: + why_not_sparsity_fast_path = "NestedTensor input is not supported" + elif mask is not None: + why_not_sparsity_fast_path = ( + "src_key_padding_mask and mask were both supplied" + ) + elif torch.is_autocast_enabled(): + why_not_sparsity_fast_path = "autocast is enabled" + + if not why_not_sparsity_fast_path: + tensor_args = ( + src, + first_layer.self_attn.in_proj_weight, + first_layer.self_attn.in_proj_bias, + first_layer.self_attn.out_proj.weight, + first_layer.self_attn.out_proj.bias, + first_layer.norm1.weight, + first_layer.norm1.bias, + first_layer.norm2.weight, + first_layer.norm2.bias, + first_layer.linear1.weight, + first_layer.linear1.bias, + first_layer.linear2.weight, + first_layer.linear2.bias, + ) + _supported_device_type = [ + "cpu", + "cuda", + "xpu", + torch.utils.backend_registration._privateuse1_backend_name, + ] + if torch.overrides.has_torch_function(tensor_args): + why_not_sparsity_fast_path = "some Tensor argument has_torch_function" + elif src.device.type not in _supported_device_type: + why_not_sparsity_fast_path = ( + f"src device is neither one of {_supported_device_type}" + ) + elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): + why_not_sparsity_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + + if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None): + convert_to_nested = True + output = torch._nested_tensor_from_mask( + output, src_key_padding_mask.logical_not(), mask_check=False + ) + src_key_padding_mask_for_layers = None + + seq_len = _get_seq_len(src, batch_first) + is_causal = _detect_is_causal_mask(mask, is_causal, seq_len) + + for mod in self.layers: + output = mod( + output, + src_mask=mask, + is_causal=is_causal, + src_key_padding_mask=src_key_padding_mask_for_layers, + ) + + if convert_to_nested: + output = output.to_padded_tensor(0.0, src.size()) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(Module): + r"""TransformerDecoder is a stack of N decoder layers. + + This TransformerDecoder layer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. + + .. warning:: + All layers in the TransformerDecoder are initialized with the same parameters. + It is recommended to manually initialize the layers after creating the TransformerDecoder instance. + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + Examples: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + + __constants__ = ["norm"] + + def __init__( + self, + decoder_layer: "TransformerDecoderLayer", + num_layers: int, + norm: Module | None = None, + ) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Tensor | None = None, + memory_mask: Tensor | None = None, + tgt_key_padding_mask: Tensor | None = None, + memory_key_padding_mask: Tensor | None = None, + tgt_is_causal: bool | None = None, + memory_is_causal: bool = False, + ) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. + Default: ``None``; try to detect a causal mask. + Warning: + ``tgt_is_causal`` provides a hint that ``tgt_mask`` is + the causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + memory_is_causal: If specified, applies a causal mask as + ``memory mask``. + Default: ``False``. + Warning: + ``memory_is_causal`` provides a hint that + ``memory_mask`` is the causal mask. Providing incorrect + hints can result in incorrect execution, including + forward and backward compatibility. + + Shape: + see the docs in :class:`~torch.nn.Transformer`. + """ + output = tgt + + seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first) + tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len) + + for mod in self.layers: + output = mod( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + tgt_is_causal=tgt_is_causal, + memory_is_causal=memory_is_causal, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerEncoderLayer(Module): + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + + This TransformerEncoderLayer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. + + TransformerEncoderLayer can handle either traditional torch.tensor inputs, + or Nested Tensor inputs. Derived classes are expected to similarly accept + both input formats. (Not all combinations of inputs are currently + supported by TransformerEncoderLayer while Nested Tensor is in prototype + state.) + + If you are implementing a custom layer, you may derive it either from + the Module or TransformerEncoderLayer class. If your custom layer + supports both torch.Tensors and Nested Tensors inputs, make its + implementation a derived class of TransformerEncoderLayer. If your custom + Layer supports only torch.Tensor inputs, derive its implementation from + Module. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + norm_first: if ``True``, layer norm is done prior to attention and feedforward + operations, respectively. Otherwise it's done after. Default: ``False`` (after). + bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive + bias. Default: ``True``. + + Examples: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + + Alternatively, when ``batch_first`` is ``True``: + >>> encoder_layer = nn.TransformerEncoderLayer( + ... d_model=512, nhead=8, batch_first=True + ... ) + >>> src = torch.rand(32, 10, 512) + >>> out = encoder_layer(src) + + Fast path: + forward() will use a special optimized implementation described in + `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following + conditions are met: + + - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor + argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``) + - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu`` + - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed + - if src is a `NestedTensor `_, neither ``src_mask`` + nor ``src_key_padding_mask`` is passed + - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case + unless the caller has manually modified one without modifying the other) + + If the optimized implementation is in use, a + `NestedTensor `_ can be + passed for ``src`` to represent padding more efficiently than using a padding + mask. In this case, a `NestedTensor `_ will be + returned, and an additional speedup proportional to the fraction of the input that + is padding can be expected. + + .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`: + https://arxiv.org/abs/2205.14135 + + """ + + __constants__ = ["norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str | Callable[[Tensor], Tensor] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = False, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + bias=bias, + batch_first=batch_first, + **factory_kwargs, + ) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) + + self.norm_first = norm_first + # pyrefly: ignore [bad-argument-type] + self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + # pyrefly: ignore [bad-argument-type] + self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + + # We can't test self.activation in forward() in TorchScript, + # so stash some information about it instead. + if activation is F.relu or isinstance(activation, torch.nn.ReLU): + self.activation_relu_or_gelu = 1 + elif activation is F.gelu or isinstance(activation, torch.nn.GELU): + self.activation_relu_or_gelu = 2 + else: + self.activation_relu_or_gelu = 0 + self.activation = activation + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + src_mask: Tensor | None = None, + src_key_padding_mask: Tensor | None = None, + is_causal: bool = False, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + is_causal: If specified, applies a causal mask as ``src mask``. + Default: ``False``. + Warning: + ``is_causal`` provides a hint that ``src_mask`` is the + causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + + Shape: + see the docs in :class:`~torch.nn.Transformer`. + """ + src_key_padding_mask = F._canonical_mask( + mask=src_key_padding_mask, + mask_name="src_key_padding_mask", + other_type=F._none_or_dtype(src_mask), + other_name="src_mask", + target_type=src.dtype, + ) + + src_mask = F._canonical_mask( + mask=src_mask, + mask_name="src_mask", + other_type=None, + other_name="", + target_type=src.dtype, + check_other=False, + ) + + is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled() + + why_not_sparsity_fast_path = "" + if not is_fastpath_enabled: + why_not_sparsity_fast_path = ( + "torch.backends.mha.get_fastpath_enabled() was not True" + ) + elif src.dim() != 3: + why_not_sparsity_fast_path = ( + f"input not batched; expected src.dim() of 3 but got {src.dim()}" + ) + elif self.training: + why_not_sparsity_fast_path = "training is enabled" + elif not self.self_attn.batch_first: + why_not_sparsity_fast_path = "self_attn.batch_first was not True" + elif self.self_attn.in_proj_bias is None: + why_not_sparsity_fast_path = "self_attn was passed bias=False" + elif not self.self_attn._qkv_same_embed_dim: + why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True" + elif not self.activation_relu_or_gelu: + why_not_sparsity_fast_path = "activation_relu_or_gelu was not True" + elif self.norm1.eps != self.norm2.eps: + why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps" + elif src.is_nested and ( + src_key_padding_mask is not None or src_mask is not None + ): + why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input" + elif self.self_attn.num_heads % 2 == 1: + why_not_sparsity_fast_path = "num_head is odd" + elif torch.is_autocast_enabled(): + why_not_sparsity_fast_path = "autocast is enabled" + elif any( + len(getattr(m, "_forward_hooks", {})) + + len(getattr(m, "_forward_pre_hooks", {})) + for m in self.modules() + ): + why_not_sparsity_fast_path = "forward pre-/hooks are attached to the module" + if not why_not_sparsity_fast_path: + tensor_args = ( + src, + self.self_attn.in_proj_weight, + self.self_attn.in_proj_bias, + self.self_attn.out_proj.weight, + self.self_attn.out_proj.bias, + self.norm1.weight, + self.norm1.bias, + self.norm2.weight, + self.norm2.bias, + self.linear1.weight, + self.linear1.bias, + self.linear2.weight, + self.linear2.bias, + ) + + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + _supported_device_type = [ + "cpu", + "cuda", + "xpu", + torch.utils.backend_registration._privateuse1_backend_name, + ] + if torch.overrides.has_torch_function(tensor_args): + why_not_sparsity_fast_path = "some Tensor argument has_torch_function" + elif not all( + (x.device.type in _supported_device_type) for x in tensor_args + ): + why_not_sparsity_fast_path = ( + "some Tensor argument's device is neither one of " + f"{_supported_device_type}" + ) + elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): + why_not_sparsity_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + + if not why_not_sparsity_fast_path: + merged_mask, mask_type = self.self_attn.merge_masks( + src_mask, src_key_padding_mask, src + ) + return torch._transformer_encoder_layer_fwd( + src, + self.self_attn.embed_dim, + self.self_attn.num_heads, + self.self_attn.in_proj_weight, + self.self_attn.in_proj_bias, + self.self_attn.out_proj.weight, + self.self_attn.out_proj.bias, + self.activation_relu_or_gelu == 2, + self.norm_first, + self.norm1.eps, + self.norm1.weight, + self.norm1.bias, + self.norm2.weight, + self.norm2.bias, + self.linear1.weight, + self.linear1.bias, + self.linear2.weight, + self.linear2.bias, + merged_mask, + mask_type, + ) + + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + x = src + if self.norm_first: + x = x + self._sa_block( + self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal + ) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1( + x + + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal) + ) + x = self.norm2(x + self._ff_block(x)) + + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, + is_causal: bool = False, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + is_causal=is_causal, + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class TransformerDecoderLayer(Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + + This TransformerDecoderLayer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + norm_first: if ``True``, layer norm is done prior to self attention, multihead + attention and feedforward operations, respectively. Otherwise it's done after. + Default: ``False`` (after). + bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive + bias. Default: ``True``. + + Examples: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + + Alternatively, when ``batch_first`` is ``True``: + >>> decoder_layer = nn.TransformerDecoderLayer( + ... d_model=512, nhead=8, batch_first=True + ... ) + >>> memory = torch.rand(32, 10, 512) + >>> tgt = torch.rand(32, 20, 512) + >>> out = decoder_layer(tgt, memory) + """ + + __constants__ = ["norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str | Callable[[Tensor], Tensor] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = False, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + bias=bias, + **factory_kwargs, + ) + self.multihead_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + bias=bias, + **factory_kwargs, + ) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) + + self.norm_first = norm_first + # pyrefly: ignore [bad-argument-type] + self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + # pyrefly: ignore [bad-argument-type] + self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + # pyrefly: ignore [bad-argument-type] + self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = F.relu + super().__setstate__(state) + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Tensor | None = None, + memory_mask: Tensor | None = None, + tgt_key_padding_mask: Tensor | None = None, + memory_key_padding_mask: Tensor | None = None, + tgt_is_causal: bool = False, + memory_is_causal: bool = False, + ) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. + Default: ``False``. + Warning: + ``tgt_is_causal`` provides a hint that ``tgt_mask`` is + the causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + memory_is_causal: If specified, applies a causal mask as + ``memory mask``. + Default: ``False``. + Warning: + ``memory_is_causal`` provides a hint that + ``memory_mask`` is the causal mask. Providing incorrect + hints can result in incorrect execution, including + forward and backward compatibility. + + Shape: + see the docs in :class:`~torch.nn.Transformer`. + """ + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + + x = tgt + if self.norm_first: + x = x + self._sa_block( + self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal + ) + x = x + self._mha_block( + self.norm2(x), + memory, + memory_mask, + memory_key_padding_mask, + memory_is_causal, + ) + x = x + self._ff_block(self.norm3(x)) + else: + x = self.norm1( + x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal) + ) + x = self.norm2( + x + + self._mha_block( + x, memory, memory_mask, memory_key_padding_mask, memory_is_causal + ) + ) + x = self.norm3(x + self._ff_block(x)) + + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, + is_causal: bool = False, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + need_weights=False, + )[0] + return self.dropout1(x) + + # multihead attention block + def _mha_block( + self, + x: Tensor, + mem: Tensor, + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, + is_causal: bool = False, + ) -> Tensor: + x = self.multihead_attn( + x, + mem, + mem, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + need_weights=False, + )[0] + return self.dropout2(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout3(x) + + +def _get_clones(module, N): + # FIXME: copy.deepcopy() is not defined on nn.module + return ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError(f"activation should be relu/gelu, not {activation}") + + +def _detect_is_causal_mask( + mask: Tensor | None, + is_causal: bool | None = None, + size: int | None = None, +) -> bool: + """Return whether the given attention mask is causal. + + Warning: + If ``is_causal`` is not ``None``, its value will be returned as is. If a + user supplies an incorrect ``is_causal`` hint, + + ``is_causal=False`` when the mask is in fact a causal attention.mask + may lead to reduced performance relative to what would be achievable + with ``is_causal=True``; + ``is_causal=True`` when the mask is in fact not a causal attention.mask + may lead to incorrect and unpredictable execution - in some scenarios, + a causal mask may be applied based on the hint, in other execution + scenarios the specified mask may be used. The choice may not appear + to be deterministic, in that a number of factors like alignment, + hardware SKU, etc influence the decision whether to use a mask or + rely on the hint. + ``size`` if not None, check whether the mask is a causal mask of the provided size + Otherwise, checks for any causal mask. + """ + # Prevent type refinement + make_causal = is_causal is True + + if is_causal is None and mask is not None: + sz = size if size is not None else mask.size(-2) + causal_comparison = _generate_square_subsequent_mask( + sz, device=mask.device, dtype=mask.dtype + ) + + # Do not use `torch.equal` so we handle batched masks by + # broadcasting the comparison. + if mask.size() == causal_comparison.size(): + make_causal = bool((mask == causal_comparison).all()) + else: + make_causal = False + + return make_causal diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/upsampling.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/upsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..29e58bc6a9f3779924584e2934874a1333b3e501 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/upsampling.py @@ -0,0 +1,298 @@ +# mypy: allow-untyped-defs + +import torch.nn.functional as F +from torch import Tensor +from torch.nn.common_types import _ratio_2_t, _ratio_any_t, _size_2_t, _size_any_t + +from .module import Module + + +__all__ = ["Upsample", "UpsamplingNearest2d", "UpsamplingBilinear2d"] + + +class Upsample(Module): + r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. + + The input data is assumed to be of the form + `minibatch x channels x [optional depth] x [optional height] x width`. + Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor. + + The algorithms available for upsampling are nearest neighbor and linear, + bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor, + respectively. + + One can either give a :attr:`scale_factor` or the target output :attr:`size` to + calculate the output size. (You cannot give both, as it is ambiguous) + + Args: + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional): + output spatial sizes + scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional): + multiplier for spatial size. Has to match input size if it is a tuple. + mode (str, optional): the upsampling algorithm: one of ``'nearest'``, + ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``. + Default: ``'nearest'`` + align_corners (bool, optional): if ``True``, the corner pixels of the input + and output tensors are aligned, and thus preserving the values at + those pixels. This only has effect when :attr:`mode` is + ``'linear'``, ``'bilinear'``, ``'bicubic'``, or ``'trilinear'``. + Default: ``False`` + recompute_scale_factor (bool, optional): recompute the scale_factor for use in the + interpolation calculation. If `recompute_scale_factor` is ``True``, then + `scale_factor` must be passed in and `scale_factor` is used to compute the + output `size`. The computed output `size` will be used to infer new scales for + the interpolation. Note that when `scale_factor` is floating-point, it may differ + from the recomputed `scale_factor` due to rounding and precision issues. + If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will + be used directly for interpolation. + + Shape: + - Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})` + or :math:`(N, C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = \left\lfloor D_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor + + .. warning:: + With ``align_corners = True``, the linearly interpolating modes + (`linear`, `bilinear`, `bicubic`, and `trilinear`) don't proportionally + align the output and input pixels, and thus the output values can depend + on the input size. This was the default behavior for these modes up to + version 0.3.1. Since then, the default behavior is + ``align_corners = False``. See below for concrete examples on how this + affects the outputs. + + .. note:: + If you want downsampling/general resizing, you should use :func:`~nn.functional.interpolate`. + + Examples:: + + >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2) + >>> input + tensor([[[[1., 2.], + [3., 4.]]]]) + + >>> m = nn.Upsample(scale_factor=2, mode='nearest') + >>> m(input) + tensor([[[[1., 1., 2., 2.], + [1., 1., 2., 2.], + [3., 3., 4., 4.], + [3., 3., 4., 4.]]]]) + + >>> # xdoctest: +IGNORE_WANT("other tests seem to modify printing styles") + >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False + >>> m(input) + tensor([[[[1.0000, 1.2500, 1.7500, 2.0000], + [1.5000, 1.7500, 2.2500, 2.5000], + [2.5000, 2.7500, 3.2500, 3.5000], + [3.0000, 3.2500, 3.7500, 4.0000]]]]) + + >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + >>> m(input) + tensor([[[[1.0000, 1.3333, 1.6667, 2.0000], + [1.6667, 2.0000, 2.3333, 2.6667], + [2.3333, 2.6667, 3.0000, 3.3333], + [3.0000, 3.3333, 3.6667, 4.0000]]]]) + + >>> # Try scaling the same data in a larger tensor + >>> input_3x3 = torch.zeros(3, 3).view(1, 1, 3, 3) + >>> input_3x3[:, :, :2, :2].copy_(input) + tensor([[[[1., 2.], + [3., 4.]]]]) + >>> input_3x3 + tensor([[[[1., 2., 0.], + [3., 4., 0.], + [0., 0., 0.]]]]) + + >>> # xdoctest: +IGNORE_WANT("seems to fail when other tests are run in the same session") + >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False + >>> # Notice that values in top left corner are the same with the small input (except at boundary) + >>> m(input_3x3) + tensor([[[[1.0000, 1.2500, 1.7500, 1.5000, 0.5000, 0.0000], + [1.5000, 1.7500, 2.2500, 1.8750, 0.6250, 0.0000], + [2.5000, 2.7500, 3.2500, 2.6250, 0.8750, 0.0000], + [2.2500, 2.4375, 2.8125, 2.2500, 0.7500, 0.0000], + [0.7500, 0.8125, 0.9375, 0.7500, 0.2500, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) + + >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + >>> # Notice that values in top left corner are now changed + >>> m(input_3x3) + tensor([[[[1.0000, 1.4000, 1.8000, 1.6000, 0.8000, 0.0000], + [1.8000, 2.2000, 2.6000, 2.2400, 1.1200, 0.0000], + [2.6000, 3.0000, 3.4000, 2.8800, 1.4400, 0.0000], + [2.4000, 2.7200, 3.0400, 2.5600, 1.2800, 0.0000], + [1.2000, 1.3600, 1.5200, 1.2800, 0.6400, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) + """ + + __constants__ = [ + "size", + "scale_factor", + "mode", + "align_corners", + "name", + "recompute_scale_factor", + ] + name: str + size: _size_any_t | None + scale_factor: _ratio_any_t | None + mode: str + align_corners: bool | None + recompute_scale_factor: bool | None + + def __init__( + self, + size: _size_any_t | None = None, + scale_factor: _ratio_any_t | None = None, + mode: str = "nearest", + align_corners: bool | None = None, + recompute_scale_factor: bool | None = None, + ) -> None: + super().__init__() + self.name = type(self).__name__ + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + + def forward(self, input: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + return F.interpolate( + input, + self.size, + self.scale_factor, + self.mode, + self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + ) + + def __setstate__(self, state): + if "recompute_scale_factor" not in state: + state["recompute_scale_factor"] = True + + super().__setstate__(state) + + def extra_repr(self) -> str: + """ + Return the extra representation of the module. + """ + if self.scale_factor is not None: + info = "scale_factor=" + repr(self.scale_factor) + else: + info = "size=" + repr(self.size) + info += ", mode=" + repr(self.mode) + return info + + +class UpsamplingNearest2d(Upsample): + r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input channels. + + To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor` + as it's constructor argument. + + When :attr:`size` is given, it is the output size of the image `(h, w)`. + + Args: + size (int or Tuple[int, int], optional): output spatial sizes + scale_factor (float or Tuple[float, float], optional): multiplier for + spatial size. + + .. warning:: + This class is deprecated in favor of :func:`~nn.functional.interpolate`. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})` where + + .. math:: + H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor + + Examples:: + + >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2) + >>> input + tensor([[[[1., 2.], + [3., 4.]]]]) + + >>> m = nn.UpsamplingNearest2d(scale_factor=2) + >>> m(input) + tensor([[[[1., 1., 2., 2.], + [1., 1., 2., 2.], + [3., 3., 4., 4.], + [3., 3., 4., 4.]]]]) + """ + + def __init__( + self, + size: _size_2_t | None = None, + scale_factor: _ratio_2_t | None = None, + ) -> None: + super().__init__(size, scale_factor, mode="nearest") + + +class UpsamplingBilinear2d(Upsample): + r"""Applies a 2D bilinear upsampling to an input signal composed of several input channels. + + To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor` + as it's constructor argument. + + When :attr:`size` is given, it is the output size of the image `(h, w)`. + + Args: + size (int or Tuple[int, int], optional): output spatial sizes + scale_factor (float or Tuple[float, float], optional): multiplier for + spatial size. + + .. warning:: + This class is deprecated in favor of :func:`~nn.functional.interpolate`. It is + equivalent to ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})` where + + .. math:: + H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor + + Examples:: + + >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2) + >>> input + tensor([[[[1., 2.], + [3., 4.]]]]) + + >>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?") + >>> m = nn.UpsamplingBilinear2d(scale_factor=2) + >>> m(input) + tensor([[[[1.0000, 1.3333, 1.6667, 2.0000], + [1.6667, 2.0000, 2.3333, 2.6667], + [2.3333, 2.6667, 3.0000, 3.3333], + [3.0000, 3.3333, 3.6667, 4.0000]]]]) + """ + + def __init__( + self, + size: _size_2_t | None = None, + scale_factor: _ratio_2_t | None = None, + ) -> None: + super().__init__(size, scale_factor, mode="bilinear", align_corners=True) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5dffadefe152d527090aef870f87a7a7565eac25 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/modules/utils.py @@ -0,0 +1,83 @@ +# mypy: allow-untyped-defs +import collections +from itertools import repeat +from typing import Any + + +__all__ = ["consume_prefix_in_state_dict_if_present"] + + +def _ntuple(n, name="parse"): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return tuple(x) + return tuple(repeat(x, n)) + + parse.__name__ = name + return parse + + +_single = _ntuple(1, "_single") +_pair = _ntuple(2, "_pair") +_triple = _ntuple(3, "_triple") +_quadruple = _ntuple(4, "_quadruple") + + +def _reverse_repeat_tuple(t, n): + r"""Reverse the order of `t` and repeat each element for `n` times. + + This can be used to translate padding arg used by Conv and Pooling modules + to the ones used by `F.pad`. + """ + return tuple(x for x in reversed(t) for _ in range(n)) + + +def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]: + import torch + + if isinstance(out_size, (int, torch.SymInt)): + # pyrefly: ignore [bad-return] + return out_size + if len(defaults) <= len(out_size): + raise ValueError(f"Input dimension should be at least {len(out_size) + 1}") + return [ + v if v is not None else d + for v, d in zip(out_size, defaults[-len(out_size) :], strict=False) + ] + + +def consume_prefix_in_state_dict_if_present( + state_dict: dict[str, Any], + prefix: str, +) -> None: + r"""Strip the prefix in state_dict in place, if any. + + .. note:: + Given a `state_dict` from a DP/DDP model, a local model can load it by applying + `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling + :meth:`torch.nn.Module.load_state_dict`. + + Args: + state_dict (OrderedDict): a state-dict to be loaded to the model. + prefix (str): prefix. + """ + keys = list(state_dict.keys()) + for key in keys: + if key.startswith(prefix): + newkey = key[len(prefix) :] + state_dict[newkey] = state_dict.pop(key) + + # also strip the prefix in metadata if any. + if hasattr(state_dict, "_metadata"): + keys = list(state_dict._metadata.keys()) + for key in keys: + # for the metadata dict, the key can be: + # '': for the DDP module, which we want to remove. + # 'module': for the actual model. + # 'module.xx.xx': for the rest. + if len(key) == 0: + continue + # handling both, 'module' case and 'module.' cases + if key == prefix.replace(".", "") or key.startswith(prefix): + newkey = key[len(prefix) :] + state_dict._metadata[newkey] = state_dict._metadata.pop(key) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad8648d10aadc8dec59ea7ebc54aa77cd60ee4f5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__init__.py @@ -0,0 +1,27 @@ +from typing_extensions import deprecated + +from torch.nn.parallel.data_parallel import data_parallel, DataParallel +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.nn.parallel.parallel_apply import parallel_apply +from torch.nn.parallel.replicate import replicate +from torch.nn.parallel.scatter_gather import gather, scatter + + +__all__ = [ + "replicate", + "scatter", + "parallel_apply", + "gather", + "data_parallel", + "DataParallel", + "DistributedDataParallel", +] + + +@deprecated( + "`torch.nn.parallel.DistributedDataParallelCPU` is deprecated, " + "please use `torch.nn.parallel.DistributedDataParallel` instead.", + category=FutureWarning, +) +class DistributedDataParallelCPU(DistributedDataParallel): + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d03115af949a0f50aaadcccd580dd26a391e0d34 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c119ef8d416a337bad2114c27be6b39da20c08c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/comm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/comm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0adcf94590dc6361c15d5c1e7d50aab7a9d21449 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/comm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/data_parallel.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/data_parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e5f39353b6a4e8af1b38673a90d267c2254080a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/data_parallel.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed558965bd76998ee24ebbc4bad1118846e3a37a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a742e59bc924ad9f4e1a941888ec8cb954ae877 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/scatter_gather.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/scatter_gather.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c59e2b484326058231bdef4915952f984354735 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/scatter_gather.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/_functions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..70a2eace9eff15b06df7958588afd8e1580bb8a7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/_functions.py @@ -0,0 +1,131 @@ +import warnings +from itertools import chain + +import torch +from torch._utils import _get_device_index +from torch.autograd import Function +from torch.nn.parallel import comm + + +class Broadcast(Function): + @staticmethod + def forward(ctx, target_gpus, *inputs): + assert all(i.device.type != "cpu" for i in inputs), ( + "Broadcast function not implemented for CPU tensors" + ) + target_gpus = [_get_device_index(x, True) for x in target_gpus] + ctx.target_gpus = target_gpus + if len(inputs) == 0: + return () + ctx.num_inputs = len(inputs) + ctx.input_device = inputs[0].get_device() + outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus) + non_differentiables = [] + for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]): + if not input_requires_grad: + non_differentiables.extend(output[idx] for output in outputs) + ctx.mark_non_differentiable(*non_differentiables) + return tuple(chain.from_iterable(outputs)) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None,) + ReduceAddCoalesced.apply( + ctx.input_device, ctx.num_inputs, *grad_outputs + ) + + +class ReduceAddCoalesced(Function): + @staticmethod + def forward(ctx, destination, num_inputs, *grads): + ctx.target_gpus = [ + grads[i].get_device() for i in range(0, len(grads), num_inputs) + ] + + grads_ = [grads[i : i + num_inputs] for i in range(0, len(grads), num_inputs)] + return comm.reduce_add_coalesced(grads_, destination) + + @staticmethod + def backward(ctx, *grad_outputs): + return ( + None, + None, + ) + Broadcast.apply(ctx.target_gpus, *grad_outputs) + + +class Gather(Function): + @staticmethod + def forward(ctx, target_device, dim, *inputs): + assert all(i.device.type != "cpu" for i in inputs), ( + "Gather function not implemented for CPU tensors" + ) + if target_device == "cpu": + ctx.target_device = "cpu" + else: + target_device = _get_device_index(target_device, True) + ctx.target_device = target_device + ctx.dim = dim + ctx.input_gpus = tuple(i.get_device() for i in inputs) + if all(t.dim() == 0 for t in inputs) and dim == 0: + inputs = tuple(t.view(1) for t in inputs) + warnings.warn( + "Was asked to gather along dimension 0, but all " + "input tensors were scalars; will instead unsqueeze " + "and return a vector.", + stacklevel=2, + ) + ctx.unsqueezed_scalar = True + else: + ctx.unsqueezed_scalar = False + ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs) + return comm.gather(inputs, ctx.dim, ctx.target_device) + + @staticmethod + def backward(ctx, grad_output): + scattered_grads = Scatter.apply( + ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output + ) + if ctx.unsqueezed_scalar: + scattered_grads = tuple(g[0] for g in scattered_grads) + return (None, None) + scattered_grads + + +class Scatter(Function): + @staticmethod + def forward(ctx, target_gpus, chunk_sizes, dim, input): + target_gpus = [_get_device_index(x, True) for x in target_gpus] + ctx.dim = dim + ctx.input_device = input.get_device() if input.device.type != "cpu" else -1 + streams = None + if torch.accelerator.is_available() and ctx.input_device == -1: + # Perform CPU to GPU copies in a background stream + streams = [_get_stream(torch.device(device)) for device in target_gpus] + outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams) + # Synchronize with the copy stream + if streams is not None: + for i, output in enumerate(outputs): + with torch.accelerator.device_index(target_gpus[i]): + main_stream = torch.accelerator.current_stream() + main_stream.wait_stream(streams[i]) + output.record_stream(main_stream) + return outputs + + @staticmethod + def backward(ctx, *grad_output): + return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output) + + +# background streams used for copying +_streams: list[torch.Stream | None] | None = None + + +def _get_stream(device: torch.device): + """Get a background stream for copying between CPU and target device.""" + global _streams + if device.type == "cpu" or not torch.accelerator.is_available(): + return None + assert torch.accelerator.current_accelerator().type == device.type + if _streams is None: + _streams = [None] * torch.accelerator.device_count() + if _streams[device.index] is None: + _streams[device.index] = torch.Stream(device.index) + return _streams[device.index] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/comm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..255c0c4b332712a714610801f11c8e2b33df3671 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/comm.py @@ -0,0 +1,261 @@ +# mypy: allow-untyped-defs +import warnings + +import torch +from torch._utils import ( + _flatten_dense_tensors, + _get_device_index, + _handle_complex, + _reorder_tensors_as, + _take_tensors, + _unflatten_dense_tensors, +) +from torch.cuda import nccl + + +def broadcast(tensor, devices=None, *, out=None): + r"""Broadcasts a tensor to specified GPU devices. + + Args: + tensor (Tensor): tensor to broadcast. Can be on CPU or GPU. + devices (Iterable[torch.device, str or int], optional): an iterable of + GPU devices, among which to broadcast. + out (Sequence[Tensor], optional, keyword-only): the GPU tensors to + store output results. + + .. note:: + Exactly one of :attr:`devices` and :attr:`out` must be specified. + + Returns: + - If :attr:`devices` is specified, + a tuple containing copies of :attr:`tensor`, placed on + :attr:`devices`. + - If :attr:`out` is specified, + a tuple containing :attr:`out` tensors, each containing a copy of + :attr:`tensor`. + """ + tensor = _handle_complex(tensor) + if not ((devices is None) ^ (out is None)): + raise RuntimeError( + f"Exactly one of 'devices' and 'out' must be specified, but got devices={devices} and out={out}" + ) + if devices is not None: + devices = [_get_device_index(d) for d in devices] + return torch._C._broadcast(tensor, devices) + else: + # pyrefly: ignore [bad-argument-type] + return torch._C._broadcast_out(tensor, out) + + +def broadcast_coalesced(tensors, devices, buffer_size=10485760): + """Broadcast a sequence of tensors to the specified GPUs. + + Small tensors are first coalesced into a buffer to reduce the number of synchronizations. + + Args: + tensors (sequence): tensors to broadcast. Must be on the same device, + either CPU or GPU. + devices (Iterable[torch.device, str or int]): an iterable of GPU + devices, among which to broadcast. + buffer_size (int): maximum size of the buffer used for coalescing + + Returns: + A tuple containing copies of :attr:`tensor`, placed on :attr:`devices`. + """ + devices = [_get_device_index(d) for d in devices] + tensors = [_handle_complex(t) for t in tensors] + return torch._C._broadcast_coalesced(tensors, devices, buffer_size) + + +def reduce_add(inputs, destination=None): + """Sum tensors from multiple GPUs. + + All inputs should have matching shapes, dtype, and layout. The output tensor + will be of the same shape, dtype, and layout. + + Args: + inputs (Iterable[Tensor]): an iterable of tensors to add. + destination (int, optional): a device on which the output will be + placed (default: current device). + + Returns: + A tensor containing an elementwise sum of all inputs, placed on the + :attr:`destination` device. + """ + destination = _get_device_index(destination, optional=True) + input_size = inputs[0].size() + root_index = None # index of input tensor that already is on the correct device + for i, inp in enumerate(inputs): + assert inp.device.type != "cpu", "reduce_add expects all inputs to be on GPUs" + if inp.get_device() == destination: + root_index = i + if inp.size() != input_size: + got = "x".join(str(x) for x in inp.size()) + expected = "x".join(str(x) for x in input_size) + raise ValueError( + f"input {i} has invalid size: got {got}, but expected {expected}" + ) + if root_index is None: + raise RuntimeError( + "reduce_add expects destination to be on the same GPU with one of the tensors" + ) + + if len(inputs) == 1: + return inputs[0] + + if nccl.is_available(inputs): + result = torch.empty_like(inputs[root_index]) + nccl.reduce(inputs, output=result, root=root_index) + else: + destination_device = torch.device(inputs[root_index].device.type, destination) + nonroot = [t for i, t in enumerate(inputs) if i != root_index] + # make a new tensor w/o clone + result = inputs[root_index] + nonroot[0].to( + device=destination_device, non_blocking=True + ) + for other in nonroot[1:]: + result.add_(other.to(device=destination_device, non_blocking=True)) + return result + + +def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760): + """Sum tensors from multiple GPUs. + + Small tensors are first coalesced into a buffer to reduce the number + of synchronizations. + + Args: + inputs (Iterable[Iterable[Tensor]]): iterable of iterables that + contain tensors from a single device. + destination (int, optional): a device on which the output will be + placed (default: current device). + buffer_size (int): maximum size of the buffer used for coalescing + + Returns: + A tuple of tensors containing an elementwise sum of each group of + inputs, placed on the ``destination`` device. + """ + # TODO: When `len(inputs) == 1` and all inputs are on `destination`, just + # return `inputs`. + dense_tensors: list[list] = [[] for _ in inputs] # shape (num_gpus, num_tensors) + output = [] + ref_order = [] + # process sparse ones first since they may have different sizes on different gpus + for tensor_at_gpus in zip(*inputs, strict=True): + if all(t.is_sparse for t in tensor_at_gpus): + result = reduce_add(tensor_at_gpus, destination) # this will be sparse too + output.append(result) + ref_order.append(tensor_at_gpus[0]) + else: + for coll, t in zip(dense_tensors, tensor_at_gpus, strict=True): + coll.append(t.to_dense() if t.is_sparse else t) + ref_order.append(dense_tensors[0][-1]) + itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors] + # now the dense ones, which have consistent sizes + for chunks in zip(*itrs, strict=True): + flat_tensors = [ + _flatten_dense_tensors(chunk) for chunk in chunks + ] # (num_gpus,) + flat_result = reduce_add(flat_tensors, destination) + for t in _unflatten_dense_tensors(flat_result, chunks[0]): + # The unflattened tensors do not share storage, and we don't expose + # base flat tensor anyways, so give them different version counters. + # See NOTE [ Version Counter in comm.*_coalesced ] + output.append(t.data) + return tuple(_reorder_tensors_as(output, ref_order)) + + +def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None): + """Scatters tensor across multiple GPUs. + + Args: + tensor (Tensor): tensor to scatter. Can be on CPU or GPU. + devices (Iterable[torch.device, str or int], optional): an iterable of + GPU devices, among which to scatter. + chunk_sizes (Iterable[int], optional): sizes of chunks to be placed on + each device. It should match :attr:`devices` in length and sums to + ``tensor.size(dim)``. If not specified, :attr:`tensor` will be divided + into equal chunks. + dim (int, optional): A dimension along which to chunk :attr:`tensor`. + Default: ``0``. + streams (Iterable[torch.cuda.Stream], optional): an iterable of Streams, among + which to execute the scatter. If not specified, the default stream will + be utilized. + out (Sequence[Tensor], optional, keyword-only): the GPU tensors to + store output results. Sizes of these tensors must match that of + :attr:`tensor`, except for :attr:`dim`, where the total size must + sum to ``tensor.size(dim)``. + + .. note:: + Exactly one of :attr:`devices` and :attr:`out` must be specified. When + :attr:`out` is specified, :attr:`chunk_sizes` must not be specified and + will be inferred from sizes of :attr:`out`. + + Returns: + - If :attr:`devices` is specified, + a tuple containing chunks of :attr:`tensor`, placed on + :attr:`devices`. + - If :attr:`out` is specified, + a tuple containing :attr:`out` tensors, each containing a chunk of + :attr:`tensor`. + """ + tensor = _handle_complex(tensor) + if out is None: + # pyrefly: ignore [not-iterable] + devices = [_get_device_index(d) for d in devices] + return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams)) + else: + if devices is not None: + raise RuntimeError( + f"'devices' must not be specified when 'out' is specified, but got devices={devices}" + ) + if chunk_sizes is not None: + raise RuntimeError( + f"'chunk_sizes' must not be specified when 'out' is specified, but got chunk_sizes={chunk_sizes}" + ) + return tuple(torch._C._scatter_out(tensor, out, dim, streams)) + + +def gather(tensors, dim=0, destination=None, *, out=None): + r"""Gathers tensors from multiple GPU devices. + + Args: + tensors (Iterable[Tensor]): an iterable of tensors to gather. + Tensor sizes in all dimensions other than :attr:`dim` have to match. + dim (int, optional): a dimension along which the tensors will be + concatenated. Default: ``0``. + destination (torch.device, str, or int, optional): the output device. + Can be CPU or CUDA. Default: the current CUDA device. + out (Tensor, optional, keyword-only): the tensor to store gather result. + Its sizes must match those of :attr:`tensors`, except for :attr:`dim`, + where the size must equal ``sum(tensor.size(dim) for tensor in tensors)``. + Can be on CPU or CUDA. + + .. note:: + :attr:`destination` must not be specified when :attr:`out` is specified. + + Returns: + - If :attr:`destination` is specified, + a tensor located on :attr:`destination` device, that is a result of + concatenating :attr:`tensors` along :attr:`dim`. + - If :attr:`out` is specified, + the :attr:`out` tensor, now containing results of concatenating + :attr:`tensors` along :attr:`dim`. + """ + tensors = [_handle_complex(t) for t in tensors] + if out is None: + if destination == -1: + warnings.warn( + "Using -1 to represent CPU tensor is deprecated. Please use a " + 'device object or string instead, e.g., "cpu".', + FutureWarning, + stacklevel=2, + ) + destination = _get_device_index(destination, allow_cpu=True, optional=True) + return torch._C._gather(tensors, dim, destination) + else: + if destination is not None: + raise RuntimeError( + f"'destination' must not be specified when 'out' is specified, but got destination={destination}" + ) + return torch._C._gather_out(tensors, out, dim) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..4f2319439f092bed9a4277838dcb3b794de64b97 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py @@ -0,0 +1,289 @@ +# mypy: allow-untyped-defs +import operator +import warnings +from collections.abc import Sequence +from itertools import chain +from typing import Any, Generic, TypeVar + +import torch +from torch._utils import ( + _get_all_device_indices, + _get_available_device_type, + _get_device_index, + _get_devices_properties, +) +from torch.nn.modules import Module +from torch.nn.parallel.parallel_apply import parallel_apply +from torch.nn.parallel.replicate import replicate +from torch.nn.parallel.scatter_gather import gather, scatter_kwargs + + +__all__ = ["DataParallel", "data_parallel"] + + +def _check_balance(device_ids: Sequence[int | torch.device]) -> None: + imbalance_warn = """ + There is an imbalance between your GPUs. You may want to exclude GPU {} which + has less than 75% of the memory or cores of GPU {}. You can do so by setting + the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES + environment variable.""" + device_ids = [_get_device_index(x, True) for x in device_ids] + dev_props = _get_devices_properties(device_ids) + + def warn_imbalance(get_prop) -> bool: + values = [get_prop(props) for props in dev_props] + min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) + max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) + if min_val / max_val < 0.75: + warnings.warn( + imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]), + stacklevel=2, + ) + return True + return False + + if warn_imbalance(lambda props: props.total_memory): + return + if warn_imbalance(lambda props: props.multi_processor_count): + return + + +T = TypeVar("T", bound=Module) + + +class DataParallel(Module, Generic[T]): + r"""Implements data parallelism at the module level. + + This container parallelizes the application of the given :attr:`module` by + splitting the input across the specified devices by chunking in the batch + dimension (other objects will be copied once per device). In the forward + pass, the module is replicated on each device, and each replica handles a + portion of the input. During the backwards pass, gradients from each replica + are summed into the original module. + + The batch size should be larger than the number of GPUs used. + + .. warning:: + It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`, + instead of this class, to do multi-GPU training, even if there is only a single + node. See: :ref:`cuda-nn-ddp-instead` and :ref:`ddp`. + + Arbitrary positional and keyword inputs are allowed to be passed into + DataParallel but some types are specially handled. tensors will be + **scattered** on dim specified (default 0). tuple, list and dict types will + be shallow copied. The other types will be shared among different threads + and can be corrupted if written to in the model's forward pass. + + The parallelized :attr:`module` must have its parameters and buffers on + ``device_ids[0]`` before running this :class:`~torch.nn.DataParallel` + module. + + .. warning:: + In each forward, :attr:`module` is **replicated** on each device, so any + updates to the running module in ``forward`` will be lost. For example, + if :attr:`module` has a counter attribute that is incremented in each + ``forward``, it will always stay at the initial value because the update + is done on the replicas which are destroyed after ``forward``. However, + :class:`~torch.nn.DataParallel` guarantees that the replica on + ``device[0]`` will have its parameters and buffers sharing storage with + the base parallelized :attr:`module`. So **in-place** updates to the + parameters or buffers on ``device[0]`` will be recorded. E.g., + :class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm` + rely on this behavior to update the buffers. + + .. warning:: + Forward and backward hooks defined on :attr:`module` and its submodules + will be invoked ``len(device_ids)`` times, each with inputs located on + a particular device. Particularly, the hooks are only guaranteed to be + executed in correct order with respect to operations on corresponding + devices. For example, it is not guaranteed that hooks set via + :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before + `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but + that each such hook be executed before the corresponding + :meth:`~torch.nn.Module.forward` call of that device. + + .. warning:: + When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in + :func:`forward`, this wrapper will return a vector of length equal to + number of devices used in data parallelism, containing the result from + each device. + + .. note:: + There is a subtlety in using the + ``pack sequence -> recurrent network -> unpack sequence`` pattern in a + :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. + See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for + details. + + + Args: + module (Module): module to be parallelized + device_ids (list of int or torch.device): CUDA devices (default: all devices) + output_device (int or torch.device): device location of output (default: device_ids[0]) + + Attributes: + module (Module): the module to be parallelized + + Example:: + + >>> # xdoctest: +SKIP + >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) + >>> output = net(input_var) # input_var can be on any device, including CPU + """ + + # TODO: update notes/cuda.rst when this class handles 8+ GPUs well + + def __init__( + self, + module: T, + device_ids: Sequence[int | torch.device] | None = None, + output_device: int | torch.device | None = None, + dim: int = 0, + ) -> None: + super().__init__() + torch._C._log_api_usage_once("torch.nn.parallel.DataParallel") + device_type = _get_available_device_type() + if device_type is None or device_type == "mps": + self.module = module + self.device_ids = [] + return + + if device_ids is None: + device_ids = _get_all_device_indices() + + if device_ids is None: + raise RuntimeError("no available devices were found") + + if output_device is None: + output_device = device_ids[0] + + self.dim = dim + self.module = module + self.device_ids = [_get_device_index(x, True) for x in device_ids] + self.output_device = _get_device_index(output_device, True) + # pyrefly: ignore [read-only] + self.src_device_obj = torch.device(device_type, self.device_ids[0]) + + if device_type == "cuda": + _check_balance(self.device_ids) + + if len(self.device_ids) == 1: + self.module.to(self.src_device_obj) + + def forward(self, *inputs: Any, **kwargs: Any) -> Any: + with torch.autograd.profiler.record_function("DataParallel.forward"): + if not self.device_ids: + return self.module(*inputs, **kwargs) + + # pyrefly: ignore [bad-argument-type] + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError( + "module must have its parameters and buffers " + f"on device {self.src_device_obj} (device_ids[0]) but found one of " + f"them on device: {t.device}" + ) + + inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids) + # for forward function without any inputs, empty list and dict will be created + # so the module can be executed on one device which is the first one in device_ids + if not inputs and not module_kwargs: + inputs = ((),) + module_kwargs = ({},) + + if len(self.device_ids) == 1: + return self.module(*inputs[0], **module_kwargs[0]) + replicas = self.replicate(self.module, self.device_ids[: len(inputs)]) + outputs = self.parallel_apply(replicas, inputs, module_kwargs) + return self.gather(outputs, self.output_device) + + def replicate(self, module: T, device_ids: Sequence[int | torch.device]) -> list[T]: + return replicate(module, device_ids, not torch.is_grad_enabled()) + + def scatter( + self, + inputs: tuple[Any, ...], + kwargs: dict[str, Any] | None, + device_ids: Sequence[int | torch.device], + ) -> Any: + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def parallel_apply( + self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any + ) -> list[Any]: + return parallel_apply( + replicas, inputs, kwargs, self.device_ids[: len(replicas)] + ) + + def gather(self, outputs: Any, output_device: int | torch.device) -> Any: + return gather(outputs, output_device, dim=self.dim) + + +def data_parallel( + module: Module, + inputs: Any, + device_ids: Sequence[int | torch.device] | None = None, + output_device: int | torch.device | None = None, + dim: int = 0, + module_kwargs: Any | None = None, +) -> torch.Tensor: + r"""Evaluate module(input) in parallel across the GPUs given in device_ids. + + This is the functional version of the DataParallel module. + + Args: + module (Module): the module to evaluate in parallel + inputs (Tensor): inputs to the module + device_ids (list of int or torch.device): GPU ids on which to replicate module + output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU. + (default: device_ids[0]) + Returns: + a Tensor containing the result of module(input) located on + output_device + """ + if not isinstance(inputs, tuple): + inputs = (inputs,) if inputs is not None else () + + device_type = _get_available_device_type() + + if device_type is None: + raise RuntimeError("device type could not be determined") + + if device_ids is None: + device_ids = _get_all_device_indices() + + if device_ids is None: + raise RuntimeError("no available devices were found") + + if output_device is None: + output_device = device_ids[0] + + device_ids = [_get_device_index(x, True) for x in device_ids] + output_device = _get_device_index(output_device, True) + # pyrefly: ignore [no-matching-overload] + src_device_obj = torch.device(device_type, device_ids[0]) + + # pyrefly: ignore [bad-argument-type] + for t in chain(module.parameters(), module.buffers()): + if t.device != src_device_obj: + raise RuntimeError( + "module must have its parameters and buffers " + f"on device {src_device_obj} (device_ids[0]) but found one of " + f"them on device: {t.device}" + ) + + inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) + # for module without any inputs, empty list and dict will be created + # so the module can be executed on one device which is the first one in device_ids + if not inputs and not module_kwargs: + inputs = ((),) + module_kwargs = ({},) + + assert module_kwargs is not None + + if len(device_ids) == 1: + return module(*inputs[0], **module_kwargs[0]) + used_device_ids = device_ids[: len(inputs)] + replicas = replicate(module, used_device_ids) + outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) + return gather(outputs, output_device, dim) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/distributed.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..4899d123e80a124f31e45ed832bba195af32c353 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/distributed.py @@ -0,0 +1,2434 @@ +# mypy: allow-untyped-defs +import copy +import functools +import inspect +import itertools +import logging +import os +import sys +import warnings +import weakref +from collections import defaultdict, deque +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import dataclass, fields, is_dataclass +from enum import auto, Enum +from typing import Any, Optional, TYPE_CHECKING + +import torch +import torch.distributed as dist +from torch._utils import _get_device_index +from torch.autograd import Function, Variable +from torch.distributed.algorithms.join import Join, Joinable, JoinHook +from torch.nn.modules import Module +from torch.nn.parallel.scatter_gather import gather, scatter_kwargs +from torch.utils._pytree import tree_flatten, tree_unflatten + + +RPC_AVAILABLE = False +if dist.is_available(): + from torch.distributed.distributed_c10d import ( + _get_default_group, + _rank_not_in_group, + ReduceOp, + ) + from torch.distributed.utils import ( + _alloc_storage, + _cast_forward_inputs, + _free_storage, + _sync_module_states, + _to_kwargs, + _verify_param_shape_across_processes, + ) +if dist.rpc.is_available(): + RPC_AVAILABLE = True + from torch.distributed.rpc import RRef + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + + +__all__ = ["DistributedDataParallel"] + +logger = logging.getLogger(__name__) + + +@dataclass +class _MixedPrecision: + """ + This configures DDP-native mixed precision training. + + Attributes: + param_dtype (torch.dtype): This specifies the dtype for model + parameters, inputs (when ``cast_forward_inputs`` is set to + ``True``), and therefore the dtype for computation. + However, outside the forward and backward passes, parameters are in + full precision. Model checkpointing always happens in full + precision. + reduce_dtype (torch.dtype): This specifies the dtype for gradient + reduction, which is permitted to differ from ``param_dtype``. + buffer_dtype (torch.dtype): This specifies the dtype for buffers. + + .. note:: This API is experimental and subject to change. + + .. note:: Only floating point tensors are cast to their specified dtypes. + + .. note:: ``state_dict`` checkpoints parameters and buffers in full + precision. + + .. note:: Each low precision dtype must be specified explicitly. For + example, ``_MixedPrecision(reduce_dtype=torch.float16)`` only specifies + the reduction dtype to be low precision, and DDP will not cast + parameters or buffers. + + .. note:: If a ``reduce_dtype`` is not specified, then gradient reduction + happens in ``param_dtype`` if specified or the original parameter dtype + otherwise. For example, ``_MixedPrecision(param_dtype=torch.float16)`` + would result in communication occurring in fp16. + """ + + param_dtype: torch.dtype | None = None + reduce_dtype: torch.dtype | None = None + buffer_dtype: torch.dtype | None = None + # TODO (rohan-varma): keep_low_precision_grads: bool = False + # TODO (rohan-varma): APIs to allow users to run batchnorm and layernorm + # in full precision. For DDP, this can be implemented by not performing the + # parameter cast for BN and LN units. + + +def _cast_buffers(mixed_precision_config, root_module): + """Casts buffers to the given ``buffer_dtype``.""" + for buf in root_module.buffers(): + if hasattr(buf, "_ddp_ignored") and buf._ddp_ignored: + continue + + buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype) + + +def _setup_mixed_precision_params(mixed_precision_config, root_module): + """Create and free storage for the mixed precision parameters.""" + for param in root_module.parameters(): + # Do not setup mixed precision for DDP ignored parameters. + if hasattr(param, "_ddp_ignored") and param._ddp_ignored: + continue + + if not hasattr(param, "_mp_param"): + param._mp_param = torch.zeros_like( + param, + device=param.device, + dtype=mixed_precision_config.param_dtype, + requires_grad=param.requires_grad, + ) + _free_storage(param._mp_param) + # _fp_param will point to the full precision param so it can be switched + # back to at the end of forward / backward. + param._fp_param = param.data + + +def _tree_flatten_with_rref(output): + output_is_rref = RPC_AVAILABLE and isinstance(output, RRef) + if output_is_rref: + output_tensor_list, treespec = tree_flatten(output.local_value()) + else: + output_tensor_list, treespec = tree_flatten(output) + # Need to return flattened tensors, spec to re-pack them, as well + # as if the return type was actually an RRef to reconstruct. + return output_tensor_list, treespec, output_is_rref + + +def _tree_unflatten_with_rref(output, treespec, output_is_rref): + output = tree_unflatten(output, treespec) + if output_is_rref: + output = RRef(output) + return output + + +def _find_tensors(obj): + r"""Recursively find all tensors contained in the specified object.""" + if RPC_AVAILABLE and isinstance(obj, RRef): + # If the current node is the owner of the RRef, unwrap it and try to + # find Tensors. + # TODO: Expand to remote RRefs. + if obj.is_owner(): + return _find_tensors(obj.local_value()) + if isinstance(obj, torch.Tensor): + return [obj] + if isinstance(obj, (list, tuple)): + return itertools.chain.from_iterable(map(_find_tensors, obj)) + if isinstance(obj, dict): + return itertools.chain.from_iterable(map(_find_tensors, obj.values())) + if is_dataclass(obj): + return itertools.chain.from_iterable( + map(_find_tensors, (getattr(obj, f.name) for f in fields(obj))) + ) + + return [] + + +def _dump_DDP_relevant_env_vars(): + relevant_env_vars = [ + "RANK", + "LOCAL_RANK", + "WORLD_SIZE", + "MASTER_PORT", + "MASTER_ADDR", + "CUDA_VISIBLE_DEVICES", + "GLOO_SOCKET_IFNAME", + "GLOO_DEVICE_TRANSPORT", + "NCCL_SOCKET_IFNAME", + "TORCH_NCCL_BLOCKING_WAIT", + "NCCL_DEBUG", + "NCCL_DEBUG_SUBSYS", + "NCCL_IB_DISABLE", + # More NCCL env vars: + "NCCL_P2P_DISABLE", + "NCCL_P2P_LEVEL", + "NCCL_SHM_DISABLE", + "NCCL_SOCKET_NTHREADS", + "NCCL_NSOCKS_PERTHREAD", + "NCCL_BUFFSIZE", + "NCCL_NTHREADS", + "NCCL_RINGS", + "NCCL_MAX_NCHANNELS", + "NCCL_MIN_NCHANNELS", + "NCCL_CHECKS_DISABLE", + "NCCL_CHECK_POINTERS", + "NCCL_LAUNCH_MODE", + "NCCL_IB_HCA", + "NCCL_IB_TIMEOUT", + "NCCL_IB_RETRY_CNT", + "NCCL_IB_GID_INDEX", + "NCCL_IB_SL", + "NCCL_IB_TC", + "NCCL_IB_AR_THRESHOLD", + "NCCL_IB_CUDA_SUPPORT", + "NCCL_NET_GDR_LEVEL", + "NCCL_NET_GDR_READ", + "NCCL_SINGLE_RING_THRESHOLD", + "NCCL_LL_THRESHOLD", + "NCCL_TREE_THRESHOLD", + "NCCL_ALGO", + "NCCL_PROTO", + "NCCL_IGNORE_CPU_AFFINITY", + "NCCL_DEBUG_FILE", + "NCCL_COLLNET_ENABLE", + "NCCL_TOPO_FILE", + "NCCL_TOPO_DUMP_FILE", + "TORCH_NCCL_ASYNC_ERROR_HANDLING", + ] + formatted_output = "" + for var in relevant_env_vars: + value = os.environ.get(var, "N/A") + formatted_output += f"env:{var}={value}\n" + print(formatted_output) + + +class _BufferCommHookLocation(Enum): + PRE_FORWARD = auto() + POST_FORWARD = auto() + + +@dataclass +class _BufferCommHook: + buffer_comm_hook: Callable + buffer_comm_hook_state: Any + buffer_comm_hook_location: _BufferCommHookLocation + + +# Add a DDPSink to run various functions when backwards starts, such as +# queueing call back of out-most backward/graph task, +# this helps call back is fired after all gradients' calculation +# is completed. +class _DDPSink(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, ddp_weakref, *inputs): + # set_materialize_grads(False) will ensure that None gradients stay as + # None and are not filled with zeros. + ctx.set_materialize_grads(False) + ctx.ddp_weakref = ddp_weakref + ret = inputs + if ddp_weakref()._ddp_sink_clone: + ret = tuple( + inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs + ) + return ret + + @staticmethod + def backward(ctx, *grad_outputs): + # Enqueue delay allreduce for static graph training on the first + # iteration. + ddp_weakref = ctx.ddp_weakref() + reducer = ddp_weakref.reducer + static_graph = ddp_weakref.static_graph + delay_ar_enqueued = ( + static_graph and ddp_weakref._static_graph_delay_allreduce_enqueued + ) + if static_graph and not delay_ar_enqueued: + Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc] + reducer._delay_all_reduce + ) + ddp_weakref._static_graph_delay_allreduce_enqueued = True + + return (None, *grad_outputs) + + +class _DDPJoinHook(JoinHook): + def __init__(self, ddp, divide_by_initial_world_size): + """Set config variables for internal usage.""" + assert isinstance(ddp, DistributedDataParallel), ( + "DDP join hook requires passing in a DistributedDataParallel " + "instance as the state" + ) + assert ddp.logger is not None + ddp.logger._set_uneven_input_join() + self.ddp = ddp + self.ddp._divide_by_initial_world_size = divide_by_initial_world_size + super().__init__() + + def main_hook(self): + """Shadow the DDP collective communication operations in the forward and backward passes.""" + ddp = self.ddp + # Buckets are rebuilt only once during a training period + ddp.reducer._rebuild_buckets() + + # Schedule a broadcast if we are syncing module buffers in the + # forward pass + # TODO: make DDP uneven inputs context manager support buffer + # comm hook (https://github.com/pytorch/pytorch/issues/65436) + ddp._check_and_sync_module_buffers() + + # Check if need to sync in the backward pass + should_sync_backwards = ddp._check_global_requires_backward_grad_sync( + is_joined_rank=True + ) + # Forward parameter sync is disabled in the next iteration if we + # are skipping gradient sync this iteration, so set + # `require_forward_param_sync` accordingly + ddp.require_forward_param_sync = should_sync_backwards + if not should_sync_backwards: + return + + # Schedule one allreduce per gradient bucket to match the backward + # pass allreduce + ddp._match_all_reduce_for_bwd_pass() + + # Check if we need to allreduce locally unused parameters + if ddp.find_unused_parameters: + ddp._match_unused_params_allreduce() + + # Rebuilt parameters are pushed only once during a training period + ddp.reducer._push_all_rebuilt_params() + + def post_hook(self, is_last_joiner: bool): + """Sync the final model to ensure that the model is the same across all processes.""" + self.ddp._sync_final_model(is_last_joiner) + + +class DistributedDataParallel(Module, Joinable): + r"""Implement distributed data parallelism based on ``torch.distributed`` at module level. + + This container provides data parallelism by synchronizing gradients + across each model replica. The devices to synchronize across are + specified by the input ``process_group``, which is the entire world + by default. Note that ``DistributedDataParallel`` does not chunk or + otherwise shard the input across participating GPUs; the user is + responsible for defining how to do so, for example through the use + of a :class:`DistributedSampler`. + + See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`. + The same constraints on input as in :class:`torch.nn.DataParallel` apply. + + Creation of this class requires that ``torch.distributed`` to be already + initialized, by calling :func:`torch.distributed.init_process_group`. + + ``DistributedDataParallel`` is proven to be significantly faster than + :class:`torch.nn.DataParallel` for single-node multi-GPU data + parallel training. + + To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn + up ``N`` processes, ensuring that each process exclusively works on a single + GPU from 0 to N-1. This can be done by either setting + ``CUDA_VISIBLE_DEVICES`` for every process or by calling the following API for GPUs, + + >>> # xdoctest: +SKIP("undefined variables") + >>> torch.cuda.set_device(i) + + or calling the unified API for :ref:`accelerator`, + + >>> # xdoctest: +SKIP("undefined variables") + >>> torch.accelerator.set_device_index(i) + + where i is from 0 to N-1. In each process, you should refer the following + to construct this module: + + >>> # xdoctest: +SKIP("undefined variables") + >>> if torch.accelerator.is_available(): + >>> device_type = torch.accelerator.current_accelerator().type + >>> vendor_backend = torch.distributed.get_default_backend_for_device(device_type) + >>> + >>> torch.distributed.init_process_group( + >>> backend=vendor_backend, world_size=N, init_method='...' + >>> ) + >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i) + + Or you can use the latest API for initialization: + + >>> torch.distributed.init_process_group(device_id=i) + + In order to spawn up multiple processes per node, you can use either + ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``. + + .. note:: + Please refer to `PyTorch Distributed Overview `__ + for a brief introduction to all features related to distributed training. + + .. note:: + ``DistributedDataParallel`` can be used in conjunction with + :class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce + per-rank optimizer states memory footprint. Please refer to + `ZeroRedundancyOptimizer recipe `__ + for more details. + + .. note:: ``nccl`` backend is currently the fastest and highly recommended + backend when using GPUs. This applies to both single-node and + multi-node distributed training. + + .. note:: This module also supports mixed-precision distributed training. + This means that your model can have different types of parameters such + as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these + mixed types of parameters will just work fine. + + .. note:: If you use ``torch.save`` on one process to checkpoint the module, + and ``torch.load`` on some other processes to recover it, make sure that + ``map_location`` is configured properly for every process. Without + ``map_location``, ``torch.load`` would recover the module to devices + where the module was saved from. + + .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the + gradient will be ``M`` times smaller when compared to the same model + trained on a single node with ``batch=M*N`` if the loss is summed (NOT + averaged as usual) across instances in a batch (because the gradients + between different nodes are averaged). You should take this into + consideration when you want to obtain a mathematically equivalent + training process compared to the local training counterpart. But in most + cases, you can just treat a DistributedDataParallel wrapped model, a + DataParallel wrapped model and an ordinary model on a single GPU as the + same (E.g. using the same learning rate for equivalent batch size). + + .. note:: + Parameters are never broadcast between processes. The module performs + an all-reduce step on gradients and assumes that they will be modified + by the optimizer in all processes in the same way. Buffers + (e.g. BatchNorm stats) are broadcast from the module in process of rank + 0, to all other replicas in the system in every iteration. + + .. note:: + If you are using DistributedDataParallel in conjunction with the + :ref:`distributed-rpc-framework`, you should always use + :meth:`torch.distributed.autograd.backward` to compute gradients and + :class:`torch.distributed.optim.DistributedOptimizer` for optimizing + parameters. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> import torch.distributed.autograd as dist_autograd + >>> from torch.nn.parallel import DistributedDataParallel as DDP + >>> import torch + >>> from torch import optim + >>> from torch.distributed.optim import DistributedOptimizer + >>> import torch.distributed.rpc as rpc + >>> from torch.distributed.rpc import RRef + >>> + >>> t1 = torch.rand((3, 3), requires_grad=True) + >>> t2 = torch.rand((3, 3), requires_grad=True) + >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2)) + >>> ddp_model = DDP(my_model) + >>> + >>> # Setup optimizer + >>> optimizer_params = [rref] + >>> for param in ddp_model.parameters(): + >>> optimizer_params.append(RRef(param)) + >>> + >>> dist_optim = DistributedOptimizer( + >>> optim.SGD, + >>> optimizer_params, + >>> lr=0.05, + >>> ) + >>> + >>> with dist_autograd.context() as context_id: + >>> pred = ddp_model(rref.to_here()) + >>> loss = loss_func(pred, target) + >>> dist_autograd.backward(context_id, [loss]) + >>> dist_optim.step(context_id) + + .. note:: + DistributedDataParallel currently offers limited support for gradient + checkpointing with :meth:`torch.utils.checkpoint`. + If the checkpoint is done with use_reentrant=False (recommended), DDP + will work as expected without any limitations. + If, however, the checkpoint is done with use_reentrant=True (the default), + DDP will work as expected when there are no unused parameters in the model + and each layer is checkpointed at most once (make sure you are not passing + `find_unused_parameters=True` to DDP). We currently do not support the + case where a layer is checkpointed multiple times, or when there unused + parameters in the checkpointed model. + + .. note:: + To let a non-DDP model load a state dict from a DDP model, + :meth:`~torch.nn.modules.utils.consume_prefix_in_state_dict_if_present` + needs to be applied to strip the prefix "module." in the DDP state dict before loading. + + .. warning:: + Constructor, forward method, and differentiation of the output (or a + function of the output of this module) are distributed synchronization + points. Take that into account in case different processes might be + executing different code. + + .. warning:: + This module assumes all parameters are registered in the model by the + time it is created. No parameters should be added nor removed later. + Same applies to buffers. + + .. warning:: + This module assumes all parameters are registered in the model of each + distributed processes are in the same order. The module itself will + conduct gradient ``allreduce`` following the reverse order of the + registered parameters of the model. In other words, it is users' + responsibility to ensure that each distributed process has the exact + same model and thus the exact same parameter registration order. + + .. warning:: + This module allows parameters with non-rowmajor-contiguous strides. + For example, your model may contain some parameters whose + :class:`torch.memory_format` is ``torch.contiguous_format`` + and others whose format is ``torch.channels_last``. However, + corresponding parameters in different processes must have the + same strides. + + .. warning:: + This module doesn't work with :func:`torch.autograd.grad` (i.e. it will + only work if gradients are to be accumulated in ``.grad`` attributes of + parameters). + + .. warning:: + If you plan on using this module with a ``nccl`` backend or a ``gloo`` + backend (that uses Infiniband), together with a DataLoader that uses + multiple workers, please change the multiprocessing start method to + ``forkserver`` (Python 3 only) or ``spawn``. Unfortunately + Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will + likely experience deadlocks if you don't change this setting. + + .. warning:: + You should never try to change your model's parameters after wrapping + up your model with ``DistributedDataParallel``. Because, when + wrapping up your model with ``DistributedDataParallel``, the constructor + of ``DistributedDataParallel`` will register the additional gradient + reduction functions on all the parameters of the model itself at the + time of construction. If you change the model's parameters afterwards, + gradient reduction functions no longer match the correct set of + parameters. + + .. warning:: + Using ``DistributedDataParallel`` in conjunction with the + :ref:`distributed-rpc-framework` is experimental and subject to change. + + Args: + module (Module): module to be parallelized + device_ids (list of int or torch.device): CUDA devices. + 1) For single-device modules, ``device_ids`` can + contain exactly one device id, which represents the only + CUDA device where the input module corresponding to this process resides. + Alternatively, ``device_ids`` can also be ``None``. + 2) For multi-device modules and CPU modules, + ``device_ids`` must be ``None``. + + When ``device_ids`` is ``None`` for both cases, + both the input data for the forward pass and the actual module + must be placed on the correct device. + (default: ``None``) + output_device (int or torch.device): Device location of output for + single-device CUDA modules. For multi-device modules and + CPU modules, it must be ``None``, and the module itself + dictates the output location. (default: ``device_ids[0]`` + for single-device modules) + broadcast_buffers (bool): Flag that enables syncing (broadcasting) + buffers of the module at beginning of the ``forward`` + function. (default: ``True``) + init_sync (bool): Whether to sync during initialization to verify param + shapes and broadcast parameters and buffers. + WARNING: if this is set to False the user is required + to ensure themselves that the weights are the same on + all ranks. + (default: ``True``) + process_group: The process group to be used for distributed data + all-reduction. If ``None``, the default process group, which + is created by :func:`torch.distributed.init_process_group`, + will be used. (default: ``None``) + bucket_cap_mb: ``DistributedDataParallel`` will bucket parameters into + multiple buckets so that gradient reduction of each + bucket can potentially overlap with backward computation. + :attr:`bucket_cap_mb` controls the bucket size in + MebiBytes (MiB). If ``None``, a default size of 25 MiB + will be used. (default: ``None``) + find_unused_parameters (bool): Traverse the autograd graph from all + tensors contained in the return value of the + wrapped module's ``forward`` function. Parameters + that don't receive gradients as part of this + graph are preemptively marked as being ready to + be reduced. In addition, parameters that may have + been used in the wrapped module's ``forward`` + function but were not part of loss computation and + thus would also not receive gradients are + preemptively marked as ready to be reduced. + (default: ``False``) + check_reduction: This argument is deprecated. + gradient_as_bucket_view (bool): When set to ``True``, gradients will be views + pointing to different offsets of ``allreduce`` communication + buckets. This can reduce peak memory usage, where the + saved memory size will be equal to the total gradients + size. Moreover, it avoids the overhead of copying between + gradients and ``allreduce`` communication buckets. When + gradients are views, ``detach_()`` cannot be called on the + gradients. If hitting such errors, please fix it by + referring to the :meth:`~torch.optim.Optimizer.zero_grad` + function in ``torch/optim/optimizer.py`` as a solution. + Note that gradients will be views after first iteration, so + the peak memory saving should be checked after first iteration. + static_graph (bool): When set to ``True``, DDP knows the trained graph is + static. Static graph means 1) The set of used and unused + parameters will not change during the whole training loop; in + this case, it does not matter whether users set + ``find_unused_parameters = True`` or not. 2) How the graph is trained + will not change during the whole training loop (meaning there is + no control flow depending on iterations). + When static_graph is set to be ``True``, DDP will support cases that + can not be supported in the past: + 1) Reentrant backwards. + 2) Activation checkpointing multiple times. + 3) Activation checkpointing when model has unused parameters. + 4) There are model parameters that are outside of forward function. + 5) Potentially improve performance when there are unused parameters, + as DDP will not search graph in each iteration to detect unused + parameters when static_graph is set to be ``True``. + To check whether you can set static_graph to be ``True``, one way is to + check ddp logging data at the end of your previous model training, + if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you + can set ``static_graph = True`` as well. + + Example:: + >>> # xdoctest: +SKIP("undefined variables") + >>> model_DDP = torch.nn.parallel.DistributedDataParallel(model) + >>> # Training loop + >>> ... + >>> ddp_logging_data = model_DDP._get_ddp_logging_data() + >>> static_graph = ddp_logging_data.get("can_set_static_graph") + delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter): a list + of named parameters whose all reduce will be delayed when the gradient of + the parameter specified in ``param_to_hook_all_reduce`` is ready. Other + arguments of DDP do not apply to named params specified in this argument + as these named params will be ignored by DDP reducer. + param_to_hook_all_reduce (torch.nn.Parameter): a parameter to hook delayed all reduce + of parameters specified in ``delay_all_reduce_named_params``. + skip_all_reduce_unused_params: When set to True, DDP will skip reducing unused parameters. + This requires that unused parameters remain the same across all ranks throughout + the entire training process. If this condition is not met, it may cause + desynchronization and result in training hang. + + + Attributes: + module (Module): the module to be parallelized. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') + >>> net = torch.nn.parallel.DistributedDataParallel(model) + """ + + # used to track whether the given thread is inside ddp forward for torchdynamo purposes + _active_ddp_module: Optional["DistributedDataParallel"] = None + + def __init__( + self, + module, + device_ids=None, + output_device=None, + dim=0, + broadcast_buffers=True, + init_sync=True, + process_group=None, + bucket_cap_mb=None, + find_unused_parameters=False, + check_reduction=False, + gradient_as_bucket_view=False, + static_graph=False, + delay_all_reduce_named_params=None, + param_to_hook_all_reduce=None, + mixed_precision: _MixedPrecision | None = None, + device_mesh=None, + skip_all_reduce_unused_params=False, + ): + super().__init__() + Joinable.__init__(self) + self._use_python_reducer = ( + torch._dynamo.utils.get_optimize_ddp_mode() == "python_reducer" + ) + self.logger: dist.Logger | None = None + if bool(delay_all_reduce_named_params is not None) != bool( + param_to_hook_all_reduce is not None + ): + self._log_and_throw( + ValueError, + "delay_all_reduce_named_params and param_to_hook_all_reduce " + "need to be set at the same time.", + ) + + if process_group and device_mesh is not None: + raise RuntimeError( + "Cannot specify both process_group and device_mesh arguments." + ) + elif process_group is None and device_mesh is None: + self.process_group = _get_default_group() + elif device_mesh is None: + # pyrefly: ignore [bad-assignment] + self.process_group = process_group + else: + if device_mesh.ndim != 1: + raise RuntimeError( + f"Only 1D device mesh is supported, but got {device_mesh}." + ) + self.device_mesh = device_mesh + self.process_group = device_mesh.get_group(mesh_dim=0) + + root_mesh = device_mesh._get_root_mesh() + # if a root mesh is not the same as device_mesh, + # meaning the device_mesh is sliced out from the root mesh. + if root_mesh != device_mesh: + # TODO: This is a temporary work around to enable DDP + TP. + # We should do the logic in DDP so that the 2D implementation is + # sound and the state_dict works out of the box. + # This has to be done before check UninitializedParameter. + from torch.distributed.tensor.parallel.ddp import ( + _pre_dp_module_transform, + ) + + _pre_dp_module_transform(module) + + self._delay_all_reduce_params = [] + if hasattr(module, "_ddp_params_and_buffers_to_ignore"): + self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) + else: + self.parameters_to_ignore = set() + if delay_all_reduce_named_params is not None: + for name, param in delay_all_reduce_named_params: + self.parameters_to_ignore.add(name) + self._delay_all_reduce_params.append(param) + + self._module_parameters = [ + p + for n, p in module.named_parameters() + if n not in self.parameters_to_ignore + ] + if not any(p.requires_grad for p in self._module_parameters): + if len(self._delay_all_reduce_params): + logger.info("Delay the AllReduce of all parameters.") + else: + self._log_and_throw( + RuntimeError, + "DistributedDataParallel is not needed when a module " + "doesn't have any parameter that requires a gradient.", + ) + + if device_ids is not None and len(device_ids) > 1: + self._log_and_throw( + ValueError, + "device_ids can only be None or contain a single element.", + ) + + self.is_multi_device_module = ( + len({p.device for p in self._module_parameters}) > 1 + ) + distinct_device_types = { + p.device.type for p in self._module_parameters if p.device is not None + } + if len(distinct_device_types) != 1: + self._log_and_throw( + ValueError, + "DistributedDataParallel's input module must be on " + f"the same type of devices, but input module parameters locate in {distinct_device_types}.", + ) + + self.device_type = next(iter(distinct_device_types)) + + if ( + device_ids is None + or len(device_ids) == 0 # For backward compatibility. + or self.device_type == "cpu" + or self.is_multi_device_module + ): + if device_ids or output_device: + self._log_and_throw( + ValueError, + "DistributedDataParallel device_ids and output_device arguments " + "only work with single-device/multiple-device GPU modules or CPU modules, " + f"but got device_ids {device_ids}, output_device {output_device}, " + f"and module parameters { ({p.device for p in self._module_parameters}) }.", + ) + + self.device_ids = None + self.output_device = None + else: + # pyrefly: ignore [bad-assignment] + self.device_ids = [_get_device_index(x, True) for x in device_ids] + + if output_device is None: + output_device = device_ids[0] + + # pyrefly: ignore [bad-assignment] + self.output_device = _get_device_index(output_device, True) + + self.static_graph = False + self.dim = dim + self.module = module + self.device = next(iter(self._module_parameters)).device + self.broadcast_buffers = broadcast_buffers + self.find_unused_parameters = find_unused_parameters + self.require_backward_grad_sync = True + self.require_forward_param_sync = True + self.gradient_as_bucket_view = gradient_as_bucket_view + self.mixed_precision = mixed_precision + if self.mixed_precision is not None: + logger.warning("Received mixed precision config %s", self.mixed_precision) + + if check_reduction: + # This argument is no longer used since the reducer + # will ensure reduction completes even if some parameters + # do not receive gradients. + warnings.warn( + "The `check_reduction` argument in `DistributedDataParallel` " + "module is deprecated. Please avoid using it.", + FutureWarning, + stacklevel=2, + ) + + # Check that a module does not have Uninitialized parameters + for param in self._module_parameters: + if isinstance(param, torch.nn.parameter.UninitializedParameter): + self._log_and_throw( + RuntimeError, + "Modules with uninitialized parameters can't be used with `DistributedDataParallel`. " + "Run a dummy forward pass to correctly initialize the modules", + ) + # used for intra-node param sync and inter-node sync as well + self.broadcast_bucket_size = 250 * 1024 * 1024 + + # reduction bucket size + if bucket_cap_mb is None: + # default case (bucket cap is 25 MiB) + bucket_cap_mb = 25 + self.bucket_bytes_cap_default = True + else: + self.bucket_bytes_cap_default = False + self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) + + # Whether to perform input tensor CPU to GPU copies on a side-stream + self.use_side_stream_for_tensor_copies = ( + os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1" + ) + + # Initialize gradient buffers and register all reduce hook + self._delay_grad_buffer: torch.Tensor | None = None + self._delay_grad_views: list[torch.Tensor] = [] + self._delay_all_reduce_all_params = False + if len(self._delay_all_reduce_params) != 0: + self._register_delay_all_reduce_hook( + bucket_cap_mb=bucket_cap_mb, + param_to_hook_all_reduce=param_to_hook_all_reduce, + device_ids=device_ids, + ) + if self._delay_all_reduce_all_params: + return + + self.skip_all_reduce_unused_params = skip_all_reduce_unused_params + + # Build parameters for reducer. + parameters, expect_sparse_gradient = self._build_params_for_reducer() + + # All collectives during initialization are gated by this flag. + if init_sync: + # Verify model equivalence. + _verify_param_shape_across_processes(self.process_group, parameters) + # Sync params and buffers. Ensures all DDP models start off at the same value. + _sync_module_states( + module=self.module, + process_group=self.process_group, + broadcast_bucket_size=self.broadcast_bucket_size, + src=0, + params_and_buffers_to_ignore=self.parameters_to_ignore, + broadcast_buffers=self.broadcast_buffers, + ) + + # In debug mode, build a mapping of parameter index -> parameter. + param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters) + + # Builds reducer. + self._ddp_init_helper( + parameters, + expect_sparse_gradient, + param_to_name_mapping, + static_graph, + ) + self._comm_hooks: list[tuple[Callable, object]] = [] + + if self.mixed_precision is not None: + _setup_mixed_precision_params(self.mixed_precision, self.module) + _cast_buffers(self.mixed_precision, self.module) + # Stream used for async low precision copies. + self._mp_stream = torch.Stream() + self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated] + # Add forward pre-hook to root module to kick off copies to lower + # precision. + self.module.register_forward_pre_hook( + self._root_copy_hook, prepend=False, with_kwargs=True + ) + # Add forward pre hook to all submodules to wait for copy events + # before running computation. + for module in self.module.modules(): + module.register_forward_pre_hook( + self._module_wait_for_copy_hook, + prepend=False, + with_kwargs=True, + ) + # Set up callbacks in backward to upcast and use full precision + # params. TODO (rohan-varma): Make this compose with general + # comm hooks and apply_optimizer_in_backward. Importing inline to + # avoid circular import issue. + from torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks import ( + _AllreduceUpcastHookState, + _reducer_allreduce_and_upcast_hook, + ) + + upcast_hook_state = _AllreduceUpcastHookState( + ddp_weakref=weakref.ref(self), + upcast_stream=torch.Stream(), + ) + self.register_comm_hook( + upcast_hook_state, + _reducer_allreduce_and_upcast_hook, + ) + # Inform reducer of reduced precision param dtype for correctness + # of type checks between gradient and bucket. + self.reducer._set_mixed_precision_param_dtype( # type: ignore[attr-defined] + self.mixed_precision.param_dtype + ) + + self._has_rebuilt_buckets = False + + if static_graph: + self._set_static_graph() + + self._lazy_init_ran = False + + # Register the AccumulateGrad post hooks if optimize_ddp is + # True. The hooks will be deregistered if compiled_autograd is not + # enabled. + self._accum_grad_hooks: list[RemovableHandle] = [] + if self._use_python_reducer: + # pyrefly: ignore [bad-assignment] + torch._inductor.config._fuse_ddp_communication = True + torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb + # Directly adding this to the trace rule will disturb the users + # who are using DDPOptimizer. + torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST.add( + "torch.nn.parallel.distributed" + ) + torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear() + # NOTE: we should init these lazily + self._register_accum_grad_hook() + + # Whether or not DDPSink performs a clone. + self._ddp_sink_clone = True + + def _register_accum_grad_hook(self): + import torch.distributed._functional_collectives as fcol + + def compiled_accum_grad_hook( + param, + *, + param_index: int, + ): + if not self.require_backward_grad_sync: + return + + if param.grad is None: + return + + if self._comm_hooks: + for hook, state in self._comm_hooks: + hook(state, (param.grad, param)) + else: + gradient = param.grad / self.process_group.size() + gradient = fcol.all_reduce(gradient, "sum", self.process_group) + param.grad.copy_(gradient) + + for index, param in enumerate(self._module_parameters): + if not param.requires_grad: + continue + self._accum_grad_hooks.append( + param.register_post_accumulate_grad_hook( + functools.partial( + compiled_accum_grad_hook, + param_index=index, + ) + ) + ) + + def _delayed_all_reduce_hook(self, grad): + world_size = dist.get_world_size(self.process_group) + + self._delay_grad_buffer.div_(world_size) # type: ignore[union-attr] + _ = dist.all_reduce( + self._delay_grad_buffer, group=self.process_group, async_op=True + ) + return grad + + def _register_delay_all_reduce_hook( + self, + bucket_cap_mb, + param_to_hook_all_reduce, + device_ids, + ): + # 1. Create gradient buffer + device = torch.device("cpu") if device_ids is None else device_ids[0] + self._delay_grad_buffer = torch.zeros( + sum(p.numel() for p in self._delay_all_reduce_params), + device=device, + ) + + # 2. Broadcast the parameters + detached_params = [p.detach() for p in self._delay_all_reduce_params] + dist._broadcast_coalesced(self.process_group, detached_params, bucket_cap_mb, 0) + + # 3. Hook all reduce to the specified parameter + param_to_hook_all_reduce.register_hook(self._delayed_all_reduce_hook) + + # 4. Build tensor views for gradients + offset = 0 + for param in self._delay_all_reduce_params: + grad_view = self._delay_grad_buffer[offset : (offset + param.numel())].view( + param.shape + ) + self._delay_grad_views.append(grad_view) + offset = offset + param.numel() + + # 5. Check whether the all reduce of all params requiring grad is delayed. + for module_name, module in self.module.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + if param.requires_grad: + full_name = f"{module_name}.{param_name}" + if full_name not in self.parameters_to_ignore: + # There is at least a param whose all reduce will not be delayed. + # In this case, we should not set self._delay_all_reduce_all_params + # to True. + return + self._delay_all_reduce_all_params = True + + def _setup_in_backward_optimizers(self): + # Check if user has used apply_optim_in_backward to overlap optimizer + # step + DDP backward. Current constraints: + # 1. Only allreduce is supported at the moment, no custom communication. + # 2. For DDP-managed parameters that have their optimizer run in + # backward, their gradients are set to ``None``. If your use case + # requires DDP parameters grad not to be set to ``None`` after their + # in-backward optimizer runs, please ping + # https://github.com/pytorch/pytorch/issues/90052. + # NOTE: we use self._module_parameters instead of .parameters() since + # the former excludes ignored (non-DDP managed) parameters. + if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters): + torch._C._log_api_usage_once("ddp.optimizer_in_backward") + # Remove hooks that apply_optim_in_backward had registered because + # DDP customizes how optimizer is overlapped with backward due to + # the allreduce. + param_to_handle_map = ( + dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map + ) + for p in self._module_parameters: + for handle in param_to_handle_map.get(p, []): + handle.remove() + + # Need a weakref to DDP instance to run all_reduce (from reducer) + # and get managed DDP parameters. + ddp_weakref = weakref.ref(self) + # Note: importing in function, otherwise this will cause a circular + # import. + from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import ( + _apply_optim_in_backward_hook, + ) + + self.register_comm_hook( + ddp_weakref, + _apply_optim_in_backward_hook( + gradient_is_bucket_view=self.gradient_as_bucket_view + ), + ) + + self.reducer._set_optimizer_in_backward() # type: ignore[attr-defined] + + def _fire_reducer_autograd_hook(self, idx, *unused): + """ + Fire the reducer's autograd hook to allreduce params in a Reducer bucket. + + Note that this is only used during mixed precision training as the + Reducer's hooks installed during construction time would not be called + as we're working in the low precision parameter setting. + """ + self.reducer._autograd_hook(idx) # type: ignore[attr-defined] + + def _root_copy_hook(self, *args: Any, **kwargs: Any) -> None: + """ + For DDP mixed precision, put low precision copies on separate stream and create events to wait for them. + + When training with DDP mixed precision, this root pre-forward hook kicks + off low precision copies on a separate stream and creates respective + events to wait for them. + """ + # Clear out previous iteration submodule to event. This is because we + # may have populated some events for modules that didn't end up being + # used. + self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated] + with self._mp_stream: + for submodule in self.module.modules(): + for param in submodule.parameters(recurse=False): + # Do not cast DDP ignored parameters. + if hasattr(param, "_ddp_ignored") and param._ddp_ignored: + continue + _alloc_storage(param._mp_param, param.size()) + # copy() implicitly casts to low precision + with torch.no_grad(): + param._mp_param.copy_(param.data) + # TODO: when zero_grad(set_to_none=False) or in grad + # accumulation case, accumulated grads can be in fp32 + # which can cause errors when running DDP backwards due + # to mismatched incoming and accumulated gradient types. + # So we manually cast the accumulated grad down for now, + # in the future we may shift to FSDP style gradient + # accumulation management where the accumulated gradient + # is saved and .grad field is set to None, bypassing + # this issue. + if param.grad is not None: + param.grad.data = param.grad.to( + self.mixed_precision.param_dtype # type: ignore[union-attr] + ) + param.data = param._mp_param + copy_event = torch.Event() + copy_event.record() + self._submodule_to_event[submodule].append(copy_event) + + def _module_wait_for_copy_hook( + self, + module, + *args: Any, + **kwargs: Any, + ) -> None: + """Before carrying out computation, wait on the appropriate event to ensure low precision copies have finished.""" + try: + event = self._submodule_to_event[module].popleft() + except IndexError: + # copy event has already been waited on + return + + event.wait(stream=torch.accelerator.current_stream()) + for p in module.parameters(recurse=False): + # Don't register hooks if param does not require grad + if not p.requires_grad or (hasattr(p, "_ddp_ignored") and p._ddp_ignored): + continue + # We need to register autograd hook here instead of DDP's ctor + # since we're working with the low precision param. Register them + # via obtaining the gradient accumulator. + tmp = p.expand_as(p) + grad_acc = tmp.grad_fn.next_functions[0][0] + + hook = grad_acc.register_hook( + functools.partial(self._fire_reducer_autograd_hook, p._idx) + ) + p._ddp_mp_hook_state = (grad_acc, hook) + + def _log_and_throw(self, err_type, err_msg): + if self.logger is not None: + self.logger.set_error_and_log(f"{str(err_type)}: {err_msg}") + raise err_type(err_msg) + + def _ddp_init_helper( + self, + parameters, + expect_sparse_gradient, + param_to_name_mapping, + static_graph, + ): + """ + DDP init helper function to manage parameters, grad hooks, logging, and SyncBatchNorm. + + Initialization helper function that does the following: + (1) bucketing the parameters for reductions + (2) resetting the bucketing states + (3) registering the grad hooks + (4) Logging construction-time DDP logging data + (5) passing a handle of DDP to SyncBatchNorm Layer + """ + # Notice, the parameters order is not in the order in which they are used, + # especially in models with control flow. + # + # Alongside parameters are not presented in the real execution order, + # if a certain model happens to also + # 1) have other collectives comm ops in its backward graph. + # 2) have unused parameter in subset ranks of the whole world. + # bucketing could insert ALL-REDUCE comm op too early on the rank with unused parameter, + # matching up with other collectives comm ops on other ranks unexpectedly. + # + # In order to handle this corner case, when the parameters are not in the real execution order, + # we don't do bucketing, thus only one ALL-REDUCE is inserted after all the gradients + # of the whole graph are computed. + # + # Notice, here we only disable bucketing for the first iteration. + # After the first iteration, it's OK to rebuild buckets, + # because "bucket rebuild" bucketizes parameters based on its real execution order in backward graph. + + # Can remove this branching once #73732 is landed. + if static_graph is True or self.find_unused_parameters is False: + bucket_size_limits = [sys.maxsize] + else: + if self.bucket_bytes_cap_default: + bucket_size_limits = [ + dist._DEFAULT_FIRST_BUCKET_BYTES, + self.bucket_bytes_cap, + ] + else: + bucket_size_limits = [self.bucket_bytes_cap] + ( + bucket_indices, + per_bucket_size_limits, + ) = dist._compute_bucket_assignment_by_size( + parameters, + bucket_size_limits, + expect_sparse_gradient, + ) + + # Remember index for parameters if we are in mixed precision, as we + # need to pass in index to Reducer's autograd hook via python. + if self.mixed_precision is not None: + for i, p in enumerate(parameters): + p._idx = i + + # Note: reverse list of buckets because we want to approximate the + # order in which their gradients are produced, and assume they + # are used in the forward pass in the order they are defined. + self.reducer = dist.Reducer( + parameters, + list(reversed(bucket_indices)), + list(reversed(per_bucket_size_limits)), + self.process_group, + expect_sparse_gradient, + # The bucket size limit is specified in the constructor. + # Additionally, we allow for a single small bucket for parameters + # that are defined first, such that their gradients don't spill into + # a much larger bucket, adding unnecessary latency after gradient + # computation finishes. Experiments showed 1MB is a reasonable value. + self.bucket_bytes_cap, + self.find_unused_parameters, + self.gradient_as_bucket_view, + param_to_name_mapping, + # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first + # bucket. + ( + dist._DEFAULT_FIRST_BUCKET_BYTES + if self.bucket_bytes_cap_default + else self.bucket_bytes_cap + ), + self.skip_all_reduce_unused_params, + self._use_python_reducer, + ) + + self.logger = dist.Logger(self.reducer) + # Set as a weak reference to avoid reference cycle between + # logger and reducer. + self.reducer.set_logger(self.logger) + + has_sync_bn = False + for submodule in self.module.modules(): + if isinstance(submodule, torch.nn.SyncBatchNorm): + has_sync_bn = True + break + + # Set logging data that can be got during construction time. + self.logger.set_construction_data_and_log( + self.module.__class__.__name__, + [] if self.device_ids is None else self.device_ids, + -1 if self.output_device is None else self.output_device, + self.broadcast_buffers, + has_sync_bn, + static_graph, + ) + + # passing a handle to torch.nn.SyncBatchNorm layer + self._passing_sync_batchnorm_handle(self.module) + + def __getstate__(self): + self._check_default_group() + attrs = copy.copy(self.__dict__) + del attrs["process_group"] + del attrs["reducer"] + del attrs["logger"] + return attrs + + def __setstate__(self, state): + # If serializable, then the process group should be the default one + self.process_group = _get_default_group() + super().__setstate__(state) + self.__dict__.setdefault("require_forward_param_sync", True) + self.__dict__.setdefault("require_backward_grad_sync", True) + parameters, expect_sparse_gradient = self._build_params_for_reducer() + # In debug mode, build a mapping of parameter index -> parameter. + param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters) + # Builds reducer. + self._ddp_init_helper( + parameters, + expect_sparse_gradient, + param_to_name_mapping, + self.static_graph, + ) + if self.static_graph: + self.reducer._set_static_graph() + assert self.logger is not None + self.logger._set_static_graph() + + def _build_params_for_reducer(self): + # Build tuple of (module, parameter) for all parameters that require grads. + modules_and_parameters = [ + (module, parameter) + for module_name, module in self.module.named_modules() + for parameter in [ + param + # Note that we access module.named_parameters instead of + # parameters(module). parameters(module) is only needed in the + # single-process multi device case, where it accesses replicated + # parameters through _former_parameters. + for param_name, param in module.named_parameters(recurse=False) + if param.requires_grad + and f"{module_name}.{param_name}" not in self.parameters_to_ignore + ] + ] + + # Deduplicate any parameters that might be shared across child modules. + memo = set() + modules_and_parameters = [ + # "p not in memo" is the deduplication check. + # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed. + (m, p) + for m, p in modules_and_parameters + if p not in memo and not memo.add(p) # type: ignore[func-returns-value] + ] + + # Build list of parameters. + parameters = [parameter for _, parameter in modules_and_parameters] + + # Checks if a module will produce a sparse gradient. + def produces_sparse_gradient(module): + if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)): + return module.sparse + return False + + # Build list of booleans indicating whether or not to expect sparse + # gradients for the corresponding parameters. + expect_sparse_gradient = [ + produces_sparse_gradient(module) for module, _ in modules_and_parameters + ] + + self._assign_modules_buffers() + + return parameters, expect_sparse_gradient + + def _assign_modules_buffers(self): + """ + Assign self.module.named_buffers to self.modules_buffers. + + Assigns module buffers to self.modules_buffers which are then used to + broadcast across ranks when broadcast_buffers=True. Note that this + must be called every time buffers need to be synced because buffers can + be reassigned by user module, + see https://github.com/pytorch/pytorch/issues/63916. + """ + # Collect buffers for modules, filtering out buffers that should be ignored. + named_module_buffers = [ + (buffer, buffer_name) + for buffer_name, buffer in self.module.named_buffers() + if buffer_name not in self.parameters_to_ignore + ] + self.modules_buffers = [ + buffer for (buffer, buffer_name) in named_module_buffers + ] + # Dict[str, tensor] representing module buffers not ignored by DDP. + self.named_module_buffers = { + buffer_name: buffer for (buffer, buffer_name) in named_module_buffers + } + + def _build_debug_param_to_name_mapping(self, parameters): + param_to_param_index = {parameters[i]: i for i in range(len(parameters))} + param_set = set(parameters) + param_index_to_param_fqn = {} + for module_name, module in self.module.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + fqn = f"{module_name}.{param_name}" + # Bypass ignored parameters since those are not reduced by DDP + # to begin with. + if fqn not in self.parameters_to_ignore and param.requires_grad: + if param not in param_set: + self._log_and_throw( + ValueError, + f"Param with name {fqn} found in module parameters, but not DDP parameters." + " This indicates a bug in DDP, please report an issue to PyTorch.", + ) + param_index = param_to_param_index[param] + param_index_to_param_fqn[param_index] = fqn + + # Ensure we covered all parameters + if len(param_set) != len(param_index_to_param_fqn): + self._log_and_throw( + ValueError, + ( + "Expected param to name mapping to cover all parameters, but" + f" got conflicting lengths: {len(param_set)} vs " + f"{len(param_index_to_param_fqn)}. This indicates a bug in DDP" + ", please report an issue to PyTorch." + ), + ) + + return param_index_to_param_fqn + + def _get_parameters(self, m, recurse=True): + """Return a generator of module parameters.""" + + def model_parameters(m): + ps = ( + m._former_parameters.values() + if hasattr(m, "_former_parameters") + else m.parameters(recurse=False) + ) + yield from ps + + for mod in m.modules() if recurse else [m]: + yield from model_parameters(mod) + + def _check_default_group(self): + pickle_not_supported = False + try: + if self.process_group != _get_default_group(): + pickle_not_supported = True + except RuntimeError: + pickle_not_supported = True + + if pickle_not_supported: + self._log_and_throw( + RuntimeError, + "DDP Pickling/Unpickling are only supported " + "when using DDP with the default process " + "group. That is, when you have called " + "init_process_group and have not passed " + "process_group argument to DDP constructor", + ) + + @contextmanager + def no_sync(self): + r""" + Context manager to disable gradient synchronizations across DDP processes. + + Within this context, gradients will be accumulated on module + variables, which will later be synchronized in the first + forward-backward pass exiting the context. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg) + >>> with ddp.no_sync(): + >>> for input in inputs: + >>> ddp(input).backward() # no synchronization, accumulate grads + >>> ddp(another_input).backward() # synchronize grads + + .. warning:: + The forward pass should be included inside the context manager, or + else gradients will still be synchronized. + """ + old_require_backward_grad_sync = self.require_backward_grad_sync + self.require_backward_grad_sync = False + try: + yield + finally: + self.require_backward_grad_sync = old_require_backward_grad_sync + + @classmethod + def _get_active_ddp_module(cls): + """`TorchDynamo` requires DDP's status and module for cooperative optimization.""" + return cls._active_ddp_module + + # note, this ctxmgr function is marked 'skip' in torchdynamo, so dynamo only kicks in + # for the 'module_to_run' underneath + # see torch._dynamo/eval_frame.py TorchPatcher.patch for more details + @contextmanager + @torch._disable_dynamo(recursive=False) + def _inside_ddp_forward(self): + DistributedDataParallel._active_ddp_module = self + try: + yield + finally: + DistributedDataParallel._active_ddp_module = None + + def _run_ddp_forward(self, *inputs, **kwargs): + if self._use_python_reducer: + return self.module(*inputs, **kwargs) # type: ignore[index] + else: + with self._inside_ddp_forward(): + return self.module(*inputs, **kwargs) # type: ignore[index] + + def _clear_grad_buffer(self): + # Making param.grad points to the grad buffers before backward is based on the + # assumption that the grad accumulation is done in place in autograd engine, + # for some edge cases, if the grad accumulation in autograd engine is not in + # place, then the param.grad and grad buffers are detached. + if self._delay_grad_buffer is not None: + # We batch zero_grad for all params by resetting the whole grad + # buffer when the grad of all params is set to None. + all_param_grad_none = all( + param.grad is None for param in self._delay_all_reduce_params + ) + + for index, param in enumerate(self._delay_all_reduce_params): + if param.grad is None: + param.grad = self._delay_grad_views[index] + if not all_param_grad_none: + param.grad.zero_() + + if all_param_grad_none: + self._delay_grad_buffer.zero_() + + def _lazy_init(self): + # Initialization for DDP that occurs after construction, but lazily + # before the first forward pass. + self._setup_in_backward_optimizers() + self._lazy_init_ran = True + + def _pre_forward(self, *inputs, **kwargs): + if self._use_python_reducer: + return inputs, kwargs + + if not self._lazy_init_ran and not torch.compiler.is_compiling(): + self._lazy_init() + + if self._delay_all_reduce_all_params: + return inputs, kwargs + + if torch.is_grad_enabled() and self.require_backward_grad_sync: + assert self.logger is not None + self.logger.set_runtime_stats_and_log() + self.reducer.prepare_for_forward() + + # Notify the join context that this process has not joined, if + # needed + work = Join.notify_join_context(self) + if work: + self.reducer._set_forward_pass_work_handle( + work, + self._divide_by_initial_world_size, # type: ignore[arg-type] + ) + + # Calling _rebuild_buckets before forward computation, + # It may allocate new buckets before deallocating old buckets + # inside _rebuild_buckets. To save peak memory usage, + # call _rebuild_buckets before the peak memory usage increases + # during forward computation. + # This should be called only once during whole training period. + if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): + logger.info("Reducer buckets have been rebuilt in this iteration.") + self._has_rebuilt_buckets = True + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + + if self._join_config.enable: + # Notify joined ranks whether they should sync in backwards pass or not. + self._check_global_requires_backward_grad_sync(is_joined_rank=False) + + if self.device_ids: + moved_inputs, moved_kwargs = _to_kwargs( + inputs, + kwargs, + torch.device(self.device_type, self.device_ids[0]), + self.use_side_stream_for_tensor_copies, + ) + args, kwargs = moved_inputs[0], moved_kwargs[0] + # Cast inputs to reduced precision if needed. + if self.mixed_precision is not None: + args, kwargs = _cast_forward_inputs( + self.mixed_precision.param_dtype, + *args, + **kwargs, + ) + return args, kwargs + else: + # Cast inputs to reduced precision if needed. + # TODO (rohan-varma) test this codepath. + if self.mixed_precision is not None: + inputs, kwargs = _cast_forward_inputs( + self.mixed_precision.param_dtype, + *inputs, + **kwargs, + ) + return inputs, kwargs + + def _post_forward(self, output): + if self._use_python_reducer: + return output + + if self._delay_all_reduce_all_params: + self._clear_grad_buffer() + return output + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters and not self.static_graph: + # Do not need to populate this for static graph. + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + self.require_forward_param_sync = False + + # TODO: DDPSink is currently enabled for unused parameter detection and + # static graph training for first iteration. + if (self.find_unused_parameters and not self.static_graph) or ( + self.static_graph and not self._static_graph_delay_allreduce_enqueued + ): + ( + output_tensor_list, + treespec, + output_is_rref, + ) = _tree_flatten_with_rref(output) + output_placeholders: list[torch.Tensor | None] = [ + None for _ in range(len(output_tensor_list)) + ] + # Do not touch tensors that have no grad_fn, which can cause issues + # such as https://github.com/pytorch/pytorch/issues/60733 + for i, output in enumerate(output_tensor_list): + if torch.is_tensor(output) and output.grad_fn is None: + output_placeholders[i] = output + + # When find_unused_parameters=True, makes tensors which require grad + # run through the DDPSink backward pass. When not all outputs are + # used in loss, this makes those corresponding tensors receive + # undefined gradient which the reducer then handles to ensure + # param.grad field is not touched and we don't error out. + passthrough_tensor_list = _DDPSink.apply( + weakref.ref(self), + *output_tensor_list, + ) + for i in range(len(output_placeholders)): + if output_placeholders[i] is None: + output_placeholders[i] = passthrough_tensor_list[i] + + # Reconstruct output data structure. + output = _tree_unflatten_with_rref( + output_placeholders, treespec, output_is_rref + ) + + # At the end of the forward pass, reset the grad buffer and grad views + self._clear_grad_buffer() + return output + + def forward(self, *inputs, **kwargs): + with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): + inputs, kwargs = self._pre_forward(*inputs, **kwargs) + output = ( + self.module.forward(*inputs, **kwargs) + if self._delay_all_reduce_all_params + else self._run_ddp_forward(*inputs, **kwargs) + ) + return self._post_forward(output) + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def to_kwargs(self, inputs, kwargs, device_id): + # Kept for BC + return _to_kwargs( + inputs, + kwargs, + torch.device(self.device_type, device_id), + self.use_side_stream_for_tensor_copies, + ) + + def gather(self, outputs, output_device): + return gather(outputs, output_device, dim=self.dim) + + def train(self, mode=True): + super().train(mode) + return self + + # When running in join mode, schedules an allreduce to notify joined ranks + # of whether backwards pass synchronization will run this iteration or not. + def _check_global_requires_backward_grad_sync(self, is_joined_rank): + if not is_joined_rank and self.require_backward_grad_sync: + requires_sync_tensor = torch.ones(1, device=self.device) + else: + requires_sync_tensor = torch.zeros(1, device=self.device) + + work = dist.all_reduce( + requires_sync_tensor, group=self.process_group, async_op=True + ) + + # (kwen2501) This if condition is a plain translation of previous + # behavior, i.e. in the `is_joined_rank=False` case, `work.wait()` + # is not called and it doesn't care about the result. I am guessing + # that it just wants to fire a matching all-reduce and does not want + # the main stream to wait. + if is_joined_rank: + work.wait() + should_sync_backwards = requires_sync_tensor.item() != 0 + return should_sync_backwards + else: + return None # Return value is not/should not be used. + + # When running in join mode, checks and performs sync of module buffers if + # the models have buffers that should be synchronized in the forward pass. + def _check_and_sync_module_buffers(self): + if self._check_sync_bufs_pre_fwd(): + authoritative_rank = self._find_common_rank(self._distributed_rank, False) + self._sync_module_buffers(authoritative_rank) + + # When running in join model, agrees upon a common rank and broadcast model + # parameters to all other ranks. + def _sync_final_model(self, is_last_joiner): + # Agree upon the process that will be the authoritative model copy. + # The current rank is a candidate for being the authoritative copy if + # is_last_joiner=True. We break ties via picking the larger rank. + self._authoritative_rank = self._find_common_rank( + self._distributed_rank, is_last_joiner + ) + _sync_module_states( + module=self.module, + process_group=self.process_group, + broadcast_bucket_size=self.broadcast_bucket_size, + src=self._authoritative_rank, + params_and_buffers_to_ignore=self.parameters_to_ignore, + broadcast_buffers=self.broadcast_buffers, + ) + + # Schedule comm ops to match those scheduled in the reducer's backward + # pass. + def _match_all_reduce_for_bwd_pass(self): + comm_work = [] + # Schedule comm in the same order as Reducer schedules them, i.e. + # the order of the buckets. Retrieving the bucket order from the reducer + # ensures that we keep the same order in join mode, such as when bucket + # order is rebuilt dynamically. + + # Returns grad_buckets in order, but real tensors are substituted with + # zero tensors of the same shape. + grad_buckets = self.reducer._get_zeros_like_grad_buckets() + for grad_bucket in grad_buckets: + # Joined processes contribute zero gradient. In the case that + # divide_by_initial_world_size=True, we divide grads by the static + # world size, if not, the dividing factor is reduced by the number + # of joined processes. + work = self.reducer._run_comm_hook(grad_bucket) + comm_work.append(work) + for work in comm_work: + work.wait() + + # Allreduces the used parameter mapping across ranks. + def _match_unused_params_allreduce(self): + locally_used_param_map = self.reducer._get_local_used_map() + self.process_group.allreduce(locally_used_param_map) + + def join( + self, + divide_by_initial_world_size: bool = True, + enable: bool = True, + throw_on_early_termination: bool = False, + ): + r""" + Context manager for training with uneven inputs across processes in DDP. + + This context manager will keep track of already-joined DDP processes, + and "shadow" the forward and backward passes by inserting collective + communication operations to match with the ones created by non-joined + DDP processes. This will ensure each collective call has a corresponding + call by already-joined DDP processes, preventing hangs or errors that + would otherwise happen when training with uneven inputs across + processes. Alternatively, if the flag ``throw_on_early_termination`` is + specified to be ``True``, all trainers will throw an error once one rank + runs out of inputs, allowing these errors to be caught and handled + according to application logic. + + Once all DDP processes have joined, the context manager will broadcast + the model corresponding to the last joined process to all processes to + ensure the model is the same across all processes + (which is guaranteed by DDP). + + To use this to enable training with uneven inputs across processes, + simply wrap this context manager around your training loop. No further + modifications to the model or data loading is required. + + .. warning:: + If the model or training loop this context manager is wrapped around + has additional distributed collective operations, such as + ``SyncBatchNorm`` in the model's forward pass, then the flag + ``throw_on_early_termination`` must be enabled. This is because this + context manager is not aware of non-DDP collective communication. + This flag will cause all ranks to throw when any one rank + exhausts inputs, allowing these errors to be caught and recovered + from across all ranks. + + Args: + divide_by_initial_world_size (bool): If ``True``, will divide + gradients by the initial ``world_size`` DDP training was launched + with. If ``False``, will compute the effective world size + (number of ranks that have not depleted their inputs yet) and + divide gradients by that during allreduce. Set + ``divide_by_initial_world_size=True`` to ensure every input + sample including the uneven inputs have equal weight in terms of + how much they contribute to the global gradient. This is + achieved by always dividing the gradient by the initial + ``world_size`` even when we encounter uneven inputs. If you set + this to ``False``, we divide the gradient by the remaining + number of nodes. This ensures parity with training on a smaller + ``world_size`` although it also means the uneven inputs would + contribute more towards the global gradient. Typically, you + would want to set this to ``True`` for cases where the last few + inputs of your training job are uneven. In extreme cases, where + there is a large discrepancy in the number of inputs, setting + this to ``False`` might provide better results. + enable (bool): Whether to enable uneven input detection or not. Pass + in ``enable=False`` to disable in cases where you know that + inputs are even across participating processes. Default is + ``True``. + throw_on_early_termination (bool): Whether to throw an error + or continue training when at least one rank has exhausted + inputs. If ``True``, will throw upon the first rank reaching end + of data. If ``False``, will continue training with a smaller + effective world size until all ranks are joined. Note that if + this flag is specified, then the flag + ``divide_by_initial_world_size`` would be ignored. Default + is ``False``. + + + Example:: + + >>> # xdoctest: +SKIP("Distributed") + >>> import torch + >>> import torch.distributed as dist + >>> import os + >>> import torch.multiprocessing as mp + >>> import torch.nn as nn + >>> # On each spawned worker + >>> def worker(rank): + >>> dist.init_process_group("nccl", rank=rank, world_size=2) + >>> torch.cuda.set_device(rank) + >>> model = nn.Linear(1, 1, bias=False).to(rank) + >>> model = torch.nn.parallel.DistributedDataParallel( + >>> model, device_ids=[rank], output_device=rank + >>> ) + >>> # Rank 1 gets one more input than rank 0. + >>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)] + >>> with model.join(): + >>> for _ in range(5): + >>> for inp in inputs: + >>> loss = model(inp).sum() + >>> loss.backward() + >>> # Without the join() API, the below synchronization will hang + >>> # blocking for rank 1's allreduce to complete. + >>> torch.cuda.synchronize(device=rank) + """ + return Join( + [self], + enable, + throw_on_early_termination, + divide_by_initial_world_size=divide_by_initial_world_size, + ) + + def join_hook( + self, + **kwargs, + ): + r""" + DDP join hook enables training on uneven inputs by mirroring communications in forward and backward passes. + + Arguments: + kwargs (dict): a :class:`dict` containing any keyword arguments + to modify the behavior of the join hook at run time; all + :class:`Joinable` instances sharing the same join context + manager are forwarded the same value for ``kwargs``. + + The hook supports the following keyword arguments: + divide_by_initial_world_size (bool, optional): + If ``True``, then gradients are divided by the initial world + size that DDP was launched with. + If ``False``, then gradients are divided by the effective world + size (i.e. the number of non-joined processes), meaning that + the uneven inputs contribute more toward the global gradient. + Typically, this should be set to ``True`` if the degree of + unevenness is small but can be set to ``False`` in extreme + cases for possibly better results. + Default is ``True``. + """ + divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True) + return _DDPJoinHook( + self, divide_by_initial_world_size=divide_by_initial_world_size + ) + + @property + def join_device(self): + return self.device + + @property + def join_process_group(self): + return self.process_group + + def _register_buffer_comm_hook( + self, + state, + hook: Callable, + comm_hook_location=_BufferCommHookLocation.POST_FORWARD, + ): + r""" + Allow custom registration of hooks that define how buffer are synchronized across ranks. + + The hook takes in an optional state and is passed in a Dict[str, Tensor] + corresponding to buffer names and the buffers, and can run arbitrary reductions + on buffers as opposed to DDP's default broadcast from rank 0. This is useful for + example if a counter needs to be summed or averaged across ranks every iteration. + + Args: + state (Any): Optional state that is passed to the hook. + hook (Callable): Callable with the following signature: + ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]`` + comm_hook_location (_BufferCommHookLocation): Enum value indicating + where to run the hook. + _BufferCommHookLocation.PRE_FORWARD means that the + hook will run _before_ the forward pass, and + _BufferCommHookLocation.POST_FORWARD means that the + hook will run _after_ the forward pass. + + NOTE: To maximize performance, users can return a + List[torch.futures.Future] from their hook, and DDP will + install and await these hooks appropriately at the end of + the backward pass. This will ensure all buffers are + synchronized by the end of the backward pass. If this + setting is used, it is recommended to pass + comm_hook_location=_BufferCommHookLocation.POST_FORWARD, + which will trigger the hook after the forward pass. + If _BufferCommHookLocation.PRE_FORWARD is used, users must + ensure appropriate synchronization when manipulating GPU + buffers in the forward pass. + """ + assert callable(hook) + self.buffer_hook = _BufferCommHook( + buffer_comm_hook=hook, + buffer_comm_hook_state=state, + buffer_comm_hook_location=comm_hook_location, + ) + + def register_comm_hook(self, state: object, hook: Callable): + r""" + Register communication hook for user-defined DDP aggregation of gradients across multiple workers. + + This hook would be very useful for researchers to try out new ideas. For + example, this hook can be used to implement several algorithms like GossipGrad + and gradient compression which involve different communication strategies for + parameter syncs while running Distributed DataParallel training. + + Args: + state (object): Passed to the hook to maintain any state information during the training process. + Examples include error feedback in gradient compression, + peers to communicate with next in GossipGrad, etc. + + It is locally stored by each worker + and shared by all the gradient tensors on the worker. + hook (Callable): Callable with the following signature: + ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``: + + This function is called once the bucket is ready. The + hook can perform whatever processing is needed and return + a Future indicating completion of any async work (ex: allreduce). + If the hook doesn't perform any communication, it still + must return a completed Future. The Future should hold the + new value of grad bucket's tensors. Once a bucket is ready, + c10d reducer would call this hook and use the tensors returned + by the Future and copy grads to individual parameters. + Note that the future's return type must be a single tensor. + + We also provide an API called ``get_future`` to retrieve a + Future associated with the completion of ``c10d.ProcessGroup.Work``. + ``get_future`` is currently supported for NCCL and also supported for most + operations on GLOO and MPI, except for peer to peer operations (send/recv). + + .. warning :: + Grad bucket's tensors will not be predivided by world_size. User is responsible + to divide by the world_size in case of operations like allreduce. + + .. warning :: + DDP communication hook can only be registered once and should be registered + before calling backward. + + .. warning :: + The Future object that hook returns should contain a single tensor + that has the same shape with the tensors inside grad bucket. + + .. warning :: + ``get_future`` API supports NCCL, and partially GLOO and MPI backends (no support + for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``. + + Example:: + Below is an example of a noop hook that returns the same tensor. + + >>> # xdoctest: +SKIP('undefined name') + >>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: + >>> fut = torch.futures.Future() + >>> fut.set_result(bucket.buffer()) + >>> return fut + >>> ddp.register_comm_hook(state=None, hook=noop) + + Example:: + Below is an example of a Parallel SGD algorithm where gradients are encoded before + allreduce, and then decoded after allreduce. + + >>> # xdoctest: +SKIP('undefined name') + >>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: + >>> encoded_tensor = encode(bucket.buffer()) # encode gradients + >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future() + >>> # Define the then callback to decode. + >>> def decode(fut): + >>> decoded_tensor = decode(fut.value()[0]) # decode gradients + >>> return decoded_tensor + >>> return fut.then(decode) + >>> ddp.register_comm_hook(state=None, hook=encode_and_decode) + """ + self._check_comm_hook(hook) + assert self.logger is not None + self.logger._set_comm_hook_name(hook.__qualname__) + self._comm_hooks.append((hook, state)) + dist._register_comm_hook(self.reducer, state, hook) + + def _register_builtin_comm_hook(self, comm_hook_type): + r""" + Register a built-in communication hook that specifies how DDP aggregates gradients across multiple workers. + + The built-in hooks aim to provide efficient C++ implementations for certain hooks, + which might not be as efficient if implemented in Python using a Python communication hook. + + Args: + comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as ALLREDUCE, FP16_COMPRESS, etc. + + .. warning :: + DDP communication hook can only be registered once and should be registered + before calling backward. + + Example:: + Below is an example of a FP16 compression where gradients are + compressed into 16-bit floating-point numbers before allreduce, and + then decompressed after allreduce. + + >>> # xdoctest: +SKIP('undefined name') + >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS) + + """ + assert self.logger is not None + self.logger._set_comm_hook_name(str(comm_hook_type)) + dist._register_builtin_comm_hook(self.reducer, comm_hook_type) + + def _register_fused_optim(self, optim: type, *args, optim_params=None, **kwargs): + r""" + Register an optimizer in DDP to optimize parameter immediately after its gradient reduction. + + Registers an optimizer with DDP such that the optimization for a + parameter will run immediately when that parameter's gradient is + finished with reduction, instead of waiting for all parameters' + gradients to finish reduction. This can result in a training speedup + depending on your workload since the optimizer can run while gradient + reduction for other parameters are still ongoing. In addition, this has + the potential to reduce peak memory consumption during training, as it + only needs to load the per-parameter optimizer states of a single + parameter at a time, instead of loading all per-parameter optimizer + states at once. + + Args: + optim (Type): a ``torch.optim.Optimizer`` class to be registered + as a fused optimizer. + *args (Sequence[Any]): Arguments to forward to `optim`. + optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters + to optimize, similar to `params` argument of traditional `torch.optim` + Optimizers. If this is omitted, all DDP model parameters will be + optimized. + **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim`. + + .. warning :: + _register_fused_optim should only be called once on a DDP instance, + and registering multiple fused optimizers for the same DDP model + is not currently supported. Please ping + https://github.com/pytorch/pytorch/issues/71595 if this is necessary + for your use case. + + .. warning :: + _register_fused_optim and register_comm_hook currently do not + compose together, meaning that custom DDP communication hooks are + not supported with overlapped optimizers. Please ping + https://github.com/pytorch/pytorch/issues/71595 if this is necessary + for your use case. + + .. warning :: + Gradient accumulation and DDP `no_sync` are currently not supported + with overlapped optimizer. Please ping + https://github.com/pytorch/pytorch/issues/71595 if this is necessary + for your use case. + + Example:: + + >>> # xdoctest: +SKIP("No rendezvous handler") + >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') + >>> net = torch.nn.parallel.DistributedDataParallel(model, pg) + >>> lr = 1e-2 + >>> betas = (0.9, 0.99) + >>> eps = 1e-6 + >>> net._register_fused_optim(torch.optim.Adam, lr, betas=betas, eps=eps) + >>> # Example with subset of parameters + >>> params_to_opt = [list(net.parameters())[0]] + >>> net._register_fused_optim( + ... torch.optim.Adam, lr, optim_params=params_to_opt, betas=betas, eps=eps + ... ) + """ + # Note: importing in function, otherwise this will cause a circular + # import as optimizer_overlap module needs to import DistributedDataParallel. + from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim + + overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs) + try: + overlapped_optim.register_ddp(self) + except NotImplementedError as e: + raise RuntimeError( + f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}." + ) from e + + def _distributed_broadcast_coalesced( + self, tensors, buffer_size, authoritative_rank=0 + ): + dist._broadcast_coalesced( + self.process_group, tensors, buffer_size, authoritative_rank + ) + + def _check_sync_bufs_post_fwd(self): + return ( + self.will_sync_module_buffers() + and hasattr(self, "buffer_hook") + and self.buffer_hook.buffer_comm_hook_location + == _BufferCommHookLocation.POST_FORWARD + ) + + def _check_sync_bufs_pre_fwd(self): + return self.will_sync_module_buffers() and ( + not hasattr(self, "buffer_hook") + or self.buffer_hook.buffer_comm_hook_location + == _BufferCommHookLocation.PRE_FORWARD + ) + + def will_sync_module_buffers(self): + return ( + self.require_forward_param_sync + and self.broadcast_buffers + and len(self.modules_buffers) > 0 + ) + + def _find_common_rank(self, input_rank, rank_cond): + # -1 indicates that this rank is not under consideration to be the + # common_rank + rank_to_use = torch.tensor( + [input_rank if rank_cond else -1], + device=self.device, + ) + dist.all_reduce(rank_to_use, op=ReduceOp.MAX, group=self.process_group) + if rank_to_use.item() == -1: + self._log_and_throw( + ValueError, + "BUG! Expected rank_cond to be true for at least one process." + " This indicates a bug in PyTorch, please report an issue.", + ) + return rank_to_use.item() + + def _sync_buffers(self): + with torch.no_grad(): + # module buffer sync + # Synchronize buffers across processes. + # If we are running DDP with the join manager, we have to agree + # upon a rank to sync module buffers from, since rank 0 may + # already have been joined and have stale module buffers. + if self._join_config.enable: + authoritative_rank = self._find_common_rank( + self._distributed_rank, True + ) + else: + # The process with rank 0 is considered the authoritative copy. + authoritative_rank = 0 + # Update self.modules_buffers in case any buffers were + # reassigned. + self._assign_modules_buffers() + self._sync_module_buffers(authoritative_rank) + + def _sync_module_buffers(self, authoritative_rank): + if not hasattr(self, "buffer_hook"): + self._default_broadcast_coalesced(authoritative_rank=authoritative_rank) + else: + hook = self.buffer_hook.buffer_comm_hook + state = self.buffer_hook.buffer_comm_hook_state + futs = hook(state, self.named_module_buffers) + if futs is not None: + self.reducer._install_post_backward_futures(futs) + + def _default_broadcast_coalesced( + self, bufs=None, bucket_size=None, authoritative_rank=0 + ): + """ + Broadcasts buffers from rank 0 to rest of workers. + + If bufs, bucket_size are None, default values self.modules_buffers + and self.broadcast_bucket_size are used instead. + """ + if bufs is None: + bufs = self.modules_buffers + if bucket_size is None: + bucket_size = self.broadcast_bucket_size + + self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank) + + def _passing_sync_batchnorm_handle(self, module): + for layer in module.modules(): + if isinstance(layer, torch.nn.modules.SyncBatchNorm): + if self.device_type == "cpu": + self._log_and_throw( + ValueError, + "SyncBatchNorm layers only work with GPU modules", + ) + + def _check_comm_hook(self, hook): + if not callable(hook): + self._log_and_throw(TypeError, "Communication hook must be callable.") + + sig = inspect.signature(hook) + if ( + sig.parameters["bucket"].annotation != inspect._empty + and sig.parameters["bucket"].annotation != dist.GradBucket + ): + self._log_and_throw( + ValueError, + "Communication hook: bucket annotation should be dist.GradBucket.", + ) + + if ( + sig.return_annotation != inspect._empty + and sig.return_annotation != torch.futures.Future[torch.Tensor] + ): + self._log_and_throw( + ValueError, + "Communication hook: return annotation should be torch.futures.Future[torch.Tensor].", + ) + + if hook.__name__ in ["bf16_compress_hook", "bf16_compress_wrapper_hook"]: + cuda_supported = ( + torch.version.cuda is not None + ) or torch.version.hip is not None + nccl_supported = ( + dist.is_available() + and dist.is_nccl_available() + and torch.cuda.nccl.version() >= (2, 10) + ) + xpu_xccl_supported = ( + dist.is_available() + and dist.is_xccl_available() + and torch.xpu.is_available() + ) + + if not ((cuda_supported and nccl_supported) or xpu_xccl_supported): + self._log_and_throw( + TypeError, + "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+ or XPU and XCCL", + ) + + @property + def _distributed_rank(self): + return dist.get_rank(self.process_group) + + @staticmethod + def _get_data_parallel_params(module, named_params=False): + """Return a generator of parameters managed by a given DDP unit.""" + for param in ( + module.parameters() if not named_params else module.named_parameters() + ): + if not hasattr(param, "_ddp_ignored"): + yield param + + @staticmethod + def _set_params_and_buffers_to_ignore_for_model( + module, params_and_buffers_to_ignore + ): + """ + Set parameters and buffers to be ignored by DDP. + + Expected format for parameters is the fully qualified name: {module_name}.{param_name}, and + similarly, {module_name}.{buffer_name} for buffers. For example: + params_to_ignore = [] + # NB: model here is vanilla PyTorch module, not yet wrapped with DDP. + for module_name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + if should_ignore(param): + # Create expected format + fqn = f"{module_name}.{param_name}" + params_to_ignore.append(fqn) + torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( + model, + params_to_ignore + ) + """ + # This is a workaround to set parameters and buffers DDP should ignore + # during synchronization. It will be removed when the API is finalized + # as part of addressing https://github.com/pytorch/pytorch/issues/43690. + module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore + for name, param in module.named_parameters(): + if name in params_and_buffers_to_ignore: + param._ddp_ignored = True + for name, buffer in module.named_buffers(): + if name in params_and_buffers_to_ignore: + buffer._ddp_ignored = True + + def _get_ddp_logging_data(self): + r""" + Return a dictionary of logging data for debugging and analysis. + + This interface can be called after DistributedDataParallel() is + constructed. It returns a dictionary of logging data. It could help + for debugging and analysis. The logging data includes DistributedDataParallel + constructor input parameters, some internal states of DistributedDataParallel + and performance metrics. Simply print the dictionary and see what + these metrics are. + This is a prototype interface and subject to change in the future. + """ + assert self.logger is not None + ddp_logging_data = self.logger._get_ddp_logging_data() + return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map} + + def _set_ddp_runtime_logging_sample_rate(self, sample_rate): + r""" + Set sample_rate of collecting runtime stats. + + This interface allows users to set sample_rate of collecting + runtime stats. The runtime stats will be recorded for the + first 10 iterations, after 10 iterations runtime stats will be + recorded once every "sample_rate" training iterations. In + default, runtime stats are recorded for the first 10 iterations, + after 10 iterations runtime stats are recorded once every + "kDDPRuntimeLoggingSampleRate=100" training iterations. + This is a prototype interface and subject to change in the future. + """ + if sample_rate < 1: + self._log_and_throw( + ValueError, + "DDP runtime logging sample rate should be equal or greater than 1", + ) + self.reducer._set_ddp_runtime_logging_sample_rate(sample_rate) + + def _set_static_graph(self): + """ + Set static graph for DDP. + + It is recommended to set static graph in the DDP constructor, which will + call this private API internally. + """ + # If self.static_graph has been set, no need to set it again + if self.static_graph: + warnings.warn( + "You've set static_graph to be True, no need to set it again.", + stacklevel=2, + ) + return + self.static_graph = True + self._static_graph_delay_allreduce_enqueued = False + self.reducer._set_static_graph() + assert self.logger is not None + self.logger._set_static_graph() + if self.find_unused_parameters: + warnings.warn( + "You passed find_unused_parameters=true to DistributedDataParallel, " + "`_set_static_graph` will detect unused parameters automatically, so " + "you do not need to set find_unused_parameters=true, just be sure these " + "unused parameters will not change during training loop while calling " + "`_set_static_graph`.", + stacklevel=2, + ) + + def _remove_autograd_hooks(self): + """Remove autograd hooks registered by the reducer on the model parameters.""" + self.reducer._remove_autograd_hooks() + + def _check_reducer_finalized(self): + """ + Check if the reducer has processed all buckets and finalized the backward appropriately. + + It is useful to call this method after calling .backward() in your training loop + in order to avoid subsequent hard to debug errors down the road due to the + reducer not finalizing backward. + """ + self.reducer._check_reducer_finalized() + + def _set_sparse_metadata(self, global_unique_ids): + self.reducer._set_sparse_metadata(global_unique_ids) + + def _update_process_group(self, new_process_group): + """ + Dynamically updates the process group for DDP so that we can shrink/expand DDP + world size without having to reinitialize DDP. + + NOTE: If you are using custom communications hooks via, register_comm_hook, + you need to update the process groups for those hooks separately. + """ + # Force a rebuild of buckets for a new process group. This ensures all ranks + # are synchronized in terms of when they will rebuild buckets and also + # re-evaluates previous assumptions of buckets given the world size might have + # changed. + self._has_rebuilt_buckets = False + self.reducer._reset_state() + + if not _rank_not_in_group(new_process_group): + self.process_group = new_process_group + self.reducer._update_process_group(new_process_group) + + def _set_ddp_sink_clone(self, val: bool): + """ + Sets whether or not DDPSink should clone the output tensors or not. + The default is True since if the loss is modified in place we run + into the view is modified in-place error. + + Although, cloning the tensors can add significant memory and + performance hit if the number and size of tensors are large. As + a result, this can be set to False if you are not modifying the + loss in place. + """ + self._ddp_sink_clone = val diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..6c26aaf5048e908ab72978b9d8562d4997c17928 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py @@ -0,0 +1,135 @@ +import threading +from collections.abc import Sequence +from typing import Any, cast + +import torch +from torch._utils import ExceptionWrapper +from torch.cuda._utils import _get_device_index +from torch.nn.modules import Module + + +__all__ = ["get_a_var", "parallel_apply"] + + +def get_a_var( + obj: torch.Tensor | list[Any] | tuple[Any, ...] | dict[Any, Any], +) -> torch.Tensor | None: + if isinstance(obj, torch.Tensor): + return obj + + if isinstance(obj, (list, tuple)): + for result in map(get_a_var, obj): + if isinstance(result, torch.Tensor): + return result + if isinstance(obj, dict): + for result in map(get_a_var, obj.items()): + if isinstance(result, torch.Tensor): + return result + return None + + +def parallel_apply( + modules: Sequence[Module], + inputs: Sequence[Any], + kwargs_tup: Sequence[dict[str, Any]] | None = None, + devices: Sequence[int | torch.device | None] | None = None, +) -> list[Any]: + r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`. + + Args: + modules (Module): modules to be parallelized + inputs (tensor): inputs to the modules + devices (list of int or torch.device): CUDA devices + + :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and + :attr:`devices` (if given) should all have same length. Moreover, each + element of :attr:`inputs` can either be a single object as the only argument + to a module, or a collection of positional arguments. + """ + assert len(modules) == len(inputs), ( + f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}" + ) + if kwargs_tup is not None: + assert len(modules) == len(kwargs_tup) + else: + kwargs_tup = (cast(dict[str, Any], {}),) * len(modules) + if devices is not None: + assert len(modules) == len(devices) + else: + devices = [None] * len(modules) + devices = [_get_device_index(x, True) for x in devices] + streams = [torch.accelerator.current_stream(x) for x in devices] + assert torch.accelerator.is_available(), "No available accelerator found." + device_type = torch.accelerator.current_accelerator().type # type: ignore[union-attr] + lock = threading.Lock() + results = {} + grad_enabled, autocast_enabled = ( + torch.is_grad_enabled(), + torch.is_autocast_enabled(), + ) + + def _worker( + i: int, + module: Module, + input: Any, + kwargs: dict[str, Any], + device: int | torch.device | None = None, + stream: torch.Stream | None = None, + ) -> None: + torch.set_grad_enabled(grad_enabled) + if device is None: + t = get_a_var(input) + if t is None: + with lock: + results[i] = ExceptionWrapper( + where=f"in replica {i}, no device was provided and no tensor input was found; " + "device cannot be resolved" + ) + return + device = t.get_device() + if isinstance(device, torch.device): + device = device.index + if stream is None: + stream = torch.accelerator.current_stream(device) + try: + with ( + torch.accelerator.device_index(device), + stream, + torch.amp.autocast(device_type, enabled=autocast_enabled), + ): + # this also avoids accidental slicing of `input` if it is a Tensor + if not isinstance(input, (list, tuple)): + input = (input,) + output = module(*input, **kwargs) + with lock: + results[i] = output + except Exception: + with lock: + results[i] = ExceptionWrapper( + where=f"in replica {i} on device {device}" + ) + + if len(modules) > 1: + threads = [ + threading.Thread( + target=_worker, args=(i, module, input, kwargs, device, stream) + ) + for i, (module, input, kwargs, device, stream) in enumerate( + zip(modules, inputs, kwargs_tup, devices, streams, strict=True) + ) + ] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + else: + _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0]) + + outputs = [] + for i in range(len(inputs)): + output = results[i] + if isinstance(output, ExceptionWrapper): + output.reraise() + outputs.append(output) + return outputs diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/replicate.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7844ab4aba222055f726492df33d2a61aba880 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/replicate.py @@ -0,0 +1,203 @@ +from collections import OrderedDict +from collections.abc import Iterator, Sequence +from typing import cast, TYPE_CHECKING, TypeVar +from typing_extensions import TypeIs + +import torch +from torch._utils import _get_device_index +from torch.nn.modules import Module +from torch.nn.parallel import comm + + +if TYPE_CHECKING: + from torch._C import ScriptMethod + from torch.jit import ScriptModule + from torch.jit._state import EnabledProxy + + +__all__ = ["replicate"] + + +def _is_script_module(module: Module) -> TypeIs["ScriptModule"]: + import torch.jit + + return isinstance(module, torch.jit.ScriptModule) + + +def _is_script_method(module: object) -> TypeIs["ScriptMethod"]: + import torch.jit + + return isinstance(module, torch._C.ScriptMethod) + + +def _init_script_module() -> "ScriptModule": + import torch.jit + + return torch.jit.ScriptModule() + + +def _is_jit_enabled() -> "EnabledProxy": + import torch.jit._state + + return torch.jit._state._enabled + + +# Check if we can safely replicate the module. +# there are two types of module: +# 1. python modules +# 2. ScriptModule +# +# currently a module cannot be replicated properly if the descendants of +# any ScriptModule contains python module (type 1 above) +def _replicatable_module(module: Module, memo: set[Module] | None = None) -> bool: + # module.modules() contains module itself as the first element + def descendant_modules(module: Module) -> Iterator[Module]: + gen = module.modules() + next(gen) + return gen + + if not _is_jit_enabled(): + return True + if memo is None: + memo = set() + + # memoize visited modules + memo.add(module) + if _is_script_module(module): + memo.update(descendant_modules(module)) + return all( + _is_script_module(descendant) for descendant in descendant_modules(module) + ) + + for child in module.children(): + # since any unreplicatable module will cause the check to return + # False early, visited modules here can be safely ignored. + if child in memo: + continue + if not _replicatable_module(child, memo): + return False + + return True + + +def _broadcast_coalesced_reshape( + tensors: Sequence[torch.Tensor], + devices: Sequence[int | torch.device], + detach: bool = False, +) -> list[list[torch.Tensor]]: + from torch.nn.parallel._functions import Broadcast + + if detach: + return comm.broadcast_coalesced(tensors, devices) + else: + # Use the autograd function to broadcast if not detach + if len(tensors) > 0: + tensor_copies = Broadcast.apply(devices, *tensors) + return [ + tensor_copies[i : i + len(tensors)] + for i in range(0, len(tensor_copies), len(tensors)) + ] + else: + return [] + + +T = TypeVar("T", bound=Module) + + +def replicate( + network: T, + devices: Sequence[int | torch.device], + detach: bool = False, +) -> list[T]: + if not _replicatable_module(network): + raise RuntimeError( + "Cannot replicate network where python modules are children of ScriptModule" + ) + + if not devices: + return [] + + devices = [_get_device_index(x, True) for x in devices] + num_replicas = len(devices) + + params = list(network.parameters()) + param_indices = {param: idx for idx, param in enumerate(params)} + param_copies = _broadcast_coalesced_reshape(params, devices, detach) + + buffers = list(network.buffers()) + buffers_rg: list[torch.Tensor] = [] + buffers_not_rg: list[torch.Tensor] = [] + for buf in buffers: + if buf.requires_grad and not detach: + buffers_rg.append(buf) + else: + buffers_not_rg.append(buf) + + buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} + buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)} + + buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) + buffer_copies_not_rg = _broadcast_coalesced_reshape( + buffers_not_rg, devices, detach=True + ) + + modules = list(network.modules()) + module_copies: list[list[Module]] = [[] for _ in devices] + module_indices: dict[Module, int] = {} + + for i, module in enumerate(modules): + module_indices[module] = i + for j in range(num_replicas): + replica = module._replicate_for_data_parallel() + # This is a temporary fix for DDP. DDP needs to access the + # replicated model parameters. It used to do so through + # `mode.parameters()`. The fix added in #33907 for DP stops the + # `parameters()` API from exposing the replicated parameters. + # Hence, we add a `_former_parameters` dict here to support DDP. + replica._former_parameters = OrderedDict() + + module_copies[j].append(replica) + + for i, module in enumerate(modules): + for key, child in module._modules.items(): + if child is None: + for j in range(num_replicas): + replica = module_copies[j][i] + replica._modules[key] = None + else: + module_idx = module_indices[child] + for j in range(num_replicas): + replica = module_copies[j][i] + setattr(replica, key, module_copies[j][module_idx]) + for key, param in module._parameters.items(): + if param is None: + for j in range(num_replicas): + replica = module_copies[j][i] + replica._parameters[key] = None + else: + param_idx = param_indices[param] + for j in range(num_replicas): + replica = module_copies[j][i] + param_copy = param_copies[j][param_idx] + # parameters in replicas are no longer leaves, + # so setattr them as non-parameter attributes + setattr(replica, key, param_copy) + # expose the parameter for DDP + replica._former_parameters[key] = param_copy # type: ignore[operator, index] + for key, buf in module._buffers.items(): # type: ignore[assignment] + if buf is None: + for j in range(num_replicas): + replica = module_copies[j][i] + replica._buffers[key] = None + else: + if buf.requires_grad and not detach: + buffer_copies = buffer_copies_rg + buffer_idx = buffer_indices_rg[buf] + else: + buffer_copies = buffer_copies_not_rg + buffer_idx = buffer_indices_not_rg[buf] + for j in range(num_replicas): + replica = module_copies[j][i] + setattr(replica, key, buffer_copies[j][buffer_idx]) + + return [cast(T, module_copies[j][0]) for j in range(num_replicas)] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/scatter_gather.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/scatter_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..27aeaf19944dcadab63b25d0c9789c31dff322da --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/parallel/scatter_gather.py @@ -0,0 +1,154 @@ +# mypy: allow-untyped-defs +from collections.abc import Sequence +from typing import Any, overload, TypeVar +from typing_extensions import deprecated + +import torch +from torch.nn.parallel._functions import Gather, Scatter + + +__all__ = ["scatter", "scatter_kwargs", "gather"] + + +@deprecated( + "`is_namedtuple` is deprecated, please use the python checks instead", + category=FutureWarning, +) +def is_namedtuple(obj: Any) -> bool: + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + return _is_namedtuple(obj) + + +def _is_namedtuple(obj: Any) -> bool: + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + return ( + isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + ) + + +T = TypeVar("T", dict, list, tuple) + + +# For some reason, 'scatter' returns a tuple when given a single Tensor input but a list otherwise. +@overload +def scatter( + inputs: torch.Tensor, + target_gpus: Sequence[int | torch.device], + dim: int = ..., +) -> tuple[torch.Tensor, ...]: ... + + +@overload +def scatter( + inputs: T, + target_gpus: Sequence[int | torch.device], + dim: int = ..., +) -> list[T]: ... + + +def scatter(inputs, target_gpus, dim=0): + r"""Slice tensors into approximately equal chunks and distributes them across given GPUs. + + Duplicates references to objects that are not tensors. + """ + + def scatter_map(obj): + if isinstance(obj, torch.Tensor): + return Scatter.apply(target_gpus, None, dim, obj) + if _is_namedtuple(obj): + # pyrefly: ignore [no-matching-overload] + return [ + # pyrefly: ignore [no-matching-overload] + type(obj)(*args) + # pyrefly: ignore # no-matching-overload + for args in zip(*map(scatter_map, obj), strict=False) + ] + if isinstance(obj, tuple) and len(obj) > 0: + # pyrefly: ignore [no-matching-overload] + return list(zip(*map(scatter_map, obj), strict=False)) + if isinstance(obj, list) and len(obj) > 0: + # pyrefly: ignore [no-matching-overload] + return [list(i) for i in zip(*map(scatter_map, obj), strict=False)] + if isinstance(obj, dict) and len(obj) > 0: + # pyrefly: ignore [no-matching-overload] + return [ + # pyrefly: ignore [no-matching-overload] + type(obj)(i) + # pyrefly: ignore # no-matching-overload + for i in zip(*map(scatter_map, obj.items()), strict=False) + ] + return [obj for _ in target_gpus] + + # After scatter_map is called, a scatter_map cell will exist. This cell + # has a reference to the actual function scatter_map, which has references + # to a closure that has a reference to the scatter_map cell (because the + # fn is recursive). To avoid this reference cycle, we set the function to + # None, clearing the cell + try: + res = scatter_map(inputs) + finally: + scatter_map = None # type: ignore[assignment] + return res + + +def scatter_kwargs( + inputs: tuple[Any, ...], + kwargs: dict[str, Any] | None, + target_gpus: Sequence[int | torch.device], + dim: int = 0, +) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]: + r"""Scatter with support for kwargs dictionary.""" + scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else [] + scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] + if len(scattered_inputs) < len(scattered_kwargs): + scattered_inputs.extend( + () for _ in range(len(scattered_kwargs) - len(scattered_inputs)) + ) + elif len(scattered_kwargs) < len(inputs): + scattered_kwargs.extend( + {} for _ in range(len(scattered_inputs) - len(scattered_kwargs)) + ) + return tuple(scattered_inputs), tuple(scattered_kwargs) + + +def gather(outputs: Any, target_device: int | torch.device, dim: int = 0) -> Any: + r"""Gather tensors from different GPUs on a specified device. + + This function is useful for gathering the results of a distributed computation. + It takes a sequence of objects, one for each GPU, and returns a single object + on the specified device. + + Args: + outputs (Any): A sequence of objects (potentially tensors) to gather. + target_device (Union[int, torch.device]): The device to gather the tensors to. + Use 'cpu' for CPU to avoid a deprecation warning. + dim (int, optional): The dimension along which to gather. Default: 0. + + Returns: + Any: A gathered object (potentially tensor) on the specified device. + """ + + def gather_map(outputs): + out = outputs[0] + if isinstance(out, torch.Tensor): + return Gather.apply(target_device, dim, *outputs) + if out is None: + return None + if isinstance(out, dict): + if not all(len(out) == len(d) for d in outputs): + raise ValueError("All dicts must have the same number of keys") + # pyrefly: ignore [not-callable] + return type(out)((k, gather_map([d[k] for d in outputs])) for k in out) + if _is_namedtuple(out): + # pyrefly: ignore [no-matching-overload] + return type(out)._make(map(gather_map, zip(*outputs, strict=True))) + # pyrefly: ignore [no-matching-overload] + return type(out)(map(gather_map, zip(*outputs, strict=True))) + + # Recursive function calls like this create reference cycles. + # Setting the function to None clears the refcycle. + try: + res = gather_map(outputs) + finally: + gather_map = None # type: ignore[assignment] + return res diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..766b09382aa78e65aba915e4e6faf7979c500d1b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/__init__.py @@ -0,0 +1,19 @@ +# flake8: noqa: F401 +r"""QAT Dynamic Modules. + +This package is in the process of being deprecated. +Please, use `torch.ao.nn.qat.dynamic` instead. +""" + +from torch.nn.qat import dynamic, modules # noqa: F403 +from torch.nn.qat.modules import * # noqa: F403 + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "Embedding", + "EmbeddingBag", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fbe59c98185053e32c9547b3c01d6ce1a2aaa2d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56838a1cfcae74c171479a717ff4e80ae7c53d71 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/__init__.py @@ -0,0 +1,8 @@ +# flake8: noqa: F401 +r"""QAT Dynamic Modules. + +This package is in the process of being deprecated. +Please, use `torch.ao.nn.qat.dynamic` instead. +""" + +from torch.nn.qat.dynamic.modules import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3157ebf745a4be74c89c430f45e5929dc37c4c07 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c3790f21256074a9a158d3ad15fa2ee0044fd0e6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/modules/__init__.py @@ -0,0 +1,4 @@ +from torch.nn.qat.dynamic.modules.linear import Linear + + +__all__ = ["Linear"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..906f291438c3238d5d86ef70befdf28727d215d6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/modules/__pycache__/linear.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a45ad97e4934b94fc4987be6caeaec624b8bf579 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/modules/linear.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5c80ea213c6287d2134e5344ddb19e538d0d5c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/dynamic/modules/linear.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""QAT Modules. + +This file is in the process of migration to `torch/ao/nn/qat/dynamic`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/qat/dynamic/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.qat.dynamic.modules.linear import Linear diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7f55fbdf789a7c845896f5b3bb0570ef74b5890 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__init__.py @@ -0,0 +1,21 @@ +# flake8: noqa: F401 +r"""QAT Modules. + +This package is in the process of being deprecated. +Please, use `torch.ao.nn.qat.modules` instead. +""" + +from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d +from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag +from torch.ao.nn.qat.modules.linear import Linear +from torch.nn.qat.modules import conv, embedding_ops, linear + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "Embedding", + "EmbeddingBag", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef275ec64972f249df4bb7f663b7be6c01b0c62b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__pycache__/conv.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0ce1cb7a2d31e87ccde0b7acebf239a0932f630 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__pycache__/embedding_ops.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__pycache__/embedding_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..310ee369571c6a809b8c76b4e6a8048b52c3d176 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__pycache__/embedding_ops.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__pycache__/linear.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7616abc6cb4e92d085015e461fb49f259f1ec63a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/conv.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..31aa101d8f60a5a5a1c61dae98d55db6f187a1c5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/conv.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""QAT Modules. + +This file is in the process of migration to `torch/ao/nn/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/qat/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/embedding_ops.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0964739f9e6ac96db8d3c8a7d156d371cc0725 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/embedding_ops.py @@ -0,0 +1,14 @@ +# flake8: noqa: F401 +r"""QAT Modules. + +This file is in the process of migration to `torch/ao/nn/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/qat/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag + + +__all__ = ["Embedding", "EmbeddingBag"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/linear.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..4e822eba7e0617f7d950f6b269398e729524a28f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/qat/modules/linear.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""QAT Modules. + +This file is in the process of migration to `torch/ao/nn/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/qat/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.qat.modules.linear import Linear diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7628c5c15992efa600ea5520aed955ba42c6146 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/__init__.py @@ -0,0 +1 @@ +from torch.nn.quantizable.modules import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cb4e6a986e302c2a977fdae9fb0b9687d9763f6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e0ca1743566174b7abfd3663dfa90b744ba56f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/__init__.py @@ -0,0 +1,9 @@ +from torch.ao.nn.quantizable.modules.activation import MultiheadAttention +from torch.ao.nn.quantizable.modules.rnn import LSTM, LSTMCell + + +__all__ = [ + "LSTM", + "LSTMCell", + "MultiheadAttention", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d90e1d60dcbbfc81849494855206e512f4c3c98 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..130574f30720e63cb0813de867aa1439654b4aec Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..202b912b0fe32d069058858e259e41dfadcfc29d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/activation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..28f3eee958115d05d161af8acf1b2308c02c3248 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/activation.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""Quantizable Modules. + +This file is in the process of migration to `torch/ao/nn/quantizable`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantizable/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantizable.modules.activation import MultiheadAttention diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/rnn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..8e355efcdd63d595580bdac22d46882b91b7d118 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantizable/modules/rnn.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""Quantizable Modules. + +This file is in the process of migration to `torch/ao/nn/quantizable`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantizable/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantizable.modules.rnn import LSTM, LSTMCell diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e2bbbc13202db1cbddaad4b05241a62190adc46 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/__init__.py @@ -0,0 +1,39 @@ +from torch.nn.quantized import dynamic, functional, modules # noqa: F403 +from torch.nn.quantized.modules import * # noqa: F403 +from torch.nn.quantized.modules import MaxPool2d + + +__all__ = [ + "BatchNorm2d", + "BatchNorm3d", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "DeQuantize", + "Dropout", + "ELU", + "Embedding", + "EmbeddingBag", + "GroupNorm", + "Hardswish", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LayerNorm", + "LeakyReLU", + "Linear", + "LSTM", + "MultiheadAttention", + "PReLU", + "Quantize", + "ReLU6", + "Sigmoid", + "Softmax", + # Wrapper modules + "FloatFunctional", + "FXFloatFunctional", + "QFunctional", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..460317478b4547963c8ad9557e8693f3b1a11471 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/__pycache__/functional.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/__pycache__/functional.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6956a441283772c5c841ad665846252011fe8f9b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/__pycache__/functional.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61faa90bd95cc7e255be2df82c617b5bab46b044 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/__init__.py @@ -0,0 +1 @@ +from torch.nn.quantized._reference.modules import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a01c594613e17db08fe3751d71beff6a9dfd7598 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9caa8e58f19393bf3c19421dbaa528b1f86b996 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__init__.py @@ -0,0 +1,39 @@ +# flake8: noqa: F401 +r"""Quantized Reference Modules. + +This module is in the process of migration to +`torch/ao/nn/quantized/reference`, and is kept here for +compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/reference`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.reference.modules.conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from torch.ao.nn.quantized.reference.modules.linear import Linear +from torch.ao.nn.quantized.reference.modules.rnn import GRUCell, LSTM, LSTMCell, RNNCell +from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "RNNCell", + "LSTMCell", + "GRUCell", + "LSTM", + "Embedding", + "EmbeddingBag", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2fe6a4f6cb0f5e8f991b42c3cf0e0dc4ee93d53 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7be9715fd7a9b8ddbbdd3a26ad639194d69be3ac Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f87abf6c55c1e4aa03a9cff54b212643bd0e0d4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e439d6899bc62f8950d379e3eca35f31eb14bc08 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..988f329881c391ad73e721efe30b88e8a54cd06c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aa34465cc655a466692b223da60302876213967 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/conv.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..ac806d0fa60d657f32603e7d761687f3c4a64215 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/conv.py @@ -0,0 +1,21 @@ +# flake8: noqa: F401 +r"""Quantized Reference Modules. + +This module is in the process of migration to +`torch/ao/nn/quantized/reference`, and is kept here for +compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/reference`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.reference.modules.conv import ( + _ConvNd, + _ConvTransposeNd, + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/linear.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..6be6d5a140bb58f76b0e6061eb4ccb37d385757f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/linear.py @@ -0,0 +1,12 @@ +# flake8: noqa: F401 +r"""Quantized Reference Modules. + +This module is in the process of migration to +`torch/ao/nn/quantized/reference`, and is kept here for +compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/reference`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.reference.modules.linear import Linear diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/rnn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..0573b3309b64b92ef3bb59200117b4ef5a62680b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/rnn.py @@ -0,0 +1,19 @@ +# flake8: noqa: F401 +r"""Quantized Reference Modules. + +This module is in the process of migration to +`torch/ao/nn/quantized/reference`, and is kept here for +compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/reference`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.reference.modules.rnn import ( + GRUCell, + LSTM, + LSTMCell, + RNNBase, + RNNCell, + RNNCellBase, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/sparse.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..85bf997d478ae1d0b4631541433e01f8f2943633 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/sparse.py @@ -0,0 +1,12 @@ +# flake8: noqa: F401 +r"""Quantized Reference Modules. + +This module is in the process of migration to +`torch/ao/nn/quantized/reference`, and is kept here for +compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/reference`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.reference.modules.sparse import Embedding, EmbeddingBag diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a0f30f2fb92c2717cb9f794d69e1e905bff6ed --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/_reference/modules/utils.py @@ -0,0 +1,18 @@ +# flake8: noqa: F401 +r"""Quantized Reference Modules. + +This module is in the process of migration to +`torch/ao/nn/quantized/reference`, and is kept here for +compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/reference`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.reference.modules.utils import ( + _get_weight_qparam_keys, + _quantize_and_dequantize_weight, + _quantize_weight, + _save_weight_qparams, + ReferenceQuantizedModule, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b08cd1bc7149c5506db3a952fff488eb06749f5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from torch.ao.nn.quantized.dynamic import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c881c9fae32916e9519816908678e1f23cfc88bc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ae09e82c3bb85f754d554af6eb7ac36f29e56ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__init__.py @@ -0,0 +1,43 @@ +# flake8: noqa: F401 +r"""Quantized Dynamic Modules. + +This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, +and is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/dynamic`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.dynamic.modules import conv, linear, rnn +from torch.ao.nn.quantized.dynamic.modules.conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from torch.ao.nn.quantized.dynamic.modules.linear import Linear +from torch.ao.nn.quantized.dynamic.modules.rnn import ( + GRU, + GRUCell, + LSTM, + LSTMCell, + RNNCell, +) + + +__all__ = [ + "Linear", + "LSTM", + "GRU", + "LSTMCell", + "RNNCell", + "GRUCell", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d84df880075858ed31be8b76db4a655496b9b5d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac5ab7c137aaecd1639a07d737eaa47f329ae36b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5905c0730620039a851b8ce1901152f621c162c2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba29168f6b319fa00d22d50a44411738b5498a25 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/conv.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b81a68a88917cddd0fe4b6987bc8ff6eddef01 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/conv.py @@ -0,0 +1,28 @@ +# flake8: noqa: F401 +r"""Quantized Dynamic Modules. + +This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, +and is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/dynamic/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.dynamic.modules.conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/linear.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..b23fae2c06aa8c829b0aeab7d29a4c84ad5c35b7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/linear.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""Quantized Dynamic Modules. + +This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, +and is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/dynamic/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.dynamic.modules.linear import Linear diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/rnn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..d5ca396a2d44030f61c8fdf64c8fe4aed3e0cf5d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/dynamic/modules/rnn.py @@ -0,0 +1,34 @@ +# flake8: noqa: F401 +r"""Quantized Dynamic Modules. + +This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, +and is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/dynamic/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.dynamic.modules.rnn import ( + GRU, + GRUCell, + LSTM, + LSTMCell, + pack_weight_bias, + PackedParameter, + RNNBase, + RNNCell, + RNNCellBase, +) + + +__all__ = [ + "pack_weight_bias", + "PackedParameter", + "RNNBase", + "LSTM", + "GRU", + "RNNCellBase", + "RNNCell", + "LSTMCell", + "GRUCell", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/functional.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..d763e171fdb432c8ba2059cc2332e7ac6424854a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/functional.py @@ -0,0 +1,10 @@ +r"""nn.quantized.functional. + +Quantized equivalents of the `nn.functional`. + +Note:: + This location is in the process of being deprecated. + Please, use the `torch.ao.nn.quantized.functional` instead. +""" + +from torch.ao.nn.quantized.functional import * # noqa: F401,F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae76d1968b0faaf30f861ab009b9011ce2960cc5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__init__.py @@ -0,0 +1,97 @@ +r"""Quantized Modules. + +Note:: + The `torch.nn.quantized` namespace is in the process of being deprecated. + Please, use `torch.ao.nn.quantized` instead. +""" + +# The following imports are needed in case the user decides +# to import the files directly, +# s.a. `from torch.nn.quantized.modules.conv import ...`. +# No need to add them to the `__all__`. +from torch.ao.nn.quantized.modules import ( + activation, + batchnorm, + conv, + DeQuantize, + dropout, + embedding_ops, + functional_modules, + linear, + MaxPool2d, + normalization, + Quantize, + rnn, + utils, +) +from torch.ao.nn.quantized.modules.activation import ( + ELU, + Hardswish, + LeakyReLU, + MultiheadAttention, + PReLU, + ReLU6, + Sigmoid, + Softmax, +) +from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d +from torch.ao.nn.quantized.modules.conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from torch.ao.nn.quantized.modules.dropout import Dropout +from torch.ao.nn.quantized.modules.embedding_ops import Embedding, EmbeddingBag +from torch.ao.nn.quantized.modules.functional_modules import ( + FloatFunctional, + FXFloatFunctional, + QFunctional, +) +from torch.ao.nn.quantized.modules.linear import Linear +from torch.ao.nn.quantized.modules.normalization import ( + GroupNorm, + InstanceNorm1d, + InstanceNorm2d, + InstanceNorm3d, + LayerNorm, +) +from torch.ao.nn.quantized.modules.rnn import LSTM + + +__all__ = [ + "BatchNorm2d", + "BatchNorm3d", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "DeQuantize", + "ELU", + "Embedding", + "EmbeddingBag", + "GroupNorm", + "Hardswish", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LayerNorm", + "LeakyReLU", + "Linear", + "LSTM", + "MultiheadAttention", + "Quantize", + "ReLU6", + "Sigmoid", + "Softmax", + "Dropout", + "PReLU", + # Wrapper modules + "FloatFunctional", + "FXFloatFunctional", + "QFunctional", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84615b0cd9b9e743dffea36f846653ae63c237df Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/activation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/activation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8302cae1f961b39b9af2ad46499b1f3bf4bafec5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/activation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/batchnorm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/batchnorm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5308a61d6c148d227903547b8ae47fe2a62e8ec1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/batchnorm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/conv.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1941875f09cb5fdb8dfa2eacd1df008f67f246c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/dropout.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/dropout.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c09b04fbba10c930f43cb0e4724ed99448792fd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/dropout.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/embedding_ops.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/embedding_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ebfb799021c4f58152fcac0772d40ba83ba00ff Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/embedding_ops.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/functional_modules.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/functional_modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..231600dbe38c7cd5cc854730622da5c5bc0aff4c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/functional_modules.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/linear.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9d649943a85674078f40c770d4162200612929f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/normalization.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/normalization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ef3884e4a676a86f53e7e468029e11dbc30a18b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/normalization.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/rnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e94d0d984b7e655482cc666735308d4c6c16bab6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/rnn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c5a57d81e5e938f4d1cd7d2da01a14eabc5544a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/activation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..d85162ef35c7cd6a399f2a73a9a6b8f3c1154cd9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/activation.py @@ -0,0 +1,20 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.activation import ( + ELU, + Hardswish, + LeakyReLU, + MultiheadAttention, + PReLU, + ReLU6, + Sigmoid, + Softmax, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/batchnorm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..8489cdb596ef44ce15d79530c80a9c7ea512e975 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/batchnorm.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/conv.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9c77b534ff6f6bd28466dda1c16ed219e48c1d73 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/conv.py @@ -0,0 +1,29 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.conv import ( + _reverse_repeat_padding, + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/dropout.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..32a7a22d558670cc4ae9a963240badd314ed6d5c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/dropout.py @@ -0,0 +1,14 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.dropout import Dropout + + +__all__ = ["Dropout"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/embedding_ops.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d25f8bea7e378023a8eb3ece75a5fb9a23163529 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/embedding_ops.py @@ -0,0 +1,18 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.embedding_ops import ( + Embedding, + EmbeddingBag, + EmbeddingPackedParams, +) + + +__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/functional_modules.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/functional_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..efe1b38ce3ea4adbae55595d86c2787d7c1f7284 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/functional_modules.py @@ -0,0 +1,18 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.functional_modules import ( + FloatFunctional, + FXFloatFunctional, + QFunctional, +) + + +__all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/linear.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ba5a5c12f82915db53d81a7b9e5a1c0e530e98 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/linear.py @@ -0,0 +1,14 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.linear import Linear, LinearPackedParams + + +__all__ = ["LinearPackedParams", "Linear"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/normalization.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..85462cc365344b004c91ff9c02879477d50041f5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/normalization.py @@ -0,0 +1,26 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.normalization import ( + GroupNorm, + InstanceNorm1d, + InstanceNorm2d, + InstanceNorm3d, + LayerNorm, +) + + +__all__ = [ + "LayerNorm", + "GroupNorm", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/rnn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a0076d13bc4e3ee29e9b3e410171d20e8e9a65 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/rnn.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.rnn import LSTM diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ea333af04ca49138a3b3ed35020654d4dad5ffe9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/quantized/modules/utils.py @@ -0,0 +1,17 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.utils import ( + _hide_packed_params_repr, + _ntuple_from_first, + _pair_from_first, + _quantize_weight, + WeightedQuantizedModule, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9253264d1e0eaf7fef1ee4ada06d2bf0be5cda7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__init__.py @@ -0,0 +1,48 @@ +from . import parametrizations, parametrize, rnn, stateless +from .clip_grad import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated] + _clip_grads_with_norm_ as clip_grads_with_norm_, + _get_total_norm as get_total_norm, + clip_grad_norm, + clip_grad_norm_, + clip_grad_value_, +) +from .convert_parameters import parameters_to_vector, vector_to_parameters +from .fusion import ( + fuse_conv_bn_eval, + fuse_conv_bn_weights, + fuse_linear_bn_eval, + fuse_linear_bn_weights, +) +from .init import skip_init +from .memory_format import ( + convert_conv2d_weight_memory_format, + convert_conv3d_weight_memory_format, +) +from .spectral_norm import remove_spectral_norm, spectral_norm +from .weight_norm import remove_weight_norm, weight_norm + + +__all__ = [ + "clip_grad_norm", + "clip_grad_norm_", + "clip_grads_with_norm_", + "clip_grad_value_", + "convert_conv2d_weight_memory_format", + "convert_conv3d_weight_memory_format", + "fuse_conv_bn_eval", + "fuse_conv_bn_weights", + "fuse_linear_bn_eval", + "fuse_linear_bn_weights", + "get_total_norm", + "parameters_to_vector", + "parametrizations", + "parametrize", + "remove_spectral_norm", + "remove_weight_norm", + "rnn", + "skip_init", + "spectral_norm", + "stateless", + "vector_to_parameters", + "weight_norm", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83e851b3ef19266f3ad275a8739dba891692ccea Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32b644459da415b2b7bc2b811125983948d2af9a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..319df68269c73c9dd7a9be996550e4143ac1764e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fa8a25148f7b59f2d17eeb3b6969ab43f44cb2e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1326ebe839765a7ead67728c3113d7b5880fac72 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..358ce0515876a28173f09036debf26af0cf2a9a7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/fusion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/fusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c207a866ec509888b74ae98257edf72445f15cd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/fusion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/init.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/init.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..154a596889c4d34f865db4d47504271bc51a18ba Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/init.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab49380d143eb64fbb95eb434aff574089d107dd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad2989111ec17a68a0505cfe6fe6b0c5461e8b69 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01ed13433d8c30f2299a6b9be52884467acbc780 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/prune.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/prune.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2133e8fc63d0123a9592e642ee792e481ecc9ae Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/prune.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/rnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23d18545fc280f00d483e524ddf7aff06d783c72 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/rnn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb2e1ae2fe9c30518fb6ac02cb608b7b5228b14d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/stateless.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/stateless.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22e2e7d91fbdc2441b3ba0cee000951c2b2d7ce6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/stateless.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..714cc1840fc15df647bcf0b98b238c91909593c9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_deprecation_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_deprecation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a25b647307900e42b11d1cdafc8d9f8785d1a620 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_deprecation_utils.py @@ -0,0 +1,53 @@ +import importlib +import warnings +from collections.abc import Callable + + +_MESSAGE_TEMPLATE = ( + r"Usage of '{old_location}' is deprecated; please use '{new_location}' instead." +) + + +def lazy_deprecated_import( + all: list[str], + old_module: str, + new_module: str, +) -> Callable: + r"""Import utility to lazily import deprecated packages / modules / functional. + + The old_module and new_module are also used in the deprecation warning defined + by the `_MESSAGE_TEMPLATE`. + + Args: + all: The list of the functions that are imported. Generally, the module's + __all__ list of the module. + old_module: Old module location + new_module: New module location / Migrated location + + Returns: + Callable to assign to the `__getattr__` + + Usage: + + # In the `torch/nn/quantized/functional.py` + from torch.nn.utils._deprecation_utils import lazy_deprecated_import + _MIGRATED_TO = "torch.ao.nn.quantized.functional" + __getattr__ = lazy_deprecated_import( + all=__all__, + old_module=__name__, + new_module=_MIGRATED_TO) + """ + warning_message = _MESSAGE_TEMPLATE.format( + old_location=old_module, new_location=new_module + ) + + def getattr_dunder(name: str) -> None: + if name in all: + # We are using the "RuntimeWarning" to make sure it is not + # ignored by default. + warnings.warn(warning_message, RuntimeWarning, stacklevel=2) + package = importlib.import_module(new_module) + return getattr(package, name) + raise AttributeError(f"Module {new_module!r} has no attribute {name!r}.") + + return getattr_dunder diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0eaf86bdbeacfe7d4e7cbd50daf11385955d7d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__init__.py @@ -0,0 +1,10 @@ +from .conv_expanded_weights import ConvPerSampleGrad +from .embedding_expanded_weights import EmbeddingPerSampleGrad +from .expanded_weights_impl import ExpandedWeight +from .group_norm_expanded_weights import GroupNormPerSampleGrad +from .instance_norm_expanded_weights import InstanceNormPerSampleGrad +from .layer_norm_expanded_weights import LayerNormPerSampleGrad +from .linear_expanded_weights import LinearPerSampleGrad + + +__all__ = ["ExpandedWeight"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f89bf4d32dd263673070cb023972280737d03ffa Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96604c9444e7d008cee41923ea7068de0f7d1cf2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..519212bf69b5b698fa198db33be3ccf36f31e86a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b19d199e2069b7b2ef68470823432cfc8721ebd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41452620c9b0d258c4e1edcf7a288164a7e02e9b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95f1750430df7614436cb12ccd605220ec966210 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/group_norm_expanded_weights.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/group_norm_expanded_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d43806308d0c29abb593409e48d2708153802d35 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/group_norm_expanded_weights.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/instance_norm_expanded_weights.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/instance_norm_expanded_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f3bc2be784cd7949ee4118d4aa15d4faa4deb7a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/instance_norm_expanded_weights.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/layer_norm_expanded_weights.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/layer_norm_expanded_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b037fc3039820e309d3e1c2769f2d9f3e8990dd6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/layer_norm_expanded_weights.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f277c27ba1e5b3b7d7a8f58e667fbc7eba62776 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/conv_expanded_weights.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/conv_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..da7d8f3dfabb8c82f01d76c27ffbe1d3974473cf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/conv_expanded_weights.py @@ -0,0 +1,82 @@ +from collections.abc import Callable +from typing import Any, TypeVar +from typing_extensions import ParamSpec + +import torch +import torch.nn.functional as F + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +from .conv_utils import ( + conv_args_and_kwargs, + conv_backward, + conv_input_for_string_padding, + conv_picker, +) +from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads +from .expanded_weights_utils import forward_helper + + +@implements_per_sample_grads(F.conv1d) +@implements_per_sample_grads(F.conv2d) +@implements_per_sample_grads(F.conv3d) +class ConvPerSampleGrad(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward( + ctx: Any, + kwarg_names: list[str], + conv_fn: Callable[_P, _R], + *expanded_args_and_kwargs: Any, + ) -> torch.Tensor: + expanded_args, expanded_kwargs = conv_args_and_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + orig_input = expanded_args[0] + was_same_padding = expanded_kwargs["padding"] == "same" + + if isinstance(expanded_kwargs["padding"], str): + # if padding is a string, we'll do the necessary padding (slowly) using F.pad + kernel_size = expanded_args[1].shape[2:] + padding, dilation = expanded_kwargs["padding"], expanded_kwargs["dilation"] + input = conv_input_for_string_padding( + conv_fn, padding, expanded_args[0], dilation, kernel_size + ) + expanded_args = (input, expanded_args[1]) + # since we've already done the padding, don't need any more + expanded_kwargs["padding"] = 0 + + output = forward_helper(conv_fn, expanded_args, expanded_kwargs) + input, weight = expanded_args + batched_dim_size = conv_picker(conv_fn, 3, 4, 5) + if input.dim() != batched_dim_size: + raise RuntimeError( + f"Expanded Weights only support convolution with batched input, got {conv_fn} with an" + f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}" + ) + + # pyrefly: ignore [invalid-type-var] + ctx.conv_fn = conv_fn + + ctx.batch_size = orig_input.shape[0] + ctx.input_required_grad = orig_input.requires_grad + ctx.orig_input_shape = orig_input.shape + ctx.was_same_padding = was_same_padding + ctx.stride, ctx.padding = expanded_kwargs["stride"], expanded_kwargs["padding"] + ctx.dilation, ctx.groups = ( + expanded_kwargs["dilation"], + expanded_kwargs["groups"], + ) + + if isinstance(weight, ExpandedWeight): + ctx.input = input + ctx.weight = weight + ctx.bias = expanded_kwargs["bias"] + + return output + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + return conv_backward(ctx.conv_fn, ctx, grad_outputs[0]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/conv_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/conv_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f3444ed64d3ca580a321eeebdc4c5a8e442c49ae --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/conv_utils.py @@ -0,0 +1,354 @@ +# mypy: allow-untyped-defs + +import torch +import torch.nn.functional as F + +from .expanded_weights_utils import ( + set_grad_sample_if_exists, + unpack_expanded_weight_or_tensor, +) + + +THRESHOLD = 32 + + +def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt): + if func is F.conv1d: + return conv1dOpt + if func is F.conv2d: + return conv2dOpt + else: + assert func is F.conv3d + return conv3dOpt + + +def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs): + args = expanded_args_and_kwargs[: len(expanded_args_and_kwargs) - len(kwarg_names)] + kwargs = expanded_args_and_kwargs[ + len(expanded_args_and_kwargs) - len(kwarg_names) : + ] + kwargs = dict(zip(kwarg_names, kwargs, strict=True)) + + return conv_normalizer(*args, **kwargs) + + +def conv_normalizer( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, +): + return (input, weight), { + "bias": bias, + "stride": stride, + "padding": padding, + "dilation": dilation, + "groups": groups, + } + + +def conv_input_for_string_padding(func, padding_style, input, dilation, kernel_size): + if padding_style == "valid": + return input + else: + padding = int_padding_for_string_padding( + func, padding_style, dilation, kernel_size + ) + return F.pad(input, padding) + + +def int_padding_for_string_padding(func, padding_style, dilation, kernel_size): + def get_dilation(i): + return dilation[i] if isinstance(dilation, tuple) else dilation + + if padding_style == "same": + padding: list[int] = [] + # F.pad needs the padding in reverse order from what conv expects + for i in range(conv_picker(func, 0, 1, 2), -1, -1): + padding += conv_padding_for_same(get_dilation(i), kernel_size[i]) + return padding + elif padding_style == "valid": + return conv_picker(func, 2, 4, 6) * (0,) + else: + raise RuntimeError( + f"got padding type of {padding_style}, only accept 'same' or 'valid'" + ) + + +def conv_padding_for_same(dilation, kernel_size): + total_pad = dilation * (kernel_size - 1) + left_pad = total_pad // 2 + right_pad = total_pad - left_pad + return left_pad, right_pad + + +def conv_backward(func, ctx, grad_output): + def weight_grad_sample(weight): + if batch_size < THRESHOLD and groups == 1: + return conv_group_weight_grad_sample( + ctx.input, + grad_output, + weight_shape, + stride, + padding, + dilation, + batch_size, + func, + ) + else: + return conv_unfold_weight_grad_sample( + ctx.input, + grad_output, + weight_shape, + kernel_size, + stride, + padding, + dilation, + groups, + func, + ) + + def expand(param): + if isinstance(param, int): + return conv_picker(func, (param,), (param, param), (param, param, param)) + else: + return param + + def calc_total_padding(func, was_same, padding, dilation, kernel_size): + if was_same: + all_padding = int_padding_for_string_padding( + func, "same", dilation, kernel_size + ) + # F.pad needs the padding in reverse order from what conv expects + total_padding = tuple( + all_padding[i] + all_padding[i - 1] + for i in range(len(all_padding) - 1, -1, -2) + ) + return total_padding + else: + return tuple(2 * pad for pad in padding) + + weight_shape = ctx.weight.shape + stride, padding, dilation, groups = ( + expand(ctx.stride), + expand(ctx.padding), + expand(ctx.dilation), + ctx.groups, + ) + + kernel_size = [weight_shape[i] for i in range(2, conv_picker(func, 3, 4, 5))] + + batch_size = ctx.batch_size + results: list[torch.Tensor | None] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + + # "same" padding may give uneven padding on either side so we need to separate the "padding" attr and total padding + total_padding = calc_total_padding( + func, ctx.was_same_padding, padding, dilation, kernel_size + ) + + if ctx.input_required_grad: + output_padding = [] + input_dims = conv_picker(func, 1, 2, 3) + for i in range(input_dims): + input_dim = ctx.orig_input_shape[2 + i] + output_padding.append( + ( + total_padding[i] + + input_dim + - (kernel_size[i] * dilation[i] - dilation[i] + 1) + ) + % stride[i] + ) + weight_ = unpack_expanded_weight_or_tensor(ctx.weight) + transpose_func = conv_picker( + func, F.conv_transpose1d, F.conv_transpose2d, F.conv_transpose3d + ) + out = transpose_func( + grad_output, + weight_, + None, + stride, + padding, + tuple(output_padding), + groups, + dilation, + ) + + if ctx.was_same_padding: + for i in range(len(total_padding)): + out = torch.narrow( + out, 2 + i, total_padding[i] // 2, ctx.orig_input_shape[2 + i] + ) + + results.append(out) + else: + results.append(None) + # weight and bias don't compute batched gradients; no other arguments are differentiable + results = results + [None] * 6 + + # set grad_sample field for weight and bias with per sample gradients + set_grad_sample_if_exists(ctx.weight, weight_grad_sample) + set_grad_sample_if_exists( + ctx.bias, lambda _: grad_output.reshape(*grad_output.shape[:2], -1).sum(dim=2) + ) + return tuple(results) + + +def conv_unfold_weight_grad_sample( + input, + grad_output, + weight_shape, + kernel_size, + stride, + padding, + dilation, + groups, + func, +): + import numpy as np + + n = input.shape[0] + in_channels = input.shape[1] + + unfold_func = conv_picker( + func, + lambda: F.unfold( + input.unsqueeze(-2), + kernel_size=(1, kernel_size[0]), + dilation=(1, dilation[0]), + padding=(0, padding[0]), + stride=(1, stride[0]), + ), + lambda: F.unfold( + input, kernel_size, dilation=dilation, padding=padding, stride=stride + ), + lambda: unfold3d(input, kernel_size, padding, stride, dilation), + ) + + input = unfold_func() + grad_output = grad_output.reshape(n, -1, input.shape[-1]) + + # n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz + weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input) + # rearrange the above tensor and extract diagonals. + # pyrefly: ignore [no-matching-overload] + weight_grad_sample = weight_grad_sample.view( + n, + groups, + -1, + groups, + int(in_channels / groups), + np.prod(kernel_size), + ) + weight_grad_sample = torch.einsum( + "ngrg...->ngr...", weight_grad_sample + ).contiguous() + shape = [n] + list(weight_shape) + weight_grad_sample = weight_grad_sample.view(shape) + return weight_grad_sample + + +def conv_group_weight_grad_sample( + input, + grad_output, + weight_shape, + stride, + padding, + dilation, + batch_size, + func, +): + I = input.shape[1] + O = grad_output.shape[1] + + input_ = input.transpose(0, 1) + grad_output_ = grad_output.view( + grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:] + ) + + weight_grad_sample = func( + input_, + grad_output_, + None, + stride=dilation, + padding=padding, + dilation=stride, + groups=batch_size, + ) + input_dims = conv_picker(func, 3, 4, 5) + for i in range(2, input_dims): + weight_grad_sample = weight_grad_sample.narrow(i, 0, weight_shape[i]) + weight_grad_sample = weight_grad_sample.view( + I, batch_size, O, *weight_grad_sample.shape[2:] + ) + weight_grad_sample = weight_grad_sample.movedim(0, 2) + return weight_grad_sample + + +def unfold3d( + tensor, + kernel_size, + padding, + stride, + dilation, +): + r""" + Extract sliding local blocks from an batched input tensor. + + :class:`torch.nn.Unfold` only supports 4D inputs (batched image-like tensors). + This method implements the same action for 5D inputs + Args: + tensor: An input tensor of shape ``(B, C, D, H, W)``. + kernel_size: the size of the sliding blocks + padding: implicit zero padding to be added on both sides of input + stride: the stride of the sliding blocks in the input spatial dimensions + dilation: the spacing between the kernel points. + Returns: + A tensor of shape ``(B, C * np.prod(kernel_size), L)``, where L - output spatial dimensions. + See :class:`torch.nn.Unfold` for more details + Example: + >>> # xdoctest: +SKIP + >>> B, C, D, H, W = 3, 4, 5, 6, 7 + >>> tensor = torch.arange(1, B * C * D * H * W + 1.0).view(B, C, D, H, W) + >>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape + torch.Size([3, 32, 120]) + """ + + import numpy as np + + if len(tensor.shape) != 5: + raise ValueError( + f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}" + ) + + if dilation != (1, 1, 1): + raise NotImplementedError(f"dilation={dilation} not supported.") + + batch_size, channels, _, _, _ = tensor.shape + + # Input shape: (B, C, D, H, W) + tensor = F.pad( + tensor, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]) + ) + # Output shape: (B, C, D+2*padding[2], H+2*padding[1], W+2*padding[0]) + + tensor = tensor.unfold(dimension=2, size=kernel_size[0], step=stride[0]) + tensor = tensor.unfold(dimension=3, size=kernel_size[1], step=stride[1]) + tensor = tensor.unfold(dimension=4, size=kernel_size[2], step=stride[2]) + # Output shape: (B, C, D_out, H_out, W_out, kernel_size[0], kernel_size[1], kernel_size[2]) + # For D_out, H_out, W_out definitions see :class:`torch.nn.Unfold` + + tensor = tensor.permute(0, 2, 3, 4, 1, 5, 6, 7) + # Output shape: (B, D_out, H_out, W_out, C, kernel_size[0], kernel_size[1], kernel_size[2]) + + tensor = tensor.reshape(batch_size, -1, channels * np.prod(kernel_size)).transpose( + 1, 2 + ) + # Output shape: (B, D_out * H_out * W_out, C * kernel_size[0] * kernel_size[1] * kernel_size[2] + + return tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..74350b88b5407a01dac92270f8471cc0e37a99c5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py @@ -0,0 +1,88 @@ +from typing import Any + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + set_grad_sample_if_exists, + standard_kwargs, +) + + +@implements_per_sample_grads(F.embedding) +class EmbeddingPerSampleGrad(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward( + ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any + ) -> torch.Tensor: + expanded_args, expanded_kwargs = standard_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + if len(expanded_args[0].shape) == 1: + raise RuntimeError( + f"Expanded Weights needs an input with a batch size, got a 1D tensor, {expanded_args[0]}" + ) + output = forward_helper(F.embedding, expanded_args, expanded_kwargs) + ctx.input, ctx.weight = expanded_args + ctx.padding_idx, ctx.scale_grad_by_freq = ( + expanded_kwargs["padding_idx"], + expanded_kwargs["scale_grad_by_freq"], + ) + ctx.sparse = expanded_kwargs["sparse"] + return output + + @staticmethod + # pyrefly: ignore [bad-override] + def backward( + ctx: Any, grad_output: torch.Tensor + ) -> tuple[torch.Tensor | None, ...]: + input, weight = ctx.input, ctx.weight + padding_idx, scale_grad_by_freq, sparse = ( + ctx.padding_idx, + ctx.scale_grad_by_freq, + ctx.sparse, + ) + + def weight_per_sample_grad(weight: torch.Tensor) -> torch.Tensor: + batch_size = input.shape[0] + embedding_dim = weight.shape[1] + index = ( + input.unsqueeze(-1) + .expand(*input.shape, embedding_dim) + .reshape(batch_size, -1, embedding_dim) + ) + grad_sample = torch.zeros( # type: ignore[attr-defined] + batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype + ) + return grad_sample.scatter_add_( + 1, index, grad_output.reshape(batch_size, -1, embedding_dim) + ) + + results: list[torch.Tensor | None] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + + if input.requires_grad: + bw_fn = torch.ops.aten.embedding_backward + results.append( + bw_fn( + grad_output, + input, + weight.shape[0], + padding_idx, + scale_grad_by_freq, + sparse, + ) + ) + else: + results.append(None) + + # weight doesn't compute batched gradients; no other arguments are differentiable (2 not saved from forward) + results = results + [None] * 6 + + # set grad_sample field for weight with per sample gradients + set_grad_sample_if_exists(weight, weight_per_sample_grad) + return tuple(results) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..58ef67e06148a0ee3e2c493d7071ff55f183f06b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -0,0 +1,186 @@ +# mypy: allow-untyped-defs +import functools +from collections.abc import Callable +from contextlib import contextmanager + +import torch +from torch._decomp import decomposition_table +from torch.utils._pytree import tree_map_only + + +HANDLED_FUNCTIONS: dict[Callable, torch.autograd.Function] = {} + +aten = torch._ops.ops.aten +# __torch_function__ runs before the pydispatcher so we need to manually use the same +# decompositions indexed by their torch equivalent +expanded_weights_rnn_decomps = { + # func: (input_decomp, data_decomp) + torch.rnn_relu: ( + decomposition_table[aten.rnn_relu.input], + decomposition_table[aten.rnn_relu.data], + ), + torch.rnn_tanh: ( + decomposition_table[aten.rnn_tanh.input], + decomposition_table[aten.rnn_tanh.data], + ), + torch.lstm: ( + decomposition_table[aten.lstm.input], + decomposition_table[aten.lstm.data], + ), + torch.gru: ( + decomposition_table[aten.gru.input], + decomposition_table[aten.gru.data], + ), +} + + +# all of the RNN decomps run linear with the batch dimension second, even if batch_first was set +@contextmanager +def batch_second(args, kwargs): + def set_batch_second(ew) -> None: + ew.set_batch_first(False) + + def reset_batch_first(ew) -> None: + ew.set_batch_first(True) + + tree_map_only(ExpandedWeight, set_batch_second, args) + tree_map_only(ExpandedWeight, set_batch_second, kwargs) + try: + yield + finally: + tree_map_only(ExpandedWeight, reset_batch_first, args) + tree_map_only(ExpandedWeight, reset_batch_first, kwargs) + + +# to support packed sequences, we need to allow for smaller batches. Expanded weights represents the largest batch +@contextmanager +def allow_smaller_batches(args, kwargs): + def allow(ew) -> None: + ew.set_allow_smaller_batches(True) + + def reset(ew) -> None: + ew.set_allow_smaller_batches(False) + + tree_map_only(ExpandedWeight, allow, args) + tree_map_only(ExpandedWeight, allow, kwargs) + try: + yield + finally: + tree_map_only(ExpandedWeight, reset, args) + tree_map_only(ExpandedWeight, reset, kwargs) + + +@contextmanager +def setup_rnn(use_input_variant, args, kwargs): + with ( + batch_second(args, kwargs) + if use_input_variant + else allow_smaller_batches(args, kwargs) + ): + yield + + +def implements_per_sample_grads(torch_function): + @functools.wraps(torch_function) + def decorator(autograd_func): + HANDLED_FUNCTIONS[torch_function] = autograd_func + return autograd_func + + return decorator + + +# ExpandedWeight represents a weight (parameter) Tensor that has an expanded +# batch dimension. Operations on the ExpandedWeight Tensor act exactly like +# those without an expanded batch dimension but a call to .backward() populates +# the original (unexpanded) tensor with per-sample-gradients for in the grad_sample field +# +# ExpandedWeight has a fallback that always fails since we cannot know what the batch +# dimension of the input tensor is and therefore cannot know if this is a valid call +# +# This is a __torch_function__ object but it could have also been a Tensor Extension +# with a dispatch key. +# +# Needs to be a tensor subclass to allow reparameterization +class ExpandedWeight(torch.Tensor): + def __init__(self, orig_weight, batch_size, loss_reduction) -> None: + self.batch_size = batch_size + self.batch_first = True + self.allow_smaller_batches = False + self.orig_weight = orig_weight + self.loss_reduction = loss_reduction + + handled_functions = HANDLED_FUNCTIONS + + def __new__(cls, orig_weight, batch_size, loss_reduction): + if not isinstance(orig_weight, torch.Tensor): + raise RuntimeError( + f"Can only make Expanded Weights of Tensors, got {type(orig_weight).__name__}" + ) + if not orig_weight.requires_grad: + raise RuntimeError( + "Can only build ExpandedWeights objects of tensors that require_grad" + ) + ret = torch.Tensor._make_subclass(cls, orig_weight, True) + return ret + + @classmethod + def __torch_function__(cls, func, _, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func in expanded_weights_rnn_decomps: + # in aten, choosing the input or data variants is done by parsing logic. This mimics some of that + decomp_opts = expanded_weights_rnn_decomps[func] + use_input_variant = isinstance( + # pyrefly: ignore [index-error] + args[2], + list, + ) # data variant uses a list here + decomp = decomp_opts[0] if use_input_variant else decomp_opts[1] + + if decomp is not None: + with setup_rnn(use_input_variant, args, kwargs): + return decomp(*args, **kwargs) + if func is torch._cudnn_rnn_flatten_weight: + # since we aren't using the fused cuda kernels for RNNs, don't do this + return + if func in cls.handled_functions: + return cls.handled_functions[func].apply( + tuple(kwargs.keys()), func, *(args + tuple(kwargs.values())) + ) + # We cannot use a fallback here because we do not know the batch dimension for any regular tensor inputs, + # i.e. torch.add(torch.Tensor, ExpandedWeight) + raise RuntimeError( + f"Expanded Weights encountered but cannot handle function {func.__name__}" + ) + + @property + def dtype(self): # type: ignore[override] + return self.orig_weight.dtype + + @property + def data(self): # type: ignore[override] + return self.orig_weight.data + + @property + def shape(self): # type: ignore[override] + return self.orig_weight.shape + + @property + def device(self): # type: ignore[override] + return self.orig_weight.device + + @property + def is_cuda(self): # type: ignore[override] + return self.orig_weight.is_cuda + + def data_ptr(self): + return self.orig_weight.data_ptr() + + def get_device(self): + return self.orig_weight.get_device() + + def set_allow_smaller_batches(self, is_allow_smaller_batches) -> None: + self.allow_smaller_batches = is_allow_smaller_batches + + def set_batch_first(self, is_batch_first=True) -> None: + self.batch_first = is_batch_first diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32f24cb4f5d04b288ab93827cc7934ace907eb0e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -0,0 +1,188 @@ +# mypy: allow-untyped-defs + +import torch + +from .expanded_weights_impl import ExpandedWeight + + +def is_batch_first(expanded_args_and_kwargs): + batch_first = None + # pyrefly: ignore [bad-assignment] + for arg in expanded_args_and_kwargs: + if not isinstance(arg, ExpandedWeight): + continue + + if not batch_first: + batch_first = arg.batch_first + elif arg.batch_first != batch_first: + raise RuntimeError( + "Got conflicting batch_first arguments in the same layer" + ) + return batch_first + + +def standard_kwargs(kwarg_names, expanded_args): + r"""Separate args and kwargs from `__torch_function__`s that standardize kwargs. + + Most `__torch_function__`s standardize the kwargs that they give, so this will separate + the args and kwargs they pass. Functions that don't are linear and convND. + """ + kwarg_values = expanded_args[len(expanded_args) - len(kwarg_names) :] + expanded_args_without_kwargs = expanded_args[ + : len(expanded_args) - len(kwarg_names) + ] + expanded_kwargs = dict(zip(kwarg_names, kwarg_values, strict=True)) + return expanded_args_without_kwargs, expanded_kwargs + + +def forward_helper(func, expanded_args, expanded_kwargs): + r"""Compute the forward pass for a function that has expanded weight(s) passed to it. + + It will run the forward pass where all ExpandedWeights are their original + weight. It runs checks on the given arguments and detaches the outputs. + + .. note:: First argument in :attr:`expanded_args` must be the input with the batch + dimension as the first element of the shape + + .. note:: :attr:`func` must return a Tensor or tuple of Tensors + + Args: + func: The function to be called + expanded_args: Arguments to be passed to :attr:`func`. Will include arguments + that need to be unpacked because they are ExpandedWeights + expanded_kwargs: Keyword arguments to be passed to :attr:`func`. + Similar to :attr:`expanded_args`. + """ + unexpanded_args, unexpanded_kwargs = _check_and_unexpand_args( + func, expanded_args, expanded_kwargs + ) + return func(*unexpanded_args, **unexpanded_kwargs) + + +def _check_and_unexpand_args(func, expanded_args, expanded_kwargs): + # input must be the first argument passed + input = expanded_args[0] + if isinstance(input, ExpandedWeight): + raise RuntimeError( + "Expanded Weights do not support inputs that are also ExpandedWeights. " + f"Input must be a Tensor, got {type(input).__name__} in function {func.__name__}" + ) + if not isinstance(input, torch.Tensor): + raise RuntimeError( + "Expanded Weights requires a Tensor as the first input to get the batch dimension, " + f"got {type(input).__name__} in function {func.__name__}" + ) + if len(input.shape) == 0: + raise RuntimeError( + f"Expanded Weights requires a batch dimension but got an input of size 0 in function {func.__name__}" + ) + if input.shape[0] == 0: + raise RuntimeError( + "0 is not a valid batch size for Expanded Weights but got input tensor of " + f"{input} in function {func.__name__}" + ) + for arg in expanded_args + tuple(expanded_kwargs.values()): + if not isinstance(arg, ExpandedWeight): + continue + batch_size = input.shape[0] if arg.batch_first else input.shape[1] + if (arg.allow_smaller_batches and batch_size > arg.batch_size) or ( + not arg.allow_smaller_batches and arg.batch_size != batch_size + ): + raise RuntimeError( + "Expected ExpandedWeights to have batch size matching input but got " + f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}" + ) + + loss_reduction: str | None = None + for arg in expanded_args + tuple(expanded_kwargs.values()): + if isinstance(arg, ExpandedWeight): + if loss_reduction is None: + loss_reduction = arg.loss_reduction + elif loss_reduction != arg.loss_reduction: + raise RuntimeError( + "Expected ExpandedWeights to all have the same loss_reduction argument but got one" + f"with {loss_reduction} and one with {arg.loss_reduction}" + ) + + unexpanded_args = tuple( + arg.orig_weight if isinstance(arg, ExpandedWeight) else arg + for arg in expanded_args + ) + unexpanded_kwargs = { + name: arg.orig_weight if isinstance(arg, ExpandedWeight) else arg + for (name, arg) in expanded_kwargs.items() + } + return unexpanded_args, unexpanded_kwargs + + +def maybe_scale_by_batch_size(grad_sample, expanded_weight): + if expanded_weight.loss_reduction == "mean": + return grad_sample * expanded_weight.batch_size + else: + return grad_sample + + +def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn) -> None: + unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight) + if isinstance(maybe_expanded_weight, ExpandedWeight): + grad_sample_contribution = maybe_scale_by_batch_size( + per_sample_grad_fn(unpacked), maybe_expanded_weight + ) + + if maybe_expanded_weight.batch_size > grad_sample_contribution.shape[0]: + # this only passes the other checks if the arg allows smaller batch sizes + intermediate = torch.zeros( + maybe_expanded_weight.batch_size, + *grad_sample_contribution.shape[1:], + dtype=grad_sample_contribution.dtype, + device=grad_sample_contribution.device, + ) + intermediate[: grad_sample_contribution.shape[0]] = grad_sample_contribution + grad_sample_contribution = intermediate + + if hasattr(unpacked, "grad_sample") and unpacked.grad_sample is not None: + unpacked.grad_sample = unpacked.grad_sample + grad_sample_contribution + else: + unpacked.grad_sample = grad_sample_contribution + + +def unpack_expanded_weight_or_tensor(maybe_expanded_weight, func=lambda x: x): + if isinstance(maybe_expanded_weight, ExpandedWeight): + orig_weight = maybe_expanded_weight.orig_weight + return func(orig_weight) + elif ( + isinstance(maybe_expanded_weight, torch.Tensor) + and not maybe_expanded_weight.requires_grad + ): + return func(maybe_expanded_weight) + elif isinstance(maybe_expanded_weight, torch.Tensor): + raise RuntimeError( + "ExpandedWeights currently does not support a mixture of ExpandedWeight parameters " + "and normal Parameters. Please file and issue with pytorch/pytorch" + ) + + +def sum_over_all_but_batch_and_last_n( + tensor: torch.Tensor, + n_dims: int, +) -> torch.Tensor: + r""" + Calculate the sum over all dimensions, except the first (batch dimension), and excluding the last n_dims. + + This function will ignore the first dimension and it will + not aggregate over the last n_dims dimensions. + Args: + tensor: An input tensor of shape ``(B, ..., X[n_dims-1])``. + n_dims: Number of dimensions to keep. + Example: + >>> tensor = torch.ones(1, 2, 3, 4, 5) + >>> sum_over_all_but_batch_and_last_n(tensor, n_dims=2).shape + torch.Size([1, 4, 5]) + Returns: + A tensor of shape ``(B, ..., X[n_dims-1])`` + """ + if tensor.dim() == n_dims + 1: + return tensor + else: + dims = list(range(1, tensor.dim() - n_dims)) + return tensor.sum(dim=dims) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..373222c2f049af31131594d0322c8ea835d9f2d3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py @@ -0,0 +1,107 @@ +# mypy: allow-untyped-defs +import operator +from functools import reduce + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + set_grad_sample_if_exists, + standard_kwargs, + unpack_expanded_weight_or_tensor, +) + + +@implements_per_sample_grads(F.group_norm) +class GroupNormPerSampleGrad(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + expanded_args, expanded_kwargs = standard_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + input, num_groups = expanded_args + N = input.shape[0] + C = input.shape[1] + HxW = reduce(operator.mul, input.shape[2:], 1) + weight, bias, eps = ( + expanded_kwargs["weight"], + expanded_kwargs["bias"], + expanded_kwargs["eps"], + ) + output, mean, rstd = forward_helper( + torch.native_group_norm, + (input, weight, bias, N, C, HxW, num_groups, eps), + {}, + ) + ctx.input, ctx.num_groups = input, num_groups + ctx.weight, ctx.eps = weight, eps + ctx.mean, ctx.rstd = mean, rstd + if isinstance(bias, ExpandedWeight): + ctx.bias = bias + if input.requires_grad and isinstance(weight, ExpandedWeight): + ctx.weight = weight + return output + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + input, num_groups = ctx.input, ctx.num_groups + weight, bias, eps = ctx.weight, ctx.bias, ctx.eps + mean, rstd = ctx.mean, ctx.rstd + + results: list[torch.Tensor | None] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + + if input.requires_grad: + weight_c = unpack_expanded_weight_or_tensor( + weight, lambda t: t.contiguous() + ) + input_c = input.contiguous() + grad_output_c = ( + grad_output.contiguous() if grad_output is not None else None + ) + N = input.shape[0] + C = input.shape[1] + HxW = 1 + for s in input.shape[2:]: + HxW *= s + bw_fn = torch.ops.aten.native_group_norm_backward + results.append( + bw_fn( + grad_output_c, + input_c, + mean, + rstd, + weight_c, + N, + C, + HxW, + num_groups, + (True, False, False), + )[0] + ) + else: + results.append(None) + + # weight and bias don't compute batched gradients; no other arguments are differentiable + results = results + [None] * 4 + + # set grad_sample field for weight and bias with per sample gradients + if hasattr(ctx, "weight"): + set_grad_sample_if_exists( + weight, + lambda _: torch.einsum( + "ni...->ni", + # pyrefly: ignore [unsupported-operation] + F.group_norm(input, num_groups, eps=eps) * grad_output, + ), + ) + if hasattr(ctx, "bias"): + set_grad_sample_if_exists( + bias, lambda _: torch.einsum("ni...->ni", grad_output) + ) + return tuple(results) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5237cb4e32472204679618b9800b4b85062cc1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py @@ -0,0 +1,101 @@ +# mypy: allow-untyped-defs +from functools import partial + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + set_grad_sample_if_exists, + standard_kwargs, + unpack_expanded_weight_or_tensor, +) + + +@implements_per_sample_grads(F.instance_norm) +class InstanceNormPerSampleGrad(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + instance_norm = partial(torch.instance_norm, cudnn_enabled=True) + expanded_args, expanded_kwargs = standard_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + output = forward_helper(instance_norm, expanded_args, expanded_kwargs) + ctx.input = expanded_args[0] + ctx.running_mean, ctx.running_var = ( + expanded_kwargs["running_mean"], + expanded_kwargs["running_var"], + ) + ctx.weight, ctx.bias, ctx.eps = ( + expanded_kwargs["weight"], + expanded_kwargs["bias"], + expanded_kwargs["eps"], + ) + return output + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var + weight, bias, eps = ctx.weight, ctx.bias, ctx.eps + + results: list[torch.Tensor | None] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + if input.requires_grad: + b = input.shape[0] + c = input.shape[1] + new_shape = (1, b * c, *input.shape[2:]) + + weight_ = unpack_expanded_weight_or_tensor( + weight, lambda orig_weight: orig_weight.repeat(b) + ) + running_mean_ = running_mean.repeat(b) if running_mean is not None else None + running_var_ = running_var.repeat(b) if running_var is not None else None + input_reshaped = input.contiguous().view(new_shape) + grad_output_reshaped = grad_output.contiguous().view(new_shape) + mean = torch.mean( + input_reshaped, (0,) + tuple(range(2, input.dim())), False + ) + var = torch.var( + input_reshaped, + (0,) + tuple(range(2, input.dim())), + keepdim=False, + unbiased=False, + ) + rstd = 1 / torch.sqrt(var + eps) + + # must use native batch norm since it supports all inputs. This may have used cuda or openmi during the forward but + # it didn't save the metadata, so we don't know during the backward + res = torch.ops.aten.native_batch_norm_backward( + grad_output_reshaped, + input_reshaped, + weight_, + running_mean_, + running_var_, + mean, + rstd, + True, + eps, + (True, False, False), + ) + results.append(res[0].reshape(input.shape)) + else: + results.append(None) + + # weight and bias don't compute batched gradients; no other arguments are differentiable (2 are not saved from the forward) + results = results + [None] * 7 + + # set grad_sample field for weight and bias with per sample gradients + set_grad_sample_if_exists( + weight, + lambda _: torch.einsum( + "ni...->ni", F.instance_norm(input, eps=eps) * grad_output + ), + ) + set_grad_sample_if_exists( + bias, lambda _: torch.einsum("ni...->ni", grad_output) + ) + return tuple(results) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..705253861dbd0c0e16300ecc96561554c8ac6d60 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py @@ -0,0 +1,88 @@ +# mypy: allow-untyped-defs + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + set_grad_sample_if_exists, + standard_kwargs, + sum_over_all_but_batch_and_last_n, + unpack_expanded_weight_or_tensor, +) + + +@implements_per_sample_grads(F.layer_norm) +class LayerNormPerSampleGrad(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + expanded_args, expanded_kwargs = standard_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + input = expanded_args[0] + normalized_shape = expanded_args[1] + if len(input.shape) <= len(normalized_shape): + raise RuntimeError( + "Expanded Weights: Layer norm should not normalize over batch dimension for per sample gradient" + f"computations but got that normalized shape, {normalized_shape}, matched input shape." + ) + output, mean, rstd = forward_helper( + torch.native_layer_norm, expanded_args, expanded_kwargs + ) + ctx.args = expanded_args + + if input.requires_grad or isinstance(expanded_kwargs["weight"], ExpandedWeight): + ctx.weight = expanded_kwargs["weight"] + if input.requires_grad or isinstance(expanded_kwargs["bias"], ExpandedWeight): + ctx.bias = expanded_kwargs["bias"] + ctx.eps = expanded_kwargs["eps"] + ctx.mean, ctx.rstd = mean, rstd + return output + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + def weight_per_sample_grad(weight): + return sum_over_all_but_batch_and_last_n( + F.layer_norm(input, normalized_shape, eps=ctx.eps) * grad_output, + weight.dim(), + ) + + input, normalized_shape = ctx.args + mean, rstd = ctx.mean, ctx.rstd + + results: list[torch.Tensor | None] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + if input.requires_grad: + weight_ = unpack_expanded_weight_or_tensor(ctx.weight) + bias_ = unpack_expanded_weight_or_tensor(ctx.bias) + results.append( + torch.ops.aten.native_layer_norm_backward( + grad_output, + input, + normalized_shape, + mean, + rstd, + weight_, + bias_, + (True, False, False), + )[0] + ) + else: + results.append(None) + + # weight and bias don't compute batched gradients; no other arguments are differentiable + results = results + [None] * 4 + + # set grad_sample field for weight and bias with per sample gradients + if hasattr(ctx, "weight"): + set_grad_sample_if_exists(ctx.weight, weight_per_sample_grad) + if hasattr(ctx, "bias"): + set_grad_sample_if_exists( + ctx.bias, + lambda bias: sum_over_all_but_batch_and_last_n(grad_output, bias.dim()), + ) + return tuple(results) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/linear_expanded_weights.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/linear_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..2cd6b96f58bd614e1004de0ce939cdb90a85dc67 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/linear_expanded_weights.py @@ -0,0 +1,63 @@ +# mypy: allow-untyped-defs + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + is_batch_first, + set_grad_sample_if_exists, + unpack_expanded_weight_or_tensor, +) + + +@implements_per_sample_grads(F.linear) +class LinearPerSampleGrad(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, _, __, *expanded_args_and_kwargs): + if len(expanded_args_and_kwargs[0].shape) <= 1: + raise RuntimeError( + "Input does not have a batch dimension. Expanded Weights expected input " + f"of at least rank 2, got of rank {len(expanded_args_and_kwargs[0].shape)}" + ) + expanded_kwargs = { + "bias": expanded_args_and_kwargs[2] + if len(expanded_args_and_kwargs) == 3 + else None + } + expanded_args = expanded_args_and_kwargs[:2] + ctx.batch_first = is_batch_first(expanded_args_and_kwargs) + output = forward_helper(F.linear, expanded_args, expanded_kwargs) + ctx.args = expanded_args + ctx.kwargs = expanded_kwargs + return output + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + input, weight = ctx.args + bias = ctx.kwargs["bias"] + results: list[torch.Tensor | None] = [] + results.append(None) # for kwarg_names + results.append(None) # for op reference + + if input.requires_grad: + results.append(grad_output.matmul(unpack_expanded_weight_or_tensor(weight))) + else: + results.append(None) + results.extend([None] * 2) # weight and bias don't compute batched gradients + + if not ctx.batch_first: + grad_output = grad_output.transpose(0, 1) + input = input.transpose(0, 1) + + # weight and bias get their grad_sample fields set directly if they exist + set_grad_sample_if_exists( + weight, lambda _: torch.einsum("n...i,n...j->nij", grad_output, input) + ) + set_grad_sample_if_exists( + bias, lambda _: torch.einsum("n...k->nk", grad_output) + ) + return tuple(results) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_named_member_accessor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_named_member_accessor.py new file mode 100644 index 0000000000000000000000000000000000000000..0935490856aebf3503aa126e51d342c3bac0b529 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_named_member_accessor.py @@ -0,0 +1,373 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Iterable + +import torch + + +_MISSING: torch.Tensor = object() # type: ignore[assignment] + + +def set_tensor(module: "torch.nn.Module", name: str, tensor: torch.Tensor) -> None: + if not isinstance(module, torch.nn.Module): + raise TypeError(f"{module} is not an instance of torch.nn.Module") + if not isinstance(tensor, torch.Tensor) and tensor is not None: + raise TypeError(f"{tensor} is not an instance of torch.Tensor") + if "." in name: + raise KeyError('tensor name can\'t contain "."') + if name == "": + raise KeyError('tensor name can\'t be empty string ""') + if name in module._parameters: + module._parameters[name] = tensor # type: ignore[assignment] + elif name in module._buffers: + module._buffers[name] = tensor + else: + setattr(module, name, tensor) + + +def swap_tensor( + module: "torch.nn.Module", + name: str, + tensor: torch.Tensor, + allow_missing: bool = False, +) -> torch.Tensor: + if not isinstance(module, torch.nn.Module): + raise TypeError(f"{module} is not an instance of torch.nn.Module") + if ( + tensor is not _MISSING + and not isinstance(tensor, torch.Tensor) + and tensor is not None + ): + raise TypeError(f"{tensor} is not an instance of torch.Tensor") + if "." in name: + raise KeyError('tensor name can\'t contain "."') + if name == "": + raise KeyError('tensor name can\'t be empty string ""') + + orig_tensor: torch.Tensor + if name in module._parameters: + orig_tensor = module._parameters[name] # type: ignore[assignment] + if tensor is not _MISSING: + module._parameters[name] = tensor # type: ignore[assignment] + else: + del module._parameters[name] + elif name in module._buffers: + orig_tensor = module._buffers[name] # type: ignore[assignment] + if tensor is not _MISSING: + module._buffers[name] = tensor + else: + del module._buffers[name] + else: + if hasattr(module, name): + orig_tensor = getattr(module, name) + else: + if not allow_missing: + raise AttributeError(f"{module._get_name()} has no attribute `{name}`") + orig_tensor = _MISSING + if ( + orig_tensor is not _MISSING + and not isinstance(orig_tensor, torch.Tensor) + and orig_tensor is not None + ): + raise TypeError( + f"attribute `{name}`: {orig_tensor} is not an instance of torch.Tensor" + ) + if tensor is not _MISSING: + setattr(module, name, tensor) + elif hasattr(module, name): + delattr(module, name) + # pyrefly: ignore [bad-return] + return orig_tensor + + +def swap_submodule( + module: "torch.nn.Module", + name: str, + submodule: "torch.nn.Module", +) -> "torch.nn.Module": + if not isinstance(module, torch.nn.Module): + raise TypeError(f"{module} is not an instance of torch.nn.Module") + if not isinstance(submodule, torch.nn.Module): + raise TypeError(f"{submodule} is not an instance of torch.nn.Module") + if "." in name: + raise KeyError('submodule name can\'t contain "."') + if name == "": + raise KeyError('submodule name can\'t be empty string ""') + if name not in module._modules: + raise KeyError(f"submodule {name} does not exist") + + orig_submodule = module._modules[name] + if not isinstance(orig_submodule, torch.nn.Module): + raise TypeError(f"{name} attribute is not an instance of torch.nn.Module") + module._modules[name] = submodule + return orig_submodule + + +class NamedMemberAccessor: + """ + A class that provides a way to access the submodules and parameters/buffers of a module. + + It provides caching mechanism to speed up submodule lookups. + This is useful for functional programming to manipulate the module state. + """ + + def __init__(self, module: "torch.nn.Module") -> None: + self.module = module + self.memo: dict[str, torch.nn.Module] = {} + + # Nested attribute access + + def get_submodule(self, name: str) -> "torch.nn.Module": + """ + Return the submodule specified by the given path. + + For example, to get the submodule mod.layer1.conv1, + use accessor.get_submodule("layer1.conv1") + + Compare to mod.get_submodule("layer1.conv1"), this method will cache the + intermediate submodule access to speed up future lookups. + """ + if not name: + return self.module + + if name in self.memo: + return self.memo[name] + else: + prefix, dot, attr = name.rpartition(".") + if dot: + module = self.get_submodule(prefix) + else: + module = self.module + try: + submodule = getattr(module, attr) + except AttributeError as ex: + raise AttributeError( + f"{module._get_name()} has no attribute `{attr}`" + ) from ex + if not isinstance(submodule, torch.nn.Module): + raise TypeError( + f"submodule `{name}`: {submodule} is not an instance of torch.nn.Module" + ) + self.memo[name] = submodule + return submodule + + def swap_submodule(self, path: str, value: "torch.nn.Module") -> "torch.nn.Module": + """ + Swap the submodule specified by the given ``path`` to ``value``. + + For example, to swap the attribute mod.layer1.conv1 use + ``accessor.swap_submodule("layer1.conv1", conv2)``. + """ + prefix, _, attr = path.rpartition(".") + return swap_submodule(self.get_submodule(prefix), attr, value) + + def get_tensor(self, name: str) -> torch.Tensor: + """ + Get the tensor specified by the given path to value. + + For example, to get the attribute mod.layer1.conv1.weight, + use accessor.get_tensor('layer1.conv1.weight') + + Compare to mod.get_parameter("layer1.conv1.weight"), this method will + cache the intermediate submodule access to speed up future lookups. + """ + prefix, _, attr = name.rpartition(".") + submodule = self.get_submodule(prefix) + try: + tensor = getattr(submodule, attr) + except AttributeError as ex: + raise AttributeError( + f"{submodule._get_name()} has no attribute `{name}`" + ) from ex + if not isinstance(tensor, torch.Tensor) and tensor is not None: + raise TypeError(f"{tensor} is not an instance of torch.Tensor") + return tensor # type: ignore[return-value] + + def set_tensor(self, name: str, value: torch.Tensor) -> None: + """ + Set the attribute specified by the given path to value. + + For example, to set the attribute mod.layer1.conv1.weight, + use accessor.set_tensor("layer1.conv1.weight", value) + """ + prefix, _, attr = name.rpartition(".") + set_tensor(self.get_submodule(prefix), attr, value) + + def del_tensor(self, name: str) -> None: + """ + Delete the attribute specified by the given path. + + For example, to delete the attribute mod.layer1.conv1.weight, + use accessor.del_tensor("layer1.conv1.weight") + """ + prefix, _, attr = name.rpartition(".") + submodule = self.get_submodule(prefix) + try: + delattr(submodule, attr) + except AttributeError as ex: + raise AttributeError( + f"{submodule._get_name()} has no attribute `{name}`" + ) from ex + + def swap_tensor( + self, name: str, value: torch.Tensor, allow_missing: bool = False + ) -> torch.Tensor: + """ + Swap the attribute specified by the given path to value. + + For example, to swap the attribute mod.layer1.conv1.weight, + use accessor.swap_tensor("layer1.conv1.weight", value) + """ + prefix, _, attr = name.rpartition(".") + return swap_tensor( + self.get_submodule(prefix), attr, value, allow_missing=allow_missing + ) + + # Batched operations + + def get_tensors(self, names: Iterable[str]) -> list[torch.Tensor]: + """ + Get the tensors specified by the given paths. + + For example, to get the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.get_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"]) + """ + return [self.get_tensor(name) for name in names] + + def set_tensors(self, names: Iterable[str], values: Iterable[torch.Tensor]) -> None: + """ + Set the attributes specified by the given paths to values. + + For example, to set the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.set_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"], [weight, bias]) + """ + if not isinstance(names, (list, tuple)): + names = list(names) + if not isinstance(values, (list, tuple)): + values = list(values) + assert len(names) == len(values), "names and values must have the same length" + + for name, value in zip(names, values, strict=True): + self.set_tensor(name, value) + + def set_tensors_dict(self, named_tensors: dict[str, torch.Tensor]) -> None: + """ + Set the attributes specified by the given paths to values. + + For example, to set the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.set_tensors_dict({ + "layer1.conv1.weight": weight, + "layer1.conv1.bias": bias, + }) + """ + for name, value in named_tensors.items(): + self.set_tensor(name, value) + + def del_tensors(self, names: Iterable[str]) -> None: + """ + Delete the attributes specified by the given paths. + + For example, to delete the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.del_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"]) + """ + for name in names: + self.del_tensor(name) + + def swap_tensors( + self, + names: Iterable[str], + values: Iterable[torch.Tensor], + allow_missing: bool = False, + ) -> list[torch.Tensor]: + """ + Swap the attributes specified by the given paths to values. + + For example, to swap the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.swap_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"], [weight, bias]) + """ + if not isinstance(names, (list, tuple)): + names = list(names) + if not isinstance(values, (list, tuple)): + values = list(values) + assert len(names) == len(values), "names and values must have the same length" + + return [ + self.swap_tensor(name, value, allow_missing=allow_missing) + for name, value in zip(names, values, strict=True) + ] + + def swap_tensors_dict( + self, named_tensors: dict[str, torch.Tensor], allow_missing: bool = False + ) -> tuple[dict[str, torch.Tensor], list[str]]: + """ + Swap the attributes specified by the given paths to values. + + For example, to swap the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.swap_tensors_dict({ + "layer1.conv1.weight": weight, + "layer1.conv1.bias": bias, + }) + """ + orig_named_tensors = {} + missing_keys = [] + try: + for name, tensor in named_tensors.items(): + orig_tensor = self.swap_tensor(name, tensor, allow_missing=True) + if orig_tensor is _MISSING: + missing_keys.append(name) + orig_named_tensors[name] = orig_tensor + except Exception: + # Swap back if any exception occurs + for name, orig_tensor in orig_named_tensors.items(): + self.swap_tensor(name, orig_tensor, allow_missing=True) + raise + if missing_keys and not allow_missing: + # Swap back if any key is missing when allow_missing is False + for name, orig_tensor in orig_named_tensors.items(): + self.swap_tensor(name, orig_tensor, allow_missing=True) + raise RuntimeError(f"Missing key(s): {', '.join(map(repr, missing_keys))}.") + return orig_named_tensors, missing_keys + + def check_keys(self, keys: Iterable[str]) -> tuple[list[str], list[str]]: + """Check that the given keys are valid.""" + keys = set(keys) + valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)} + missing_keys = valid_keys - keys + unexpected_keys = keys - valid_keys + return sorted(missing_keys), sorted(unexpected_keys) + + # Shortcut methods + + def named_parameters( + self, + remove_duplicate: bool = True, + ) -> Iterable[tuple[str, torch.Tensor]]: + """Iterate over all the parameters in the module.""" + yield from self.module.named_parameters(remove_duplicate=remove_duplicate) + + def named_buffers( + self, + remove_duplicate: bool = True, + ) -> Iterable[tuple[str, torch.Tensor]]: + """Iterate over all the buffers in the module.""" + yield from self.module.named_buffers(remove_duplicate=remove_duplicate) + + def named_tensors( + self, + remove_duplicate: bool = True, + ) -> Iterable[tuple[str, torch.Tensor]]: + """Iterate over all the tensors in the module.""" + yield from self.module.named_parameters(remove_duplicate=remove_duplicate) + yield from self.module.named_buffers(remove_duplicate=remove_duplicate) + + def named_modules( + self, + remove_duplicate: bool = True, + ) -> Iterable[tuple[str, "torch.nn.Module"]]: + """Iterate over all the modules in the module.""" + yield from self.module.named_modules(remove_duplicate=remove_duplicate) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_per_sample_grad.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_per_sample_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..2eae0865845eec9c426c5cc3b7bff1b11b5b1230 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/_per_sample_grad.py @@ -0,0 +1,126 @@ +# mypy: allow-untyped-defs +import functools + +import torch +from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeight +from torch.utils import _pytree as pytree + + +# dependency on `functional_call` means that this can't be exposed in utils +# without creating circular dependency +def call_for_per_sample_grads( + module, + *, + batch_size=None, + loss_reduction="sum", + batch_first=True, +): + r""" + Return a forward function for a module, populating grad_sample with per sample gradients on backward invocation. + + Args: + module: The ``nn.Module`` to get per sample gradients with respect to. All trainable + parameters will compute per sample gradients, located in a ``grad_sample`` + field when ``backward`` is invoked + batch_size: The batch size of the input. If None is passed, all tensor arguments in args and kwargs must have + the same batch size, which is the size of the first dimension. Otherwise, it must be passed manually. + Default: None + loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. If + "mean", per sample gradients will be scaled by the batch size to offset the crossbatch interaction from + running mean across a batch. Must be "mean" or "sum". Default: "sum" + batch_first: Indicates if the batch dimension is the first dimension. If True, the batch dimension is the first + dimension. If False, it's the second dimension. Default: True. + + Examples:: + >>> # xdoctest: +SKIP + >>> model = nn.Linear(4, 3) + >>> batched_input = torch.randn(5, 4) # batch size of 5 + >>> res = call_for_per_sample_grads(model)(batched_input).sum() + >>> res.backward() + >>> assert model.weight.shape == (3, 4) + >>> assert model.weight.grad_sample.shape == (5, 3, 4) + >>> assert model.weight.grad is None + >>> assert model.bias.shape == (3,) + >>> assert model.bias.grad_sample.shape == (5, 3) + >>> assert model.bias.grad is None + + An example using "mean" loss reduction. The grad_sample fields will be scaled by batch_size from what they would be + if we ran the same code with loss_reduction="sum". This is because the mean at the end will scale all + grad_outputs by 1 / batch_size from cross batch interaction. + >>> model = nn.Linear(4, 3) + >>> batched_input = torch.randn(5, 4) # batch size of 5 + >>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")( + ... batched_input + ... ).mean() + >>> res.backward() + + Note:: + Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom + rewrites that wrap an `nn.Linear` module. See Opacus for an example + """ + + def maybe_build_expanded_weight(og_tensor, batch_size): + if og_tensor.requires_grad: + return ExpandedWeight(og_tensor, batch_size, loss_reduction) + else: + return og_tensor + + def compute_batch_size(*args, **kwargs): + args_and_kwargs = pytree.arg_tree_leaves(*args, **kwargs) + batch_size = None + for arg in args_and_kwargs: + if not isinstance(arg, torch.Tensor): + continue + + arg_batch_size = arg.shape[0] if batch_first else arg.shape[1] + if batch_size is not None and batch_size != arg_batch_size: + raise RuntimeError( + "When computing batch size, found at least one input with batch size " + f"{batch_size} and one with batch size {arg_batch_size}. Please specify it " + "explicitly using the batch size kwarg in call_for_per_sample_grads" + ) + batch_size = arg_batch_size + if batch_size is None: + raise RuntimeError( + "Unable to find a tensor in the passed args and kwargs. They may not be pytree-able " + "and so ExpandedWeights cannot compute the batch size from the inputs. Please specify " + "it explicitly" + ) + return batch_size + + if loss_reduction not in ["sum", "mean"]: + raise RuntimeError( + f"Expected loss_reduction argument to be sum or mean, got {loss_reduction}" + ) + + if not isinstance(module, torch.nn.Module): + raise RuntimeError( + f"Module passed must be nn.Module, got {type(module).__name__}" + ) + if not (batch_size is None or isinstance(batch_size, int)): + raise RuntimeError( + f"Batch size passed must be None or an integer, got {type(batch_size).__name__}" + ) + if batch_size is not None and batch_size < 1: + raise RuntimeError(f"Batch size must be positive, got {batch_size}") + for weight in module.parameters(): + if hasattr(weight, "grad_sample") and weight.grad_sample is not None: # type: ignore[attr-defined] + raise RuntimeError( + "Current Expanded Weights accumulates the gradients, which will be incorrect for multiple " + f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or " + "post an issue to pytorch/pytorch to prioritize correct behavior" + ) + + @functools.wraps(module.forward) + def wrapper(*args, **kwargs): + wrapper_batch_size = batch_size + if wrapper_batch_size is None: + wrapper_batch_size = compute_batch_size(*args, **kwargs) + + params = { + name: maybe_build_expanded_weight(value, wrapper_batch_size) + for (name, value) in module.named_parameters() + } + return torch.func.functional_call(module, params, args, kwargs) + + return wrapper diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/clip_grad.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/clip_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..30202708bfa38bb8437627152fb76061955e31f9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/clip_grad.py @@ -0,0 +1,299 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import types +import typing +import warnings +from collections.abc import Callable +from typing import cast, TypeAlias, TypeVar +from typing_extensions import deprecated, ParamSpec + +import torch +from torch import Tensor +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +__all__: list[str] = [ + "clip_grad_norm", + "clip_grad_norm_", + "clip_grad_value_", +] + + +_tensor_or_tensors: TypeAlias = torch.Tensor | typing.Iterable[torch.Tensor] # noqa: PYI042 + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]: + """ + This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions + clip_grad_norm_ and clip_grad_value_ themselves. + """ + + def _no_grad_wrapper(*args, **kwargs): + with torch.no_grad(): + # pyrefly: ignore [invalid-param-spec] + return func(*args, **kwargs) + + functools.update_wrapper(_no_grad_wrapper, func) + # pyrefly: ignore [bad-return] + return _no_grad_wrapper + + +@_no_grad +def _get_total_norm( + tensors: _tensor_or_tensors, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: bool | None = None, +) -> torch.Tensor: + r"""Compute the norm of an iterable of tensors. + + The norm is computed over the norms of the individual tensors, as if the norms of + the individual tensors were concatenated into a single vector. + + Args: + tensors (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will be normalized + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of :attr:`tensors` is ``nan``, ``inf``, or ``-inf``. + Default: ``False`` + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the tensors (viewed as a single vector). + """ + if isinstance(tensors, torch.Tensor): + tensors = [tensors] + else: + tensors = list(tensors) + norm_type = float(norm_type) + if len(tensors) == 0: + return torch.tensor(0.0) + first_device = tensors[0].device + grouped_tensors: dict[ + tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]] + ] = _group_tensors_by_device_and_dtype( + [tensors] # type: ignore[list-item] + ) # type: ignore[assignment] + + norms: list[Tensor] = [] + for (device, _), ([device_tensors], _) in grouped_tensors.items(): + if (foreach is None and _has_foreach_support(device_tensors, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_tensors, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend( + [torch.linalg.vector_norm(g, norm_type) for g in device_tensors] + ) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + return total_norm + + +@_no_grad +def _clip_grads_with_norm_( + parameters: _tensor_or_tensors, + max_norm: float, + total_norm: torch.Tensor, + foreach: bool | None = None, +) -> None: + r"""Scale the gradients of an iterable of parameters given a pre-calculated total norm and desired max norm. + + The gradients will be scaled by the following calculation + + .. math:: + grad = grad * \min(\frac{max\_norm}{total\_norm + 1e-6}, 1) + + Gradients are modified in-place. + + Note: The scale coefficient is clamped to a maximum of 1.0 to prevent gradient amplification. + This ensures that gradients are only scaled down when the total norm exceeds max_norm. + + This function is equivalent to :func:`torch.nn.utils.clip_grad_norm_` with a pre-calculated + total norm. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + total_norm (Tensor): total norm of the gradients to use for clipping + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + None + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + if len(grads) == 0: + return + grouped_grads: dict[ + tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]] + ] = _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment] + + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + +@_no_grad +def clip_grad_norm_( + parameters: _tensor_or_tensors, + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: bool | None = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + This function is equivalent to :func:`torch.nn.utils.get_total_norm` followed by + :func:`torch.nn.utils.clip_grads_with_norm_` with the ``total_norm`` returned by ``get_total_norm``. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float, optional): type of the used p-norm. Can be ``'inf'`` for + infinity norm. Default: 2.0 + error_if_nonfinite (bool, optional): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False + foreach (bool, optional): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + else: + is_generator = isinstance(parameters, types.GeneratorType) + # prevent generators from being exhausted + parameters = list(parameters) + if is_generator and len(parameters) == 0: + warnings.warn( + "`parameters` is an empty generator, no gradient clipping will occur.", + stacklevel=3, + ) + grads = [p.grad for p in parameters if p.grad is not None] + total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) + _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) + return total_norm + + +@deprecated( + "`torch.nn.utils.clip_grad_norm` is now deprecated " + "in favor of `torch.nn.utils.clip_grad_norm_`.", + category=FutureWarning, +) +def clip_grad_norm( + parameters: _tensor_or_tensors, + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: bool | None = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + .. warning:: + This method is now deprecated in favor of + :func:`torch.nn.utils.clip_grad_norm_`. + """ + return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach) + + +@_no_grad +def clip_grad_value_( + parameters: _tensor_or_tensors, + clip_value: float, + foreach: bool | None = None, +) -> None: + r"""Clip the gradients of an iterable of parameters at specified value. + + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + clip_value (float): maximum allowed value of the gradients. + The gradients are clipped in the range + :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]` + foreach (bool, optional): use the faster foreach-based implementation + If ``None``, use the foreach implementation for CUDA and CPU native tensors and + silently fall back to the slow implementation for other device types. + Default: ``None`` + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + clip_value = float(clip_value) + + grads = [p.grad for p in parameters if p.grad is not None] + # pyrefly: ignore [bad-argument-type] + grouped_grads = _group_tensors_by_device_and_dtype([grads]) + + for (device, _), ([grads], _) in grouped_grads.items(): + if ( + foreach is None + and _has_foreach_support(cast(list[Tensor], grads), device=device) + ) or (foreach and _device_has_foreach_support(device)): + torch._foreach_clamp_min_(cast(list[Tensor], grads), -clip_value) + torch._foreach_clamp_max_(cast(list[Tensor], grads), clip_value) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + for grad in grads: + cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/convert_parameters.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/convert_parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..6a56da711ecda3c6e3d5770783f100a8890bbf55 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/convert_parameters.py @@ -0,0 +1,90 @@ +from collections.abc import Iterable + +import torch + + +def parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor: + r"""Flatten an iterable of parameters into a single vector. + + Args: + parameters (Iterable[Tensor]): an iterable of Tensors that are the + parameters of a model. + + Returns: + The parameters represented by a single vector + """ + # Flag for the device where the parameter is located + param_device = None + + vec = [] + for param in parameters: + # Ensure the parameters are located in the same device + param_device = _check_param_device(param, param_device) + + vec.append(param.view(-1)) + return torch.cat(vec) + + +def vector_to_parameters(vec: torch.Tensor, parameters: Iterable[torch.Tensor]) -> None: + r"""Copy slices of a vector into an iterable of parameters. + + Args: + vec (Tensor): a single vector representing the parameters of a model. + parameters (Iterable[Tensor]): an iterable of Tensors that are the + parameters of a model. + """ + # Ensure vec of type Tensor + if not isinstance(vec, torch.Tensor): + raise TypeError(f"expected torch.Tensor, but got: {torch.typename(vec)}") + # Flag for the device where the parameter is located + param_device = None + + # Pointer for slicing the vector for each parameter + pointer = 0 + for param in parameters: + # Ensure the parameters are located in the same device + param_device = _check_param_device(param, param_device) + + # The length of the parameter + num_param = param.numel() + # Slice the vector, reshape it, and replace the old data of the parameter + param.data = vec[pointer : pointer + num_param].view_as(param).data + + # Increment the pointer + pointer += num_param + + +def _check_param_device(param: torch.Tensor, old_param_device: int | None) -> int: + r"""Check if the parameters are located on the same device. + + Currently, the conversion between model parameters and single vector form is not supported + for multiple allocations, e.g. parameters in different GPUs/PrivateUse1s, or mixture of CPU/GPU/PrivateUse1. + + Args: + param ([Tensor]): a Tensor of a parameter of a model + old_param_device (int): the device where the first parameter of a + model is allocated. + + Returns: + old_param_device (int): report device for the first time + """ + # Meet the first parameter + support_device_types = ["cuda", torch._C._get_privateuse1_backend_name()] + if old_param_device is None: + old_param_device = ( + param.get_device() if param.device.type in support_device_types else -1 + ) + else: + warn = False + if ( + param.device.type in support_device_types + ): # Check if in same GPU/PrivateUse1 + warn = param.get_device() != old_param_device + else: # Check if in CPU + warn = old_param_device != -1 + if warn: + raise TypeError( + "Found two parameters on different devices, " + "this is currently not supported." + ) + return old_param_device diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/fusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..35406785305117f979479bc2baec0f65d6fdb7af --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/fusion.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import copy +from typing import TypeVar + +import torch + + +__all__ = [ + "fuse_conv_bn_eval", + "fuse_conv_bn_weights", + "fuse_linear_bn_eval", + "fuse_linear_bn_weights", +] + +ConvT = TypeVar("ConvT", bound="torch.nn.modules.conv._ConvNd") +LinearT = TypeVar("LinearT", bound="torch.nn.Linear") + + +def fuse_conv_bn_eval( + conv: ConvT, + bn: torch.nn.modules.batchnorm._BatchNorm, + transpose: bool = False, +) -> ConvT: + r"""Fuse a convolutional module and a BatchNorm module into a single, new convolutional module. + + Args: + conv (torch.nn.modules.conv._ConvNd): A convolutional module. + bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module. + transpose (bool, optional): If True, transpose the convolutional weight. Defaults to False. + + Returns: + torch.nn.modules.conv._ConvNd: The fused convolutional module. + + .. note:: + Both ``conv`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed. + """ + assert not (conv.training or bn.training), "Fusion only for eval!" + fused_conv = copy.deepcopy(conv) + + assert bn.running_mean is not None and bn.running_var is not None + fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( + fused_conv.weight, + fused_conv.bias, + bn.running_mean, + bn.running_var, + bn.eps, + bn.weight, + bn.bias, + transpose, + ) + + return fused_conv + + +def fuse_conv_bn_weights( + conv_w: torch.Tensor, + conv_b: torch.Tensor | None, + bn_rm: torch.Tensor, + bn_rv: torch.Tensor, + bn_eps: float, + bn_w: torch.Tensor | None, + bn_b: torch.Tensor | None, + transpose: bool = False, +) -> tuple[torch.nn.Parameter, torch.nn.Parameter]: + r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters. + + Args: + conv_w (torch.Tensor): Convolutional weight. + conv_b (Optional[torch.Tensor]): Convolutional bias. + bn_rm (torch.Tensor): BatchNorm running mean. + bn_rv (torch.Tensor): BatchNorm running variance. + bn_eps (float): BatchNorm epsilon. + bn_w (Optional[torch.Tensor]): BatchNorm weight. + bn_b (Optional[torch.Tensor]): BatchNorm bias. + transpose (bool, optional): If True, transpose the conv weight. Defaults to False. + + Returns: + Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused convolutional weight and bias. + """ + conv_weight_dtype = conv_w.dtype + conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype + if conv_b is None: + conv_b = torch.zeros_like(bn_rm) + if bn_w is None: + bn_w = torch.ones_like(bn_rm) + if bn_b is None: + bn_b = torch.zeros_like(bn_rm) + bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) + + if transpose: + shape = [1, -1] + [1] * (len(conv_w.shape) - 2) + else: + shape = [-1, 1] + [1] * (len(conv_w.shape) - 2) + + fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to( + dtype=conv_weight_dtype + ) + fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to( + dtype=conv_bias_dtype + ) + + return ( + torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), + torch.nn.Parameter(fused_conv_b, conv_b.requires_grad), + ) + + +def fuse_linear_bn_eval( + linear: LinearT, + bn: torch.nn.modules.batchnorm._BatchNorm, +) -> LinearT: + r"""Fuse a linear module and a BatchNorm module into a single, new linear module. + + Args: + linear (torch.nn.Linear): A Linear module. + bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module. + + Returns: + torch.nn.Linear: The fused linear module. + + .. note:: + Both ``linear`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed. + """ + assert not (linear.training or bn.training), "Fusion only for eval!" + fused_linear = copy.deepcopy(linear) + + """ + Linear-BN needs to be fused while preserving the shapes of linear weight/bias. + To preserve the shapes of linear weight/bias, the channel dim of bn needs to be broadcastable with the last dim of linear, + because bn operates over the channel dim, (N, C_in, H, W) while linear operates over the last dim, (*, H_in). + To be broadcastable, the number of features in bn and + the number of output features from linear must satisfy the following condition: + 1. they are equal, or + 2. the number of features in bn is 1 + Otherwise, skip the folding path + """ + assert linear.out_features == bn.num_features or bn.num_features == 1, ( + "To fuse, linear.out_features == bn.num_features or bn.num_features == 1" + ) + + assert bn.running_mean is not None and bn.running_var is not None + fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights( + fused_linear.weight, + fused_linear.bias, + bn.running_mean, + bn.running_var, + bn.eps, + bn.weight, + bn.bias, + ) + + return fused_linear + + +def fuse_linear_bn_weights( + linear_w: torch.Tensor, + linear_b: torch.Tensor | None, + bn_rm: torch.Tensor, + bn_rv: torch.Tensor, + bn_eps: float, + bn_w: torch.Tensor, + bn_b: torch.Tensor, +) -> tuple[torch.nn.Parameter, torch.nn.Parameter]: + r"""Fuse linear module parameters and BatchNorm module parameters into new linear module parameters. + + Args: + linear_w (torch.Tensor): Linear weight. + linear_b (Optional[torch.Tensor]): Linear bias. + bn_rm (torch.Tensor): BatchNorm running mean. + bn_rv (torch.Tensor): BatchNorm running variance. + bn_eps (float): BatchNorm epsilon. + bn_w (torch.Tensor): BatchNorm weight. + bn_b (torch.Tensor): BatchNorm bias. + + Returns: + Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused linear weight and bias. + """ + linear_weight_dtype = linear_w.dtype + linear_bias_dtype = linear_b.dtype if linear_b is not None else linear_weight_dtype + if linear_b is None: + linear_b = torch.zeros_like(bn_rm) + bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps) + + fused_w = linear_w * bn_scale.unsqueeze(-1).to(dtype=linear_weight_dtype) + fused_b = ((linear_b - bn_rm) * bn_scale + bn_b).to(dtype=linear_bias_dtype) + + return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter( + fused_b, linear_b.requires_grad + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/init.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/init.py new file mode 100644 index 0000000000000000000000000000000000000000..10fa03b7c01c2eac7e474ef55f433e4704e6c778 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/init.py @@ -0,0 +1,55 @@ +# mypy: allow-untyped-defs +import inspect + +import torch + + +def skip_init(module_cls, *args, **kwargs): + r""" + Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers. + + This can be useful if initialization is slow or if custom initialization will + be performed, making the default initialization unnecessary. There are some caveats to this, due to + the way this function is implemented: + + 1. The module must accept a `device` arg in its constructor that is passed to any parameters + or buffers created during construction. + + 2. The module must not perform any computation on parameters in its constructor except + initialization (i.e. functions from :mod:`torch.nn.init`). + + If these conditions are satisfied, the module can be instantiated with parameter / buffer values + uninitialized, as if having been created using :func:`torch.empty`. + + Args: + module_cls: Class object; should be a subclass of :class:`torch.nn.Module` + args: args to pass to the module's constructor + kwargs: kwargs to pass to the module's constructor + + Returns: + Instantiated module with uninitialized parameters / buffers + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> import torch + >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) + >>> m.weight + Parameter containing: + tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]], + requires_grad=True) + >>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1) + >>> m2.weight + Parameter containing: + tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, + 4.5915e-41]], requires_grad=True) + + """ + if not issubclass(module_cls, torch.nn.Module): + raise RuntimeError(f"Expected a Module; got {module_cls}") + if "device" not in inspect.signature(module_cls).parameters: + raise RuntimeError("Module must support a 'device' arg to skip initialization") + + final_device = kwargs.pop("device", "cpu") + kwargs["device"] = "meta" + return module_cls(*args, **kwargs).to_empty(device=final_device) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/memory_format.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/memory_format.py new file mode 100644 index 0000000000000000000000000000000000000000..06eb55a02572d79b6f254624aaea90d86e5430a1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/memory_format.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from typing import TypeVar + +import torch + + +_M = TypeVar("_M", bound="torch.nn.Module") + + +def convert_conv2d_weight_memory_format( + module: _M, memory_format: torch.memory_format +) -> _M: + r"""Convert ``memory_format`` of ``nn.Conv2d.weight`` to ``memory_format``. + + The conversion recursively applies to nested ``nn.Module``, including ``module``. + Note that it only changes the memory_format, but not the semantics of each dimensions. + This function is used to facilitate the computation to adopt NHWC kernels, which + provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0 + + .. note:: + Calling ``model.to(memory_format=torch.channels_last)`` is more aggressive + than the utility function ``convert_conv2d_weight_memory_format``. Any + layer with 4d weight will be affected by ``model.to``, which does not + necessarily benefit from conversion to specified ``memory_format``. + One place we are confident in is that NHWC(channels_last) conversion for + convolution in cuDNN, as it is beneficial to run convolution in NHWC, + even in cases where we have to apply permutation to input tensors. + + Hence our strategy here is to convert only the weight of convolution to + channels_last. This ensures that; + 1. Fast convolution kernels will be used, the benefit of which could + outweigh overhead of permutation (if input is not in the same format). + 2. No unnecessary permutations are applied on layers that do not benefit + from memory_format conversion. + + The optimal case is that, layers between convolution layers are channels + last compatible. Input tensor would be permuted to channels last when it + encounters the first convolution layer and stay in that memory format. + Hence following convolutions will not need to permute its input tensor. + + In case where a channels last incompatible layer is between convolution + layers, we need to permute the input tensor back to contiguous format + for that layer. The input tensor will go through the remaining layers in + contiguous format and be permuted to channels last when it encounters + another convolution layer. There's no point in propagating that + permutation to an earlier layer, as most layers are quite agnostic to + ``memory_format``. + + This claim might change when PyTorch supports fusion of permutation, as + there might have been a better spot to fuse the permutation other than + immediately before a convolution. + + Args: + module (nn.Module): ``nn.Conv2d`` & ``nn.ConvTranspose2d`` or container + ``nn.Module`` + memory_format: user specified ``memory_format``, + e.g. ``torch.channels_last`` or ``torch.contiguous_format`` + + Returns: + The original module with updated ``nn.Conv2d`` + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) + >>> input = torch.randint( + ... 1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda" + ... ) + >>> model = nn.Sequential( + >>> nn.Conv2d(8, 4, 3)).cuda().half() + >>> # This is identical to: + >>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) + >>> model = nn.utils.convert_conv2d_weight_memory_format( + ... model, torch.channels_last + ... ) + >>> out = model(input) + """ + # TODO: expand this to `_ConvNd` when channels_last support is extended + # beyond only 4d tensors. + if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): + weight_data = module.weight.detach().clone(memory_format=memory_format) + module.weight.data = weight_data.resize_( + weight_data.size(), memory_format=memory_format + ) + for child in module.children(): + convert_conv2d_weight_memory_format(child, memory_format) + # pyrefly: ignore [bad-return] + return module + + +def convert_conv3d_weight_memory_format( + module: _M, memory_format: torch.memory_format +) -> _M: + r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format`` + The conversion recursively applies to nested ``nn.Module``, including ``module``. + Note that it only changes the memory_format, but not the semantics of each dimensions. + This function is used to facilitate the computation to adopt NHWC kernels, which + provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0 + + .. note:: + Calling ``model.to(memory_format=torch.channels_last_3d)`` is more aggressive + than the utility function ``convert_conv3d_weight_memory_format``. Any + layer with 4d weight will be affected by ``model.to``, which does not + necessarily benefit from conversion to specified ``memory_format``. + One place we are confident in is that NDHWC(channels_last_3d) conversion for + convolution in cuDNN, as it is beneficial to run convolution in NDHWC, + even in cases where we have to apply permutation to input tensors. + + Hence our strategy here is to convert only the weight of convolution to + channels_last_3d. This ensures that; + 1. Fast convolution kernels will be used, the benefit of which could + outweigh overhead of permutation (if input is not in the same format). + 2. No unnecessary permutations are applied on layers that do not benefit + from memory_format conversion. + + The optimal case is that, layers between convolution layers are channels + last compatible. Input tensor would be permuted to channels last when it + encounters the first convolution layer and stay in that memory format. + Hence following convolutions will not need to permute its input tensor. + + In case where a channels last incompatible layer is between convolution + layers, we need to permute the input tensor back to contiguous format + for that layer. The input tensor will go through the remaining layers in + contiguous format and be permuted to channels last when it encounters + another convolution layer. There's no point in propagating that + permutation to an earlier layer, as most layers are quite agnostic to + ``memory_format``. + + This claim might change when PyTorch supports fusion of permutation, as + there might have been a better spot to fuse the permutation other than + immediately before a convolution. + + Args: + module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container + ``nn.Module`` + memory_format: user specified ``memory_format``, + e.g. ``torch.channels_last`` or ``torch.contiguous_format`` + + Returns: + The original module with updated ``nn.Conv3d`` + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) + >>> input = torch.randint( + ... 1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda" + ... ) + >>> model = nn.Sequential( + >>> nn.Conv3d(8, 4, 3)).cuda().half() + >>> # This is identical to: + >>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) + >>> model = nn.utils.convert_conv3d_weight_memory_format( + ... model, torch.channels_last_3d + ... ) + >>> out = model(input) + """ + + # TODO: expand this to `_ConvNd` when channels_last support is extended + # beyond only 4d tensors. + if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)): + weight_data = module.weight.detach().clone(memory_format=memory_format) + module.weight.data = weight_data.resize_( + weight_data.size(), memory_format=memory_format + ) + for child in module.children(): + convert_conv3d_weight_memory_format(child, memory_format) + # pyrefly: ignore [bad-return] + return module + + +__all__ = [ + "convert_conv2d_weight_memory_format", + "convert_conv3d_weight_memory_format", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/parametrizations.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/parametrizations.py new file mode 100644 index 0000000000000000000000000000000000000000..3a51bbc15c5969bc742bf954243bd8b1b9333bbe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/parametrizations.py @@ -0,0 +1,630 @@ +# mypy: allow-untyped-defs +from enum import auto, Enum + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules import Module +from torch.nn.utils import parametrize + + +__all__ = ["orthogonal", "spectral_norm", "weight_norm"] + + +def _is_orthogonal(Q, eps=None): + n, k = Q.size(-2), Q.size(-1) + Id = torch.eye(k, dtype=Q.dtype, device=Q.device) + # A reasonable eps, but not too large + eps = 10.0 * n * torch.finfo(Q.dtype).eps + return torch.allclose(Q.mH @ Q, Id, atol=eps) + + +def _make_orthogonal(A): + """Assume that A is a tall matrix. + + Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative. + """ + X, tau = torch.geqrf(A) + Q = torch.linalg.householder_product(X, tau) + # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs + Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) + return Q + + +class _OrthMaps(Enum): + matrix_exp = auto() + cayley = auto() + householder = auto() + + +class _Orthogonal(Module): + base: Tensor + + def __init__( + self, weight, orthogonal_map: _OrthMaps, *, use_trivialization=True + ) -> None: + super().__init__() + + # Note [Householder complex] + # For complex tensors, it is not possible to compute the tensor `tau` necessary for + # linalg.householder_product from the reflectors. + # To see this, note that the reflectors have a shape like: + # 0 0 0 + # * 0 0 + # * * 0 + # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters + # to parametrize the unitary matrices. Saving tau on its own does not work either, because + # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise + # them as independent tensors we would not maintain the constraint + # An equivalent reasoning holds for rectangular matrices + if weight.is_complex() and orthogonal_map == _OrthMaps.householder: + raise ValueError( + "The householder parametrization does not support complex tensors." + ) + + self.shape = weight.shape + self.orthogonal_map = orthogonal_map + if use_trivialization: + self.register_buffer("base", None) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + n, k = X.size(-2), X.size(-1) + transposed = n < k + if transposed: + X = X.mT + n, k = k, n + # Here n > k and X is a tall matrix + if ( + self.orthogonal_map == _OrthMaps.matrix_exp + or self.orthogonal_map == _OrthMaps.cayley + ): + # We just need n x k - k(k-1)/2 parameters + X = X.tril() + if n != k: + # Embed into a square matrix + X = torch.cat( + [X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1 + ) + A = X - X.mH + # A is skew-symmetric (or skew-hermitian) + if self.orthogonal_map == _OrthMaps.matrix_exp: + Q = torch.matrix_exp(A) + elif self.orthogonal_map == _OrthMaps.cayley: + # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1} + Id = torch.eye(n, dtype=A.dtype, device=A.device) + Q = torch.linalg.solve( + torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5) + ) + # Q is now orthogonal (or unitary) of size (..., n, n) + if n != k: + # pyrefly: ignore [unbound-name] + Q = Q[..., :k] + # Q is now the size of the X (albeit perhaps transposed) + else: + # X is real here, as we do not support householder with complex numbers + A = X.tril(diagonal=-1) + tau = 2.0 / (1.0 + (A * A).sum(dim=-2)) + Q = torch.linalg.householder_product(A, tau) + # The diagonal of X is 1's and -1's + # We do not want to differentiate through this or update the diagonal of X hence the casting + Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2) + + if hasattr(self, "base"): + # pyrefly: ignore [unbound-name] + Q = self.base @ Q + if transposed: + # pyrefly: ignore [unbound-name] + Q = Q.mT + return Q # type: ignore[possibly-undefined] + + @torch.autograd.no_grad() + def right_inverse(self, Q: torch.Tensor) -> torch.Tensor: + if Q.shape != self.shape: + raise ValueError( + f"Expected a matrix or batch of matrices of shape {self.shape}. " + f"Got a tensor of shape {Q.shape}." + ) + + Q_init = Q + n, k = Q.size(-2), Q.size(-1) + transpose = n < k + if transpose: + Q = Q.mT + n, k = k, n + + # We always make sure to always copy Q in every path + if not hasattr(self, "base"): + # Note [right_inverse expm cayley] + # If we do not have use_trivialization=True, we just implement the inverse of the forward + # map for the Householder. To see why, think that for the Cayley map, + # we would need to find the matrix X \in R^{n x k} such that: + # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) + # A = Y - Y.mH + # cayley(A)[:, :k] + # gives the original tensor. It is not clear how to do this. + # Perhaps via some algebraic manipulation involving the QR like that of + # Corollary 2.2 in Edelman, Arias and Smith? + if ( + self.orthogonal_map == _OrthMaps.cayley + or self.orthogonal_map == _OrthMaps.matrix_exp + ): + raise NotImplementedError( + "It is not possible to assign to the matrix exponential " + "or the Cayley parametrizations when use_trivialization=False." + ) + + # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition. + # Here Q is always real because we do not support householder and complex matrices. + # See note [Householder complex] + A, tau = torch.geqrf(Q) + # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could + # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition + # The diagonal of Q is the diagonal of R from the qr decomposition + A.diagonal(dim1=-2, dim2=-1).sign_() + # Equality with zero is ok because LAPACK returns exactly zero when it does not want + # to use a particular reflection + A.diagonal(dim1=-2, dim2=-1)[tau == 0.0] *= -1 + return A.mT if transpose else A + else: + if n == k: + # We check whether Q is orthogonal + if not _is_orthogonal(Q): + Q = _make_orthogonal(Q) + else: # Is orthogonal + Q = Q.clone() + else: + # Complete Q into a full n x n orthogonal matrix + N = torch.randn( + *(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device + ) + Q = torch.cat([Q, N], dim=-1) + Q = _make_orthogonal(Q) + self.base = Q + + # It is necessary to return the -Id, as we use the diagonal for the + # Householder parametrization. Using -Id makes: + # householder(torch.zeros(m,n)) == torch.eye(m,n) + # Poor man's version of eye_like + neg_Id = torch.zeros_like(Q_init) + neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.0) + return neg_Id + + +def orthogonal( + module: Module, + name: str = "weight", + orthogonal_map: str | None = None, + *, + use_trivialization: bool = True, +) -> Module: + r"""Apply an orthogonal or unitary parametrization to a matrix or a batch of matrices. + + Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized + matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as + + .. math:: + + \begin{align*} + Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ + QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} + \end{align*} + + where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex + and the transpose when :math:`Q` is real-valued, and + :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. + In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n` + and orthonormal rows otherwise. + + If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`. + + The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor: + + - ``"matrix_exp"``/``"cayley"``: + the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_ + :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric + :math:`A` to give an orthogonal matrix. + - ``"householder"``: computes a product of Householder reflectors + (:func:`~torch.linalg.householder_product`). + + ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than + ``"householder"``, but they are slower to compute for very thin or very wide matrices. + + If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework", + where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under + ``module.parametrizations.weight[0].base``. This helps the + convergence of the parametrized layer at the expense of some extra memory use. + See `Trivializations for Gradient-Based Optimization on Manifolds`_ . + + Initial value of :math:`Q`: + If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value + of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case) + and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`). + Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``. + Otherwise, the initial value is the result of the composition of all the registered + parametrizations applied to the original tensor. + + .. note:: + This function is implemented using the parametrization functionality + in :func:`~torch.nn.utils.parametrize.register_parametrization`. + + + .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map + .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501 + + Args: + module (nn.Module): module on which to register the parametrization. + name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``. + orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``. + Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise. + use_trivialization (bool, optional): whether to use the dynamic trivialization framework. + Default: ``True``. + + Returns: + The original module with an orthogonal parametrization registered to the specified + weight + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> orth_linear = orthogonal(nn.Linear(20, 40)) + >>> orth_linear + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _Orthogonal() + ) + ) + ) + >>> # xdoctest: +IGNORE_WANT + >>> Q = orth_linear.weight + >>> torch.dist(Q.T @ Q, torch.eye(20)) + tensor(4.9332e-07) + """ + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): + raise ValueError( + f"Module '{module}' has no parameter or buffer with name '{name}'" + ) + + # We could implement this for 1-dim tensors as the maps on the sphere + # but I believe it'd bite more people than it'd help + if weight.ndim < 2: + raise ValueError( + "Expected a matrix or batch of matrices. " + f"Got a tensor of {weight.ndim} dimensions." + ) + + if orthogonal_map is None: + orthogonal_map = ( + "matrix_exp" + if weight.size(-2) == weight.size(-1) or weight.is_complex() + else "householder" + ) + + orth_enum = getattr(_OrthMaps, orthogonal_map, None) + if orth_enum is None: + raise ValueError( + 'orthogonal_map has to be one of "matrix_exp", "cayley", "householder". ' + f"Got: {orthogonal_map}" + ) + orth = _Orthogonal(weight, orth_enum, use_trivialization=use_trivialization) + parametrize.register_parametrization(module, name, orth, unsafe=True) + return module + + +class _WeightNorm(Module): + def __init__( + self, + dim: int | None = 0, + ) -> None: + super().__init__() + if dim is None: + dim = -1 + self.dim = dim + + def forward(self, weight_g, weight_v): + return torch._weight_norm(weight_v, weight_g, self.dim) + + def right_inverse(self, weight): + weight_g = torch.norm_except_dim(weight, 2, self.dim) + weight_v = weight + + return weight_g, weight_v + + +def weight_norm(module: Module, name: str = "weight", dim: int = 0): + r"""Apply weight normalization to a parameter in the given module. + + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} + + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by :attr:`name` with two parameters: one specifying the magnitude + and one specifying the direction. + + By default, with ``dim=0``, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + ``dim=None``. + + See https://arxiv.org/abs/1602.07868 + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute the norm + + Returns: + The original module with the weight norm hook + + Example:: + + >>> m = weight_norm(nn.Linear(20, 40), name='weight') + >>> m + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _WeightNorm() + ) + ) + ) + >>> m.parametrizations.weight.original0.size() + torch.Size([40, 1]) + >>> m.parametrizations.weight.original1.size() + torch.Size([40, 20]) + + """ + _weight_norm = _WeightNorm(dim) + parametrize.register_parametrization(module, name, _weight_norm, unsafe=True) + + def _weight_norm_compat_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) -> None: + g_key = f"{prefix}{name}_g" + v_key = f"{prefix}{name}_v" + if g_key in state_dict and v_key in state_dict: + original0 = state_dict.pop(g_key) + original1 = state_dict.pop(v_key) + state_dict[f"{prefix}parametrizations.{name}.original0"] = original0 + state_dict[f"{prefix}parametrizations.{name}.original1"] = original1 + + module._register_load_state_dict_pre_hook(_weight_norm_compat_hook) + return module + + +class _SpectralNorm(Module): + def __init__( + self, + weight: torch.Tensor, + n_power_iterations: int = 1, + dim: int = 0, + eps: float = 1e-12, + ) -> None: + super().__init__() + ndim = weight.ndim + if dim >= ndim or dim < -ndim: + raise IndexError( + "Dimension out of range (expected to be in range of " + f"[-{ndim}, {ndim - 1}] but got {dim})" + ) + + if n_power_iterations <= 0: + raise ValueError( + "Expected n_power_iterations to be positive, but " + f"got n_power_iterations={n_power_iterations}" + ) + self.dim = dim if dim >= 0 else dim + ndim + self.eps = eps + if ndim > 1: + # For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward) + self.n_power_iterations = n_power_iterations + weight_mat = self._reshape_weight_to_matrix(weight) + h, w = weight_mat.size() + + u = weight_mat.new_empty(h).normal_(0, 1) + v = weight_mat.new_empty(w).normal_(0, 1) + self.register_buffer("_u", F.normalize(u, dim=0, eps=self.eps)) + self.register_buffer("_v", F.normalize(v, dim=0, eps=self.eps)) + + # Start with u, v initialized to some reasonable values by performing a number + # of iterations of the power method + self._power_method(weight_mat, 15) + + def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor: + # Precondition + assert weight.ndim > 1 + + if self.dim != 0: + # permute dim to front + weight = weight.permute( + self.dim, *(d for d in range(weight.dim()) if d != self.dim) + ) + + return weight.flatten(1) + + @torch.autograd.no_grad() + def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None: + # See original note at torch/nn/utils/spectral_norm.py + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallelized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + + # Precondition + assert weight_mat.ndim > 1 + + for _ in range(n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + self._u = F.normalize( + torch.mv(weight_mat, self._v), # type: ignore[has-type] + dim=0, + eps=self.eps, + out=self._u, # type: ignore[has-type] + ) + self._v = F.normalize( + torch.mv(weight_mat.H, self._u), # type: ignore[has-type] + dim=0, + eps=self.eps, + out=self._v, # type: ignore[has-type] + ) + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + if weight.ndim == 1: + # Faster and more exact path, no need to approximate anything + return F.normalize(weight, dim=0, eps=self.eps) + else: + weight_mat = self._reshape_weight_to_matrix(weight) + if self.training: + self._power_method(weight_mat, self.n_power_iterations) + # See above on why we need to clone + u = self._u.clone(memory_format=torch.contiguous_format) + v = self._v.clone(memory_format=torch.contiguous_format) + # The proper way of computing this should be through F.bilinear, but + # it seems to have some efficiency issues: + # https://github.com/pytorch/pytorch/issues/58093 + sigma = torch.vdot(u, torch.mv(weight_mat, v)) + return weight / sigma + + def right_inverse(self, value: torch.Tensor) -> torch.Tensor: + # we may want to assert here that the passed value already + # satisfies constraints + return value + + +def spectral_norm( + module: Module, + name: str = "weight", + n_power_iterations: int = 1, + eps: float = 1e-12, + dim: int | None = None, +) -> Module: + r"""Apply spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + When applied on a vector, it simplifies to + + .. math:: + \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant + of the model. :math:`\sigma` is approximated performing one iteration of the + `power method`_ every time the weight is accessed. If the dimension of the + weight tensor is greater than 2, it is reshaped to 2D in power iteration + method to get spectral norm. + + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + .. note:: + This function is implemented using the parametrization functionality + in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a + reimplementation of :func:`torch.nn.utils.spectral_norm`. + + .. note:: + When this constraint is registered, the singular vectors associated to the largest + singular value are estimated rather than sampled at random. These are then updated + performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor + is accessed with the module on `training` mode. + + .. note:: + If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`, + is in training mode on removal, it will perform another power iteration. + If you'd like to avoid this iteration, set the module to eval mode + before its removal. + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter. Default: ``"weight"``. + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm. Default: ``1``. + eps (float, optional): epsilon for numerical stability in + calculating norms. Default: ``1e-12``. + dim (int, optional): dimension corresponding to number of outputs. + Default: ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with a new parametrization registered to the specified + weight + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> snm = spectral_norm(nn.Linear(20, 40)) + >>> snm + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _SpectralNorm() + ) + ) + ) + >>> torch.linalg.matrix_norm(snm.weight, 2) + tensor(1.0081, grad_fn=) + """ + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): + raise ValueError( + f"Module '{module}' has no parameter or buffer with name '{name}'" + ) + + if dim is None: + if isinstance( + module, + ( + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, + ), + ): + dim = 1 + else: + dim = 0 + parametrize.register_parametrization( + module, name, _SpectralNorm(weight, n_power_iterations, dim, eps) + ) + return module diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/parametrize.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/parametrize.py new file mode 100644 index 0000000000000000000000000000000000000000..28599db7bdf116f7e3af1bcd7d8576fc2fe51f9b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/parametrize.py @@ -0,0 +1,838 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import collections +import copyreg +from collections.abc import Sequence +from contextlib import contextmanager +from copy import deepcopy + +import torch +from torch import Tensor +from torch.__future__ import get_swap_module_params_on_conversion +from torch.nn.modules.container import Module, ModuleDict, ModuleList +from torch.nn.parameter import Parameter +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +__all__ = [ + "cached", + "ParametrizationList", + "register_parametrization", + "is_parametrized", + "remove_parametrizations", + "type_before_parametrizations", + "transfer_parametrizations_and_params", +] + +_cache_enabled = 0 +_cache: dict[tuple[int, str], Tensor | None] = {} + + +@contextmanager +def cached(): + r"""Context manager that enables the caching system within parametrizations registered with :func:`register_parametrization`. + + The value of the parametrized objects is computed and cached the first time + they are required when this context manager is active. The cached values are + discarded when leaving the context manager. + + This is useful when using a parametrized parameter more than once in the forward pass. + An example of this is when parametrizing the recurrent kernel of an RNN or when + sharing weights. + + The simplest way to activate the cache is by wrapping the forward pass of the neural network + + .. code-block:: python + + import torch.nn.utils.parametrize as P + + ... + with P.cached(): + output = model(inputs) + + in training and evaluation. One may also wrap the parts of the modules that use + several times the parametrized tensors. For example, the loop of an RNN with a + parametrized recurrent kernel: + + .. code-block:: python + + with P.cached(): + for x in xs: + out_rnn = self.rnn_cell(x, out_rnn) + """ + global _cache + global _cache_enabled + _cache_enabled += 1 + try: + yield + finally: + _cache_enabled -= 1 + if not _cache_enabled: + _cache = {} + + +def _register_parameter_or_buffer(module, name, X) -> None: + if isinstance(X, Parameter): + module.register_parameter(name, X) + else: + module.register_buffer(name, X) + + +def _maybe_set(dest: Tensor, src: Tensor) -> None: + should_swap = ( + get_swap_module_params_on_conversion() or is_traceable_wrapper_subclass(dest) + ) + if should_swap: + if isinstance(dest, Parameter) and not isinstance(src, Parameter): + src = Parameter(src, requires_grad=dest.requires_grad) + torch.utils.swap_tensors(dest, src) + else: + dest.set_(src) # type: ignore[call-overload] + + +class ParametrizationList(ModuleList): + r"""A sequential container that holds and manages the original parameters or buffers of a parametrized :class:`torch.nn.Module`. + + It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]`` + has been parametrized with :func:`register_parametrization`. + + If the first registered parametrization has a ``right_inverse`` that returns one tensor or + does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity), + it will hold the tensor under the name ``original``. + If it has a ``right_inverse`` that returns more than one tensor, these will be registered as + ``original0``, ``original1``, ... + + .. warning:: + This class is used internally by :func:`register_parametrization`. It is documented + here for completeness. It shall not be instantiated by the user. + + Args: + modules (sequence): sequence of modules representing the parametrizations + original (Parameter or Tensor): parameter or buffer that is parametrized + unsafe (bool): a boolean flag that denotes whether the parametrization + may change the dtype and shape of the tensor. Default: `False` + Warning: the parametrization is not checked for consistency upon registration. + Enable this flag at your own risk. + """ + + original: Tensor + unsafe: bool + + def __init__( + self, + modules: Sequence[Module], + original: Tensor | Parameter, + unsafe: bool = False, + ) -> None: + # We require this because we need to treat differently the first parametrization + # This should never throw, unless this class is used from the outside + if len(modules) == 0: + raise ValueError("ParametrizationList requires one or more modules.") + + super().__init__(modules) + self.unsafe = unsafe + + # In plain words: + # module.weight must keep its dtype and shape. + # Furthermore, if there is no right_inverse or the right_inverse returns a tensor, + # this should be of the same dtype as the original tensor + # + # We check that the following invariants hold: + # X = module.weight + # Y = param.right_inverse(X) + # assert isinstance(Y, Tensor) or + # (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y)) + # Z = param(Y) if isinstance(Y, Tensor) else param(*Y) + # # Consistency checks + # assert X.dtype == Z.dtype and X.shape == Z.shape + # # If it has one input, this allows to be able to use set_ to be able to + # # move data to/from the original tensor without changing its id (which is what the + # # optimizer uses to track parameters) + # if isinstance(Y, Tensor) + # assert X.dtype == Y.dtype + # Below we use original = X, new = Y + + original_shape = original.shape + original_dtype = original.dtype + + # Compute new + with torch.no_grad(): + new = original + for module in reversed(self): # type: ignore[call-overload] + if hasattr(module, "right_inverse"): + try: + new = module.right_inverse(new) # type: ignore[operator] + except NotImplementedError: + pass + # else, or if it throws, we assume that right_inverse is the identity + + if not isinstance(new, Tensor) and not isinstance(new, Sequence): + raise ValueError( + "'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). " + f"Got {type(new).__name__}" + ) + + # Set the number of original tensors + self.is_tensor = isinstance(new, Tensor) + self.ntensors = 1 if self.is_tensor else len(new) + + # Register the tensor(s) + if self.is_tensor: + # pyrefly: ignore [missing-attribute] + if original.dtype != new.dtype: + raise ValueError( + "When `right_inverse` outputs one tensor, it may not change the dtype.\n" + f"original.dtype: {original.dtype}\n" + # pyrefly: ignore [missing-attribute] + f"right_inverse(original).dtype: {new.dtype}" + ) + + # pyrefly: ignore [missing-attribute] + if original.device != new.device: + raise ValueError( + "When `right_inverse` outputs one tensor, it may not change the device.\n" + f"original.device: {original.device}\n" + # pyrefly: ignore [missing-attribute] + f"right_inverse(original).device: {new.device}" + ) + + # Set the original to original so that the user does not need to re-register the parameter + # manually in the optimiser + with torch.no_grad(): + # pyrefly: ignore [bad-argument-type] + _maybe_set(original, new) + _register_parameter_or_buffer(self, "original", original) + else: + for i, originali in enumerate(new): + if not isinstance(originali, Tensor): + raise ValueError( + "'right_inverse' must return a Tensor or a Sequence of tensors " + "(list, tuple...). " + f"Got element {i} of the sequence with type {type(originali).__name__}." + ) + + # If the original tensor was a Parameter that required grad, we expect the user to + # add the new parameters to the optimizer after registering the parametrization + # (this is documented) + if isinstance(original, Parameter): + originali = Parameter(originali, original.requires_grad) + originali.requires_grad_(original.requires_grad) + _register_parameter_or_buffer(self, f"original{i}", originali) + + if not self.unsafe: + # Consistency checks: + # Since f : A -> B, right_inverse : B -> A, Z and original should live in B + # Z = forward(right_inverse(original)) + Z = self() + if not isinstance(Z, Tensor): + raise ValueError( + f"A parametrization must return a tensor. Got {type(Z).__name__}." + ) + if Z.dtype != original_dtype: + raise ValueError( + "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n" + f"unparametrized dtype: {original_dtype}\n" + f"parametrized dtype: {Z.dtype}" + ) + if Z.shape != original_shape: + raise ValueError( + "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n" + f"unparametrized shape: {original_shape}\n" + f"parametrized shape: {Z.shape}" + ) + + def right_inverse(self, value: Tensor) -> None: + r"""Call the ``right_inverse`` methods of the parametrizations in the inverse registration order. + + Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor + or in ``self.original0``, ``self.original1``, ... if it outputs several. + + Args: + value (Tensor): Value to which initialize the module + """ + # All the exceptions in this function should almost never throw. + # They could throw if, for example, right_inverse function returns a different + # dtype when given a different input, which should most likely be caused by a + # bug in the user's code + + with torch.no_grad(): + # See https://github.com/pytorch/pytorch/issues/53103 + for module in reversed(self): # type: ignore[call-overload] + if hasattr(module, "right_inverse"): + value = module.right_inverse(value) # type: ignore[operator] + else: + raise RuntimeError( + f"parametrization {type(module).__name__} does not implement " + "right_inverse." + ) + if self.is_tensor: + # These exceptions should only throw when a right_inverse function does not + # return the same dtype for every input, which should most likely be caused by a bug + if not isinstance(value, Tensor): + raise ValueError( + f"`right_inverse` should return a tensor. Got {type(value).__name__}" + ) + if value.dtype != self.original.dtype: + raise ValueError( + f"The tensor returned by `right_inverse` has dtype {value.dtype} " + f"while `original` has dtype {self.original.dtype}" + ) + # We know that the result is going to have the same dtype + _maybe_set(self.original, value) + else: + if not isinstance(value, collections.abc.Sequence): + raise ValueError( + "'right_inverse' must return a sequence of tensors. " + f"Got {type(value).__name__}." + ) + if len(value) != self.ntensors: + raise ValueError( + "'right_inverse' must return a sequence of tensors of length " + f"{self.ntensors}. Got a sequence of length {len(value)}." + ) + for i, tensor in enumerate(value): + original_i = getattr(self, f"original{i}") + if not isinstance(tensor, Tensor): + raise ValueError( + f"`right_inverse` must return a sequence of tensors. " + f"Got element {i} of type {type(tensor).__name__}" + ) + if original_i.dtype != tensor.dtype: + raise ValueError( + f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} " + f"while `original{i}` has dtype {original_i.dtype}" + ) + _maybe_set(original_i, tensor) + + def forward(self) -> Tensor: + if torch.jit.is_scripting(): + raise RuntimeError("Parametrization is not working with scripting.") + # Unpack the originals for the first parametrization + if self.is_tensor: + x = self[0](self.original) + else: + originals = (getattr(self, f"original{i}") for i in range(self.ntensors)) + x = self[0](*originals) + # It's not possible to call self[1:] here, so we have to be a bit more cryptic + # Also we want to skip all non-integer keys + curr_idx = 1 + while hasattr(self, str(curr_idx)): + x = self[curr_idx](x) + curr_idx += 1 + return x + + +def _inject_new_class(module: Module) -> None: + r"""Set up a module to be parametrized. + + This works by substituting the class of the module by a class + that extends it to be able to inject a property + + Args: + module (nn.Module): module into which to inject the property + """ + cls = module.__class__ + + def default_deepcopy(self, memo): + # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class. + obj = memo.get(id(self), None) + if obj is not None: + return obj + replica = self.__new__(self.__class__) + memo[id(self)] = replica + replica.__dict__ = deepcopy(self.__dict__, memo) + # Also save all slots if they exist. + slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] + for slot in slots_to_save: + if hasattr(self, slot): + setattr(replica, slot, deepcopy(getattr(self, slot), memo)) + return replica + + def getstate(self): + raise RuntimeError( + "Serialization of parametrized modules is only " + "supported through state_dict(). See:\n" + "https://pytorch.org/tutorials/beginner/saving_loading_models.html" + "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" + ) + + dct = {"__getstate__": getstate} + # We don't allow serialization of parametrized modules but should still allow deepcopying. + # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists. + if not hasattr(cls, "__deepcopy__"): + dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment] + + param_cls = type( + f"Parametrized{cls.__name__}", + (cls,), + dct, + ) + + module.__class__ = param_cls + + +def _inject_property(module: Module, tensor_name: str) -> None: + r"""Injects a property into module[tensor_name]. + + It assumes that the class in the module has already been modified from its + original one using _inject_new_class and that the tensor under :attr:`tensor_name` + has already been moved out + + Args: + module (nn.Module): module into which to inject the property + tensor_name (str): name of the name of the property to create + """ + # We check the precondition. + # This should never fire if register_parametrization is correctly implemented + assert not hasattr(module, tensor_name) + + @torch.jit.unused + def get_cached_parametrization(parametrization) -> Tensor: + global _cache + key = (id(module), tensor_name) + tensor = _cache.get(key) + if tensor is None: + tensor = parametrization() + _cache[key] = tensor + return tensor + + def get_parametrized(self) -> Tensor: + if torch.jit.is_scripting(): + raise RuntimeError("Parametrization is not working with scripting.") + parametrization = self.parametrizations[tensor_name] + # pyrefly: ignore [redundant-condition] + if _cache_enabled: + if torch.jit.is_scripting(): + # Scripting + raise RuntimeError( + "Caching is not implemented for scripting. " + "Either disable caching or avoid scripting." + ) + elif torch._C._get_tracing_state() is not None: + # Tracing + raise RuntimeError( + "Cannot trace a model while caching parametrizations." + ) + else: + return get_cached_parametrization(parametrization) + else: + # If caching is not active, this function just evaluates the parametrization + return parametrization() + + def set_original(self, value: Tensor) -> None: + if torch.jit.is_scripting(): + raise RuntimeError("Parametrization is not working with scripting.") + self.parametrizations[tensor_name].right_inverse(value) + + setattr(module.__class__, tensor_name, property(get_parametrized, set_original)) + + +def register_parametrization( + module: Module, + tensor_name: str, + parametrization: Module, + *, + unsafe: bool = False, +) -> Module: + r"""Register a parametrization to a tensor in a module. + + Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``, + the module will return the parametrized version ``parametrization(module.weight)``. + If the original tensor requires a gradient, the backward pass will differentiate + through :attr:`parametrization`, and the optimizer will update the tensor accordingly. + + The first time that a module registers a parametrization, this function will add an attribute + ``parametrizations`` to the module of type :class:`~ParametrizationList`. + + The list of parametrizations on the tensor ``weight`` will be accessible under + ``module.parametrizations.weight``. + + The original tensor will be accessible under + ``module.parametrizations.weight.original``. + + Parametrizations may be concatenated by registering several parametrizations + on the same attribute. + + The training mode of a registered parametrization is updated on registration + to match the training mode of the host module + + Parametrized parameters and buffers have an inbuilt caching system that can be activated + using the context manager :func:`cached`. + + A :attr:`parametrization` may optionally implement a method with signature + + .. code-block:: python + + def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] + + This method is called on the unparametrized tensor when the first parametrization + is registered to compute the initial value of the original tensor. + If this method is not implemented, the original tensor will be just the unparametrized tensor. + + If all the parametrizations registered on a tensor implement `right_inverse` it is possible + to initialize a parametrized tensor by assigning to it, as shown in the example below. + + It is possible for the first parametrization to depend on several inputs. + This may be implemented returning a tuple of tensors from ``right_inverse`` + (see the example implementation of a ``RankOne`` parametrization below). + + In this case, the unconstrained tensors are also located under ``module.parametrizations.weight`` + with names ``original0``, ``original1``,... + + .. note:: + + If unsafe=False (default) both the forward and right_inverse methods will be called + once to perform a number of consistency checks. + If unsafe=True, then right_inverse will be called if the tensor is not parametrized, + and nothing will be called otherwise. + + .. note:: + + In most situations, ``right_inverse`` will be a function such that + ``forward(right_inverse(X)) == X`` (see + `right inverse `_). + Sometimes, when the parametrization is not surjective, it may be reasonable + to relax this. + + .. warning:: + + If a parametrization depends on several inputs, :func:`~register_parametrization` + will register a number of new parameters. If such parametrization is registered + after the optimizer is created, these new parameters will need to be added manually + to the optimizer. See :meth:`torch.Optimizer.add_param_group`. + + Args: + module (nn.Module): module on which to register the parametrization + tensor_name (str): name of the parameter or buffer on which to register + the parametrization + parametrization (nn.Module): the parametrization to register + Keyword args: + unsafe (bool): a boolean flag that denotes whether the parametrization + may change the dtype and shape of the tensor. Default: `False` + Warning: the parametrization is not checked for consistency upon registration. + Enable this flag at your own risk. + + Raises: + ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name` + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> import torch + >>> import torch.nn as nn + >>> import torch.nn.utils.parametrize as P + >>> + >>> class Symmetric(nn.Module): + >>> def forward(self, X): + >>> return X.triu() + X.triu(1).T # Return a symmetric matrix + >>> + >>> def right_inverse(self, A): + >>> return A.triu() + >>> + >>> m = nn.Linear(5, 5) + >>> P.register_parametrization(m, "weight", Symmetric()) + >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric + True + >>> A = torch.rand(5, 5) + >>> A = A + A.T # A is now symmetric + >>> m.weight = A # Initialize the weight to be the symmetric matrix A + >>> print(torch.allclose(m.weight, A)) + True + + >>> class RankOne(nn.Module): + >>> def forward(self, x, y): + >>> # Form a rank 1 matrix multiplying two vectors + >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) + >>> + >>> def right_inverse(self, Z): + >>> # Project Z onto the rank 1 matrices + >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) + >>> # Return rescaled singular vectors + >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) + >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt + >>> + >>> linear_rank_one = P.register_parametrization( + ... nn.Linear(4, 4), "weight", RankOne() + ... ) + >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) + 1 + + """ + parametrization.train(module.training) + if is_parametrized(module, tensor_name): + # Correctness checks. + # If A is the space of tensors with shape and dtype equal to module.weight + # we check that parametrization.forward and parametrization.right_inverse are + # functions from A to A + if not unsafe: + Y = getattr(module, tensor_name) + X = parametrization(Y) + if not isinstance(X, Tensor): + raise ValueError( + f"A parametrization must return a tensor. Got {type(X).__name__}." + ) + if X.dtype != Y.dtype: + raise ValueError( + "Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"parametrization(module.{tensor_name}).dtype: {X.dtype}" + ) + if X.shape != Y.shape: + raise ValueError( + "Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"parametrization(module.{tensor_name}).shape: {X.shape}" + ) + if hasattr(parametrization, "right_inverse"): + try: + Z = parametrization.right_inverse(X) # type: ignore[operator] + except NotImplementedError: + pass + else: + if not isinstance(Z, Tensor): + raise ValueError( + f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}" + ) + if Z.dtype != Y.dtype: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same dtype " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"returned dtype: {Z.dtype}" + ) + if Z.shape != Y.shape: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same shape " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"returned shape: {Z.shape}" + ) + # else right_inverse is assumed to be the identity + + # add the new parametrization to the parametrization list + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + module.parametrizations[tensor_name].append(parametrization) # type: ignore[operator] + # If unsafe was True in previous parametrization, keep it enabled + module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr, operator] + elif tensor_name in module._buffers or tensor_name in module._parameters: + # Set the parametrization mechanism + # Fetch the original buffer or parameter + original = getattr(module, tensor_name) + # We create this early to check for possible errors + parametrizations = ParametrizationList( + [parametrization], original, unsafe=unsafe + ) + # Delete the previous parameter or buffer + delattr(module, tensor_name) + # If this is the first parametrization registered on the module, + # we prepare the module to inject the property + if not is_parametrized(module): + # Change the class + _inject_new_class(module) + # Inject a ``ModuleDict`` into the instance under module.parametrizations + module.parametrizations = ModuleDict() + # Add a property into the class + _inject_property(module, tensor_name) + # Add a ParametrizationList + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + module.parametrizations[tensor_name] = parametrizations + else: + raise ValueError( + f"Module '{module}' does not have a parameter, a buffer, or a " + f"parametrized element with name '{tensor_name}'" + ) + return module + + +def is_parametrized(module: Module, tensor_name: str | None = None) -> bool: + r"""Determine if a module has a parametrization. + + Args: + module (nn.Module): module to query + tensor_name (str, optional): name of the parameter in the module + Default: ``None`` + Returns: + ``True`` if :attr:`module` has a parametrization for the parameter named :attr:`tensor_name`, + or if it has any parametrization when :attr:`tensor_name` is ``None``; + otherwise ``False`` + """ + parametrizations = getattr(module, "parametrizations", None) + if parametrizations is None or not isinstance(parametrizations, ModuleDict): + return False + if tensor_name is None: + # Check that there is at least one parametrized buffer or Parameter + return len(parametrizations) > 0 + else: + return tensor_name in parametrizations + + +def remove_parametrizations( + module: Module, + tensor_name: str, + leave_parametrized: bool = True, +) -> Module: + r"""Remove the parametrizations on a tensor in a module. + + - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to + its current output. In this case, the parametrization shall not change the ``dtype`` + of the tensor. + - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to + the unparametrised tensor in ``module.parametrizations[tensor_name].original``. + This is only possible when the parametrization depends on just one tensor. + + Args: + module (nn.Module): module from which remove the parametrization + tensor_name (str): name of the parametrization to be removed + leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. + Default: ``True`` + + Returns: + Module: module + + Raises: + ValueError: if ``module[tensor_name]`` is not parametrized + ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors + """ + if not is_parametrized(module, tensor_name): + raise ValueError( + f"Module {module} does not have a parametrization on {tensor_name}" + ) + + # Fetch the original tensor + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + parametrizations = module.parametrizations[tensor_name] + # pyrefly: ignore [invalid-argument] + if parametrizations.is_tensor: + original = parametrizations.original + assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor" + if leave_parametrized: + with torch.no_grad(): + t = getattr(module, tensor_name) + # We know they have the same dtype because we have checked this when registering the + # parametrizations. As such, we can use set_ + # We do this so that the parameter does not to change the id() + # This way the user does not need to update the optimizer + with torch.no_grad(): + if type(original) is torch.Tensor: + _maybe_set(original, t) + else: + try: + _maybe_set(original, t) + except RuntimeError as e: + # TODO: Fix this for tensor subclasses that are parameters: + # RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach(). + raise RuntimeError( + "Calling remove_parametrizations() with leave_parametrized=True " + "for a parameter that is an instance of a tensor subclass requires " + "set_() to be implemented correctly for the tensor subclass." + "Alternatively, one can opt into the swap_tensors path" + "Either set leave_parametrized=False or provide a working implementation" + "for set_() in the tensor subclass or set " + "torch.__future__.set_swap_module_params_on_conversion(True)." + ) from e + else: + if leave_parametrized: + # We cannot use no_grad because we need to know whether one or more + # original tensors required grad + t = getattr(module, tensor_name) + # We'll have to trust the user to add it to the optimizer + original = Parameter(t) if t.requires_grad else t + else: + raise ValueError( + "Cannot leave unparametrized (`leave_parametrized=False`) a tensor " + "that is parametrized in terms of a sequence of tensors." + ) + + # Delete the property that manages the parametrization + delattr(module.__class__, tensor_name) + # Delete the ParametrizationList + del module.parametrizations[tensor_name] + + # Restore the parameter / buffer into the main class + _register_parameter_or_buffer(module, tensor_name, original) + + # Roll back the parametrized class if no other buffer or parameter + # is currently parametrized in this class + if not is_parametrized(module): + delattr(module, "parametrizations") + # Restore class + orig_cls = module.__class__.__bases__[0] + module.__class__ = orig_cls + return module + + +def type_before_parametrizations(module: Module) -> type: + r"""Return the module type before parametrizations were applied and if not, then it returns the module type. + + Args: + module (nn.Module): module to get type of + """ + if is_parametrized(module): + return module.__class__.__bases__[0] + else: + return type(module) + + +def transfer_parametrizations_and_params( + from_module: Module, + to_module: Module, + tensor_name: str | None = None, +) -> Module: + r"""Transfer parametrizations and the parameters they parametrize from :attr:`from_module` to :attr:`to_module`. + + If :attr:`tensor_name` is specified, only transfers the specified parameter, otherwise + transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them. + Does nothing if from_module is not parametrized. + + Args: + from_module (nn.Module): module to transfer from + to_module (nn.Module): module to transfer to + tensor_name (str, optional): parameter to transfer + + Returns: + Module: to_module + """ + if is_parametrized(from_module): + assert isinstance(from_module.parametrizations, ModuleDict) # for mypy + + # get list of all params or the single param to transfer + parameters_to_transfer: list | ModuleDict = ( + from_module.parametrizations if tensor_name is None else [tensor_name] + ) + + assert hasattr(parameters_to_transfer, "__iter__") # for mypy + for parameter_name in parameters_to_transfer: + # initialize the to-be-transferred param in to_module if it doesn't exist already + if not hasattr(to_module, parameter_name): + setattr( + to_module, + parameter_name, + Parameter(getattr(from_module, parameter_name)), + ) + + # apply the params's parametrizations to to_module + for param_func in from_module.parametrizations[ # type: ignore[attr-defined] + parameter_name + ]: + register_parametrization(to_module, parameter_name, param_func) + assert isinstance(to_module.parametrizations, ModuleDict) # for mypy + + # make values match, original values can be stored in either original or + # original0, original1..., need to check both cases + if hasattr(from_module.parametrizations[parameter_name], "original"): + to_module.parametrizations[ + parameter_name + ].original = from_module.parametrizations[parameter_name].original + else: + num = 0 + orig_num = "original" + str(num) + # loop through each original# until all values have been set + while hasattr(from_module.parametrizations[parameter_name], orig_num): + setattr( + to_module.parametrizations[parameter_name], + orig_num, + getattr(from_module.parametrizations[parameter_name], orig_num), + ) + num = num + 1 + orig_num = "original" + str(num) + + return to_module diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/prune.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/prune.py new file mode 100644 index 0000000000000000000000000000000000000000..827bf19ed4bea00723e38d2ca60dcf14cc3abbc2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/prune.py @@ -0,0 +1,1385 @@ +# mypy: allow-untyped-defs +r"""Pruning methods.""" + +import numbers +from abc import ABC, abstractmethod +from collections.abc import Iterable + +import torch + + +class BasePruningMethod(ABC): + r"""Abstract base class for creation of new pruning techniques. + + Provides a skeleton for customization requiring the overriding of methods + such as :meth:`compute_mask` and :meth:`apply`. + """ + + _tensor_name: str + + def __call__(self, module, inputs): + r"""Multiply the mask into original tensor and store the result. + + Multiplies the mask (stored in ``module[name + '_mask']``) + into the original tensor (stored in ``module[name + '_orig']``) + and stores the result into ``module[name]`` by using :meth:`apply_mask`. + + Args: + module (nn.Module): module containing the tensor to prune + inputs: not used. + """ + setattr(module, self._tensor_name, self.apply_mask(module)) + + @abstractmethod + def compute_mask(self, t, default_mask): + r"""Compute and returns a mask for the input tensor ``t``. + + Starting from a base ``default_mask`` (which should be a mask of ones + if the tensor has not been pruned yet), generate a random mask to + apply on top of the ``default_mask`` according to the specific pruning + method recipe. + + Args: + t (torch.Tensor): tensor representing the importance scores of the + parameter to prune. + default_mask (torch.Tensor): Base mask from previous pruning + iterations, that need to be respected after the new mask is + applied. Same dims as ``t``. + + Returns: + mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` + """ + + def apply_mask(self, module): + r"""Simply handles the multiplication between the parameter being pruned and the generated mask. + + Fetches the mask and the original tensor from the module + and returns the pruned version of the tensor. + + Args: + module (nn.Module): module containing the tensor to prune + + Returns: + pruned_tensor (torch.Tensor): pruned version of the input tensor + """ + # to carry out the multiplication, the mask needs to have been computed, + # so the pruning method must know what tensor it's operating on + assert self._tensor_name is not None, ( + f"Module {module} has to be pruned" + ) # this gets set in apply() + mask = getattr(module, self._tensor_name + "_mask") + orig = getattr(module, self._tensor_name + "_orig") + pruned_tensor = mask.to(dtype=orig.dtype) * orig + return pruned_tensor + + @classmethod + def apply(cls, module, name, *args, importance_scores=None, **kwargs): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + args: arguments passed on to a subclass of + :class:`BasePruningMethod` + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the + corresponding elements in the parameter being pruned. + If unspecified or None, the parameter will be used in its place. + kwargs: keyword arguments passed on to a subclass of a + :class:`BasePruningMethod` + """ + + def _get_composite_method(cls, module, name, *args, **kwargs): + # Check if a pruning method has already been applied to + # `module[name]`. If so, store that in `old_method`. + old_method = None + found = 0 + # there should technically be only 1 hook with hook.name == name + # assert this using `found` + hooks_to_remove = [] + for k, hook in module._forward_pre_hooks.items(): + # if it exists, take existing thing, remove hook, then + # go through normal thing + if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: + old_method = hook + hooks_to_remove.append(k) + found += 1 + assert found <= 1, ( + f"Avoid adding multiple pruning hooks to the\ + same tensor {name} of module {module}. Use a PruningContainer." + ) + + for k in hooks_to_remove: + del module._forward_pre_hooks[k] + + # Apply the new pruning method, either from scratch or on top of + # the previous one. + method = cls(*args, **kwargs) # new pruning + # Have the pruning method remember what tensor it's been applied to + method._tensor_name = name + + # combine `methods` with `old_method`, if `old_method` exists + if old_method is not None: # meaning that there was a hook + # if the hook is already a pruning container, just add the + # new pruning method to the container + if isinstance(old_method, PruningContainer): + old_method.add_pruning_method(method) + method = old_method # rename old_method --> method + + # if the hook is simply a single pruning method, create a + # container, add the old pruning method and the new one + elif isinstance(old_method, BasePruningMethod): + container = PruningContainer(old_method) + # Have the pruning method remember the name of its tensor + # setattr(container, '_tensor_name', name) + container.add_pruning_method(method) + method = container # rename container --> method + return method + + method = _get_composite_method(cls, module, name, *args, **kwargs) + # at this point we have no forward_pre_hooks but we could have an + # active reparameterization of the tensor if another pruning method + # had been applied (in which case `method` would be a PruningContainer + # and not a simple pruning method). + + # Pruning is to be applied to the module's tensor named `name`, + # starting from the state it is found in prior to this iteration of + # pruning. The pruning mask is calculated based on importances scores. + + orig = getattr(module, name) + if importance_scores is not None: + assert importance_scores.shape == orig.shape, ( + f"importance_scores should have the same shape as parameter {name} of {module}" + ) + else: + importance_scores = orig + + # If this is the first time pruning is applied, take care of moving + # the original tensor to a new parameter called name + '_orig' and + # and deleting the original parameter + if not isinstance(method, PruningContainer): + # copy `module[name]` to `module[name + '_orig']` + module.register_parameter(name + "_orig", orig) + # temporarily delete `module[name]` + del module._parameters[name] + default_mask = torch.ones_like(orig) # temp + # If this is not the first time pruning is applied, all of the above + # has been done before in a previous pruning iteration, so we're good + # to go + else: + default_mask = ( + getattr(module, name + "_mask") + .detach() + .clone(memory_format=torch.contiguous_format) + ) + + # Use try/except because if anything goes wrong with the mask + # computation etc., you'd want to roll back. + try: + # get the final mask, computed according to the specific method + mask = method.compute_mask(importance_scores, default_mask=default_mask) + # reparameterize by saving mask to `module[name + '_mask']`... + module.register_buffer(name + "_mask", mask) + # ... and the new pruned tensor to `module[name]` + setattr(module, name, method.apply_mask(module)) + # associate the pruning method to the module via a hook to + # compute the function before every forward() (compile by run) + module.register_forward_pre_hook(method) + + except Exception as e: + if not isinstance(method, PruningContainer): + orig = getattr(module, name + "_orig") + module.register_parameter(name, orig) + del module._parameters[name + "_orig"] + raise e + + return method + + def prune(self, t, default_mask=None, importance_scores=None): + r"""Compute and returns a pruned version of input tensor ``t``. + + According to the pruning rule specified in :meth:`compute_mask`. + + Args: + t (torch.Tensor): tensor to prune (of same dimensions as + ``default_mask``). + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as ``t``) used to compute mask for pruning ``t``. + The values in this tensor indicate the importance of the + corresponding elements in the ``t`` that is being pruned. + If unspecified or None, the tensor ``t`` will be used in its place. + default_mask (torch.Tensor, optional): mask from previous pruning + iteration, if any. To be considered when determining what + portion of the tensor that pruning should act on. If None, + default to a mask of ones. + + Returns: + pruned version of tensor ``t``. + """ + if importance_scores is not None: + assert importance_scores.shape == t.shape, ( + "importance_scores should have the same shape as tensor t" + ) + else: + importance_scores = t + default_mask = default_mask if default_mask is not None else torch.ones_like(t) + return t * self.compute_mask(importance_scores, default_mask=default_mask) + + def remove(self, module) -> None: + r"""Remove the pruning reparameterization from a module. + + The pruned parameter named ``name`` remains permanently pruned, + and the parameter named ``name+'_orig'`` is removed from the parameter list. + Similarly, the buffer named ``name+'_mask'`` is removed from the buffers. + + Note: + Pruning itself is NOT undone or reversed! + """ + # before removing pruning from a tensor, it has to have been applied + assert self._tensor_name is not None, ( + f"Module {module} has to be pruned before pruning can be removed" + ) # this gets set in apply() + + # to update module[name] to latest trained weights + weight = self.apply_mask(module) # masked weights + + # delete and reset + if hasattr(module, self._tensor_name): + delattr(module, self._tensor_name) + orig = module._parameters[self._tensor_name + "_orig"] + orig.data = weight.data + del module._parameters[self._tensor_name + "_orig"] + del module._buffers[self._tensor_name + "_mask"] + setattr(module, self._tensor_name, orig) + + +class PruningContainer(BasePruningMethod): + """Container holding a sequence of pruning methods for iterative pruning. + + Keeps track of the order in which pruning methods are applied and handles + combining successive pruning calls. + + Accepts as argument an instance of a BasePruningMethod or an iterable of + them. + """ + + def __init__(self, *args) -> None: + self._pruning_methods: tuple[BasePruningMethod, ...] = () + if not isinstance(args, Iterable): # only 1 item + self._tensor_name = args._tensor_name + self.add_pruning_method(args) + # pyrefly: ignore [bad-argument-type] + elif len(args) == 1: # only 1 item in a tuple + # pyrefly: ignore [index-error] + self._tensor_name = args[0]._tensor_name + # pyrefly: ignore [index-error] + self.add_pruning_method(args[0]) + else: # manual construction from list or other iterable (or no args) + for method in args: + self.add_pruning_method(method) + + def add_pruning_method(self, method) -> None: + r"""Add a child pruning ``method`` to the container. + + Args: + method (subclass of BasePruningMethod): child pruning method + to be added to the container. + """ + # check that we're adding a pruning method to the container + if not isinstance(method, BasePruningMethod) and method is not None: + raise TypeError(f"{type(method)} is not a BasePruningMethod subclass") + elif method is not None and self._tensor_name != method._tensor_name: + raise ValueError( + "Can only add pruning methods acting on " + f"the parameter named '{self._tensor_name}' to PruningContainer {self}." + + f" Found '{method._tensor_name}'" + ) + # if all checks passed, add to _pruning_methods tuple + self._pruning_methods += (method,) # type: ignore[operator] + + def __len__(self) -> int: + return len(self._pruning_methods) + + def __iter__(self): + return iter(self._pruning_methods) + + def __getitem__(self, idx): + return self._pruning_methods[idx] + + def compute_mask(self, t, default_mask): + r"""Apply the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``. + + The new partial mask should be computed on the entries or channels + that were not zeroed out by the ``default_mask``. + Which portions of the tensor ``t`` the new mask will be calculated from + depends on the ``PRUNING_TYPE`` (handled by the type handler): + + * for 'unstructured', the mask will be computed from the raveled + list of nonmasked entries; + + * for 'structured', the mask will be computed from the nonmasked + channels in the tensor; + + * for 'global', the mask will be computed across all entries. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + (of same dimensions as ``default_mask``). + default_mask (torch.Tensor): mask from previous pruning iteration. + + Returns: + mask (torch.Tensor): new mask that combines the effects + of the ``default_mask`` and the new mask from the current + pruning ``method`` (of same dimensions as ``default_mask`` and + ``t``). + """ + + def _combine_masks(method, t, mask): + r"""Combine the masks from all pruning methods and returns a new mask. + + Args: + method (a BasePruningMethod subclass): pruning method + currently being applied. + t (torch.Tensor): tensor representing the parameter to prune + (of same dimensions as mask). + mask (torch.Tensor): mask from previous pruning iteration + + Returns: + new_mask (torch.Tensor): new mask that combines the effects + of the old mask and the new mask from the current + pruning method (of same dimensions as mask and t). + """ + new_mask = mask # start off from existing mask + new_mask = new_mask.to(dtype=t.dtype) + + # compute a slice of t onto which the new pruning method will operate + if method.PRUNING_TYPE == "unstructured": + # prune entries of t where the mask is 1 + slc = mask == 1 + + # for struct pruning, exclude channels that have already been + # entirely pruned + elif method.PRUNING_TYPE == "structured": + if not hasattr(method, "dim"): + raise AttributeError( + "Pruning methods of PRUNING_TYPE " + '"structured" need to have the attribute `dim` defined.' + ) + + # find the channels to keep by removing the ones that have been + # zeroed out already (i.e. where sum(entries) == 0) + n_dims = t.dim() # "is this a 2D tensor? 3D? ..." + dim = method.dim + # convert negative indexing + if dim < 0: + dim = n_dims + dim + # if dim is still negative after subtracting it from n_dims + if dim < 0: + raise IndexError( + f"Index is out of bounds for tensor with dimensions {n_dims}" + ) + # find channels along dim = dim that aren't already tots 0ed out + keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0 + # create slice to identify what to prune + slc = [slice(None)] * n_dims + slc[dim] = keep_channel + + elif method.PRUNING_TYPE == "global": + n_dims = len(t.shape) # "is this a 2D tensor? 3D? ..." + slc = [slice(None)] * n_dims + + else: + raise ValueError(f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}") + + # compute the new mask on the unpruned slice of the tensor t + if isinstance(slc, list): + slc = tuple(slc) + partial_mask = method.compute_mask(t[slc], default_mask=mask[slc]) + new_mask[slc] = partial_mask.to(dtype=new_mask.dtype) + + return new_mask + + method = self._pruning_methods[-1] + mask = _combine_masks(method, t, default_mask) + return mask + + +class Identity(BasePruningMethod): + r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones.""" + + PRUNING_TYPE = "unstructured" + + def compute_mask(self, t, default_mask): + mask = default_mask + return mask + + @classmethod + def apply(cls, module, name): # type: ignore[override] + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + """ + return super().apply(module, name) + + +class RandomUnstructured(BasePruningMethod): + r"""Prune (currently unpruned) units in a tensor at random. + + Args: + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + """ + + PRUNING_TYPE = "unstructured" + + def __init__(self, amount) -> None: + # Check range of validity of pruning amount + _validate_pruning_amount_init(amount) + self.amount = amount + + def compute_mask(self, t, default_mask): + # Check that the amount of units to prune is not > than the number of + # parameters in t + tensor_size = t.nelement() + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + mask = default_mask.clone(memory_format=torch.contiguous_format) + + if nparams_toprune != 0: # k=0 not supported by torch.kthvalue + prob = torch.rand_like(t) + topk = torch.topk(prob.view(-1), k=nparams_toprune) + mask.view(-1)[topk.indices] = 0 + + return mask + + @classmethod + def apply(cls, module, name, amount): # type: ignore[override] + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + """ + return super().apply(module, name, amount=amount) + + +class L1Unstructured(BasePruningMethod): + r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm. + + Args: + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + """ + + PRUNING_TYPE = "unstructured" + + def __init__(self, amount) -> None: + # Check range of validity of pruning amount + _validate_pruning_amount_init(amount) + self.amount = amount + + def compute_mask(self, t, default_mask): + # Check that the amount of units to prune is not > than the number of + # parameters in t + tensor_size = t.nelement() + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + mask = default_mask.clone(memory_format=torch.contiguous_format) + + if nparams_toprune != 0: # k=0 not supported by torch.kthvalue + # largest=True --> top k; largest=False --> bottom k + # Prune the smallest k + topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False) + # topk will have .indices and .values + mask.view(-1)[topk.indices] = 0 + + return mask + + @classmethod + def apply(cls, module, name, amount, importance_scores=None): # type: ignore[override] + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + """ + return super().apply( + module, name, amount=amount, importance_scores=importance_scores + ) + + +class RandomStructured(BasePruningMethod): + r"""Prune entire (currently unpruned) channels in a tensor at random. + + Args: + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + dim (int, optional): index of the dim along which we define + channels to prune. Default: -1. + """ + + PRUNING_TYPE = "structured" + + def __init__(self, amount, dim=-1) -> None: + # Check range of validity of amount + _validate_pruning_amount_init(amount) + self.amount = amount + self.dim = dim + + def compute_mask(self, t, default_mask): + r"""Compute and returns a mask for the input tensor ``t``. + + Starting from a base ``default_mask`` (which should be a mask of ones + if the tensor has not been pruned yet), generate a random mask to + apply on top of the ``default_mask`` by randomly zeroing out channels + along the specified dim of the tensor. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + default_mask (torch.Tensor): Base mask from previous pruning + iterations, that need to be respected after the new mask is + applied. Same dims as ``t``. + + Returns: + mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` + + Raises: + IndexError: if ``self.dim >= len(t.shape)`` + """ + # Check that tensor has structure (i.e. more than 1 dimension) such + # that the concept of "channels" makes sense + _validate_structured_pruning(t) + + # Check that self.dim is a valid dim to index t, else raise IndexError + _validate_pruning_dim(t, self.dim) + + # Check that the amount of channels to prune is not > than the number of + # channels in t along the dim to prune + tensor_size = t.shape[self.dim] + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + # Compute binary mask by initializing it to all 0s and then filling in + # 1s wherever topk.indices indicates, along self.dim. + # mask has the same shape as tensor t + def make_mask(t, dim, nchannels, nchannels_toprune): + # generate a random number in [0, 1] to associate to each channel + prob = torch.rand(nchannels) + # generate mask for each channel by 0ing out the channels that + # got assigned the k = nchannels_toprune lowest values in prob + threshold = torch.kthvalue(prob, k=nchannels_toprune).values + channel_mask = prob > threshold + + mask = torch.zeros_like(t) + slc = [slice(None)] * len(t.shape) + slc[dim] = channel_mask + slc = tuple(slc) + mask[slc] = 1 + return mask + + if nparams_toprune == 0: # k=0 not supported by torch.kthvalue + mask = default_mask + else: + # apply the new structured mask on top of prior (potentially + # unstructured) mask + mask = make_mask(t, self.dim, tensor_size, nparams_toprune) + mask *= default_mask.to(dtype=mask.dtype) + return mask + + @classmethod + def apply(cls, module, name, amount, dim=-1): # type: ignore[override] + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + dim (int, optional): index of the dim along which we define + channels to prune. Default: -1. + """ + return super().apply(module, name, amount=amount, dim=dim) + + +class LnStructured(BasePruningMethod): + r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm. + + Args: + amount (int or float): quantity of channels to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument ``p`` in :func:`torch.norm`. + dim (int, optional): index of the dim along which we define + channels to prune. Default: -1. + """ + + PRUNING_TYPE = "structured" + + def __init__(self, amount, n, dim=-1) -> None: + # Check range of validity of amount + _validate_pruning_amount_init(amount) + self.amount = amount + self.n = n + self.dim = dim + + def compute_mask(self, t, default_mask): + r"""Compute and returns a mask for the input tensor ``t``. + + Starting from a base ``default_mask`` (which should be a mask of ones + if the tensor has not been pruned yet), generate a mask to apply on + top of the ``default_mask`` by zeroing out the channels along the + specified dim with the lowest L\ ``n``-norm. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + default_mask (torch.Tensor): Base mask from previous pruning + iterations, that need to be respected after the new mask is + applied. Same dims as ``t``. + + Returns: + mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` + + Raises: + IndexError: if ``self.dim >= len(t.shape)`` + """ + # Check that tensor has structure (i.e. more than 1 dimension) such + # that the concept of "channels" makes sense + _validate_structured_pruning(t) + # Check that self.dim is a valid dim to index t, else raise IndexError + _validate_pruning_dim(t, self.dim) + + # Check that the amount of channels to prune is not > than the number of + # channels in t along the dim to prune + tensor_size = t.shape[self.dim] + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + nparams_tokeep = tensor_size - nparams_toprune + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + # Structured pruning prunes entire channels so we need to know the + # L_n norm along each channel to then find the topk based on this + # metric + norm = _compute_norm(t, self.n, self.dim) + # largest=True --> top k; largest=False --> bottom k + # Keep the largest k channels along dim=self.dim + topk = torch.topk(norm, k=nparams_tokeep, largest=True) + # topk will have .indices and .values + + # Compute binary mask by initializing it to all 0s and then filling in + # 1s wherever topk.indices indicates, along self.dim. + # mask has the same shape as tensor t + def make_mask(t, dim, indices): + # init mask to 0 + mask = torch.zeros_like(t) + # e.g.: slc = [None, None, None], if len(t.shape) = 3 + slc = [slice(None)] * len(t.shape) + # replace a None at position=dim with indices + # e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3] + slc[dim] = indices + slc = tuple(slc) + # use slc to slice mask and replace all its entries with 1s + # e.g.: mask[:, :, [0, 2, 3]] = 1 + mask[slc] = 1 + return mask + + if nparams_toprune == 0: # k=0 not supported by torch.kthvalue + mask = default_mask + else: + mask = make_mask(t, self.dim, topk.indices) + mask *= default_mask.to(dtype=mask.dtype) + + return mask + + @classmethod + def apply(cls, module, name, amount, n, dim, importance_scores=None): # type: ignore[override] + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument ``p`` in :func:`torch.norm`. + dim (int): index of the dim along which we define channels to + prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + """ + return super().apply( + module, + name, + amount=amount, + n=n, + dim=dim, + importance_scores=importance_scores, + ) + + +class CustomFromMask(BasePruningMethod): + PRUNING_TYPE = "global" + + def __init__(self, mask) -> None: + self.mask = mask + + def compute_mask(self, t, default_mask): + assert default_mask.shape == self.mask.shape + mask = default_mask * self.mask.to(dtype=default_mask.dtype) + return mask + + @classmethod + def apply(cls, module, name, mask): # type: ignore[override] + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + """ + return super().apply(module, name, mask=mask) + + +def identity(module, name): + r"""Apply pruning reparametrization without pruning any units. + + Applies pruning reparametrization to the tensor corresponding to the + parameter called ``name`` in ``module`` without actually pruning any + units. Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Note: + The mask is a tensor of ones. + + Args: + module (nn.Module): module containing the tensor to prune. + name (str): parameter name within ``module`` on which pruning + will act. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.identity(nn.Linear(2, 3), "bias") + >>> print(m.bias_mask) + tensor([1., 1., 1.]) + """ + Identity.apply(module, name) + return module + + +def random_unstructured(module, name, amount): + r"""Prune tensor by removing random (currently unpruned) units. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified ``amount`` of (currently unpruned) units + selected at random. + Modifies module in place (and also return the modified module) by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.random_unstructured(nn.Linear(2, 3), "weight", amount=1) + >>> torch.sum(m.weight_mask == 0) + tensor(1) + + """ + RandomUnstructured.apply(module, name, amount) + return module + + +def l1_unstructured(module, name, amount, importance_scores=None): + r"""Prune tensor by removing units with the lowest L1-norm. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified `amount` of (currently unpruned) units with the + lowest L1-norm. + Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.l1_unstructured(nn.Linear(2, 3), "weight", amount=0.2) + >>> m.state_dict().keys() + odict_keys(['bias', 'weight_orig', 'weight_mask']) + """ + L1Unstructured.apply( + module, name, amount=amount, importance_scores=importance_scores + ) + return module + + +def random_structured(module, name, amount, dim): + r"""Prune tensor by removing random channels along the specified dimension. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified ``amount`` of (currently unpruned) channels + along the specified ``dim`` selected at random. + Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + dim (int): index of the dim along which we define channels to prune. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.random_structured(nn.Linear(5, 3), "weight", amount=3, dim=1) + >>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0)) + >>> print(columns_pruned) + 3 + """ + RandomStructured.apply(module, name, amount, dim) + return module + + +def ln_structured(module, name, amount, n, dim, importance_scores=None): + r"""Prune tensor by removing channels with the lowest L\ ``n``-norm along the specified dimension. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified ``amount`` of (currently unpruned) channels + along the specified ``dim`` with the lowest L\ ``n``-norm. + Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument ``p`` in :func:`torch.norm`. + dim (int): index of the dim along which we define channels to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> from torch.nn.utils import prune + >>> m = prune.ln_structured( + ... nn.Conv2d(5, 3, 2), "weight", amount=0.3, dim=1, n=float("-inf") + ... ) + """ + LnStructured.apply( + module, name, amount, n, dim, importance_scores=importance_scores + ) + return module + + +def global_unstructured( + parameters, pruning_method, importance_scores=None, **kwargs +) -> None: + r""" + Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. + + Modifies modules in place by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + parameters (Iterable of (module, name) tuples): parameters of + the model to prune in a global fashion, i.e. by aggregating all + weights prior to deciding which ones to prune. module must be of + type :class:`nn.Module`, and name must be a string. + pruning_method (function): a valid pruning function from this module, + or a custom one implemented by the user that satisfies the + implementation guidelines and has ``PRUNING_TYPE='unstructured'``. + importance_scores (dict): a dictionary mapping (module, name) tuples to + the corresponding parameter's importance scores tensor. The tensor + should be the same shape as the parameter, and is used for computing + mask for pruning. + If unspecified or None, the parameter will be used in place of its + importance scores. + kwargs: other keyword arguments such as: + amount (int or float): quantity of parameters to prune across the + specified parameters. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + + Raises: + TypeError: if ``PRUNING_TYPE != 'unstructured'`` + + Note: + Since global structured pruning doesn't make much sense unless the + norm is normalized by the size of the parameter, we now limit the + scope of global pruning to unstructured methods. + + Examples: + >>> from torch.nn.utils import prune + >>> from collections import OrderedDict + >>> net = nn.Sequential( + ... OrderedDict( + ... [ + ... ("first", nn.Linear(10, 4)), + ... ("second", nn.Linear(4, 1)), + ... ] + ... ) + ... ) + >>> parameters_to_prune = ( + ... (net.first, "weight"), + ... (net.second, "weight"), + ... ) + >>> prune.global_unstructured( + ... parameters_to_prune, + ... pruning_method=prune.L1Unstructured, + ... amount=10, + ... ) + >>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0)) + tensor(10) + + """ + # ensure parameters is a list or generator of tuples + if not isinstance(parameters, Iterable): + raise TypeError("global_unstructured(): parameters is not an Iterable") + + importance_scores = importance_scores if importance_scores is not None else {} + if not isinstance(importance_scores, dict): + raise TypeError("global_unstructured(): importance_scores must be of type dict") + + # flatten importance scores to consider them all at once in global pruning + relevant_importance_scores = torch.nn.utils.parameters_to_vector( + # pyrefly: ignore [bad-argument-type] + [ + importance_scores.get((module, name), getattr(module, name)) + for (module, name) in parameters + ] + ) + # similarly, flatten the masks (if they exist), or use a flattened vector + # of 1s of the same dimensions as t + default_mask = torch.nn.utils.parameters_to_vector( + [ + getattr(module, name + "_mask", torch.ones_like(getattr(module, name))) + for (module, name) in parameters + ] + ) + + # use the canonical pruning methods to compute the new mask, even if the + # parameter is now a flattened out version of `parameters` + container = PruningContainer() + container._tensor_name = "temp" # to make it match that of `method` + method = pruning_method(**kwargs) + method._tensor_name = "temp" # to make it match that of `container` + if method.PRUNING_TYPE != "unstructured": + raise TypeError( + 'Only "unstructured" PRUNING_TYPE supported for ' + f"the `pruning_method`. Found method {pruning_method} of type {method.PRUNING_TYPE}" + ) + + container.add_pruning_method(method) + + # use the `compute_mask` method from `PruningContainer` to combine the + # mask computed by the new method with the pre-existing mask + final_mask = container.compute_mask(relevant_importance_scores, default_mask) + + # Pointer for slicing the mask to match the shape of each parameter + pointer = 0 + for module, name in parameters: + param = getattr(module, name) + # The length of the parameter + num_param = param.numel() + # Slice the mask, reshape it + param_mask = final_mask[pointer : pointer + num_param].view_as(param) + # Assign the correct pre-computed mask to each parameter and add it + # to the forward_pre_hooks like any other pruning method + custom_from_mask(module, name, mask=param_mask) + + # Increment the pointer to continue slicing the final_mask + pointer += num_param + + +def custom_from_mask(module, name, mask): + r"""Prune tensor corresponding to parameter called ``name`` in ``module`` by applying the pre-computed mask in ``mask``. + + Modifies module in place (and also return the modified module) by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + mask (Tensor): binary mask to be applied to the parameter. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> from torch.nn.utils import prune + >>> m = prune.custom_from_mask( + ... nn.Linear(5, 3), name="bias", mask=torch.tensor([0, 1, 0]) + ... ) + >>> print(m.bias_mask) + tensor([0., 1., 0.]) + + """ + CustomFromMask.apply(module, name, mask) + return module + + +def remove(module, name): + r"""Remove the pruning reparameterization from a module and the pruning method from the forward hook. + + The pruned parameter named ``name`` remains permanently pruned, and the parameter + named ``name+'_orig'`` is removed from the parameter list. Similarly, + the buffer named ``name+'_mask'`` is removed from the buffers. + + Note: + Pruning itself is NOT undone or reversed! + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + + Examples: + >>> m = random_unstructured(nn.Linear(5, 7), name="weight", amount=0.2) + >>> m = remove(m, name="weight") + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError( + f"Parameter '{name}' of module {module} has to be pruned before pruning can be removed" + ) + + +def is_pruned(module) -> bool: + r"""Check if a module is pruned by looking for pruning pre-hooks. + + Check whether ``module`` is pruned by looking for + ``forward_pre_hooks`` in its modules that inherit from the + :class:`BasePruningMethod`. + + Args: + module (nn.Module): object that is either pruned or unpruned + + Returns: + binary answer to whether ``module`` is pruned. + + Examples: + >>> from torch.nn.utils import prune + >>> m = nn.Linear(5, 7) + >>> print(prune.is_pruned(m)) + False + >>> prune.random_unstructured(m, name="weight", amount=0.2) + >>> print(prune.is_pruned(m)) + True + """ + for _, submodule in module.named_modules(): + for hook in submodule._forward_pre_hooks.values(): + if isinstance(hook, BasePruningMethod): + return True + return False + + +def _validate_pruning_amount_init(amount) -> None: + r"""Validate helper to check the range of amount at init. + + Args: + amount (int or float): quantity of parameters to prune. + If float, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If int, it represents the + absolute number of parameters to prune. + + Raises: + ValueError: if amount is a float not in [0, 1], or if it's a negative + integer. + TypeError: if amount is neither a float nor an integer. + + Note: + This does not take into account the number of parameters in the + tensor to be pruned, which is known only at prune. + """ + if not isinstance(amount, numbers.Real): + raise TypeError(f"Invalid type for amount: {amount}. Must be int or float.") + + if (isinstance(amount, numbers.Integral) and amount < 0) or ( + not isinstance(amount, numbers.Integral) # so it's a float + and (float(amount) > 1.0 or float(amount) < 0.0) + ): + raise ValueError( + f"amount={amount} should either be a float in the range [0, 1] or a non-negative integer" + ) + + +def _validate_pruning_amount(amount, tensor_size) -> None: + r"""Validate that the pruning amount is meaningful wrt to the size of the data. + + Validation helper to check that the amount of parameters to prune + is meaningful wrt to the size of the data (`tensor_size`). + + Args: + amount (int or float): quantity of parameters to prune. + If float, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If int, it represents the + absolute number of parameters to prune. + tensor_size (int): absolute number of parameters in the tensor + to prune. + """ + # TODO: consider removing this check and allowing users to specify + # a number of units to prune that is greater than the number of units + # left to prune. In this case, the tensor will just be fully pruned. + + if isinstance(amount, numbers.Integral) and amount > tensor_size: + raise ValueError( + f"amount={amount} should be smaller than the number of parameters to prune={tensor_size}" + ) + + +def _validate_structured_pruning(t) -> None: + r"""Validate that the tensor to be pruned is at least 2-Dimensional. + + Validation helper to check that the tensor to be pruned is multi- + dimensional, such that the concept of "channels" is well-defined. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + + Raises: + ValueError: if the tensor `t` is not at least 2D. + """ + shape = t.shape + if len(shape) <= 1: + raise ValueError( + "Structured pruning can only be applied to " + "multidimensional tensors. Found tensor of shape " + f"{shape} with {len(shape)} dims" + ) + + +def _compute_nparams_toprune(amount, tensor_size): + r"""Convert the pruning amount from a percentage to absolute value. + + Since amount can be expressed either in absolute value or as a + percentage of the number of units/channels in a tensor, this utility + function converts the percentage to absolute value to standardize + the handling of pruning. + + Args: + amount (int or float): quantity of parameters to prune. + If float, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If int, it represents the + absolute number of parameters to prune. + tensor_size (int): absolute number of parameters in the tensor + to prune. + + Returns: + int: the number of units to prune in the tensor + """ + # incorrect type already checked in _validate_pruning_amount_init + if isinstance(amount, numbers.Integral): + return amount + else: + return round(amount * tensor_size) + + +def _validate_pruning_dim(t, dim) -> None: + r"""Validate that the pruning dimension is within the bounds of the tensor dimension. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + dim (int): index of the dim along which we define channels to prune + """ + if dim >= t.dim(): + raise IndexError(f"Invalid index {dim} for tensor of size {t.shape}") + + +def _compute_norm(t, n, dim): + r"""Compute the L_n-norm of a tensor along all dimensions except for the specified dimension. + + The L_n-norm will be computed across all entries in tensor `t` along all dimension + except for the one identified by dim. + Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim), + then norm will have Size [4], and each entry will represent the + `L_n`-norm computed using the 3x2=6 entries for each of the 4 channels. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument p in torch.norm + dim (int): dim identifying the channels to prune + + Returns: + norm (torch.Tensor): L_n norm computed across all dimensions except + for `dim`. By construction, `norm.shape = t.shape[-1]`. + """ + # dims = all axes, except for the one identified by `dim` + dims = list(range(t.dim())) + # convert negative indexing + if dim < 0: + dim = dims[dim] + dims.remove(dim) + + norm = torch.norm(t, p=n, dim=dims) + return norm diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/rnn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..f0530d99f94e0a0aa5fc5821ebefd85513e44c9f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/rnn.py @@ -0,0 +1,606 @@ +import warnings +from collections.abc import Callable, Iterable +from typing import Any, NamedTuple, overload, TypeVar +from typing_extensions import Self + +import torch +from torch import _VF, Tensor + + +__all__ = [ + "PackedSequence", + "invert_permutation", + "pack_padded_sequence", + "pad_packed_sequence", + "pad_sequence", + "unpad_sequence", + "pack_sequence", + "unpack_sequence", +] + +_T = TypeVar("_T") +_R = TypeVar("_R") + + +class PackedSequence_(NamedTuple): + data: torch.Tensor + batch_sizes: torch.Tensor + sorted_indices: torch.Tensor | None + unsorted_indices: torch.Tensor | None + + +def bind(optional: _T | None, fn: Callable[[_T], _R]) -> _R | None: + if optional is None: + return None + return fn(optional) + + +class PackedSequence(PackedSequence_): + r"""Holds the data and list of :attr:`batch_sizes` of a packed sequence. + + All RNN modules accept packed sequences as inputs. + + Note: + Instances of this class should never be created manually. They are meant + to be instantiated by functions like :func:`pack_padded_sequence`. + + Batch sizes represent the number elements at each sequence step in + the batch, not the varying sequence lengths passed to + :func:`pack_padded_sequence`. For instance, given data ``abc`` and ``x`` + the :class:`PackedSequence` would contain data ``axbc`` with + ``batch_sizes=[2,1,1]``. + + Attributes: + data (Tensor): Tensor containing packed sequence + batch_sizes (Tensor): Tensor of integers holding + information about the batch size at each sequence step + sorted_indices (Tensor, optional): Tensor of integers holding how this + :class:`PackedSequence` is constructed from sequences. + unsorted_indices (Tensor, optional): Tensor of integers holding how this + to recover the original sequences with correct order. + + .. note:: + :attr:`data` can be on arbitrary device and of arbitrary dtype. + :attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64`` + tensors on the same device as :attr:`data`. + + However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor. + + This invariant is maintained throughout :class:`PackedSequence` class, + and all functions that construct a :class:`PackedSequence` in PyTorch + (i.e., they only pass in tensors conforming to this constraint). + """ + + def __new__( + cls, + data: Tensor, + batch_sizes: Tensor | None = None, + sorted_indices: Tensor | None = None, + unsorted_indices: Tensor | None = None, + ) -> Self: + return super().__new__( + cls, + *_packed_sequence_init_args( + data, batch_sizes, sorted_indices, unsorted_indices + ), + ) + + # NOTE [ device and dtype of a PackedSequence ] + # + # See the note above in doc string (starting with ":attr:`data` can be on + # arbitrary device..."). + def pin_memory(self) -> Self: + # Why not convert `batch_sizes`? + # See NOTE [ device and dtype of a PackedSequence ] + return type(self)( + self.data.pin_memory(), + self.batch_sizes, + bind(self.sorted_indices, lambda t: t.pin_memory()), + bind(self.unsorted_indices, lambda t: t.pin_memory()), + ) + + @overload + def to( + self, + dtype: torch.dtype, + non_blocking: bool = ..., + copy: bool = ..., + ) -> Self: ... + + @overload + def to( + self, + device: str | torch.device | int | None = ..., + dtype: torch.dtype | None = ..., + non_blocking: bool = ..., + copy: bool = ..., + ) -> Self: ... + + @overload + def to( + self, + other: Tensor, + non_blocking: bool = ..., + copy: bool = ..., + ) -> Self: ... + + def to(self, *args: Any, **kwargs: Any) -> Self: + r"""Perform dtype and/or device conversion on `self.data`. + + It has similar signature as :meth:`torch.Tensor.to`, except optional + arguments like `non_blocking` and `copy` should be passed as kwargs, + not args, or they will not apply to the index tensors. + + .. note:: + + If the ``self.data`` Tensor already has the correct :class:`torch.dtype` + and :class:`torch.device`, then ``self`` is returned. + Otherwise, returns a copy with the desired configuration. + """ + # Why not convert `batch_sizes`? + # See NOTE [ device and dtype of a PackedSequence ] + data = self.data.to(*args, **kwargs) + if data is self.data: + return self + else: + # Does not forward device or dtype arg/kwargs, device is set from data.device + kwargs = dict( + filter(lambda t: t[0] != "device" and t[0] != "dtype", kwargs.items()) + ) + sorted_indices = bind( + self.sorted_indices, lambda t: t.to(data.device, **kwargs) + ) + unsorted_indices = bind( + self.unsorted_indices, lambda t: t.to(data.device, **kwargs) + ) + return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices) + + def cuda(self, *args: Any, **kwargs: Any) -> Self: + # Tests to see if 'cuda' should be added to kwargs + ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to( + *args, **kwargs + ) + if ex.is_cuda: + return self.to(*args, **kwargs) + kwargs["device"] = "cuda" + return self.to(*args, **kwargs) + + def cpu(self, *args: Any, **kwargs: Any) -> Self: + ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to( + *args, **kwargs + ) + if ex.device.type == "cpu": + return self.to(*args, **kwargs) + kwargs["device"] = "cpu" + return self.to(*args, **kwargs) + + def double(self) -> Self: + return self.to(dtype=torch.double) + + def float(self) -> Self: + return self.to(dtype=torch.float) + + def half(self) -> Self: + return self.to(dtype=torch.half) + + def long(self) -> Self: + return self.to(dtype=torch.long) + + def int(self) -> Self: + return self.to(dtype=torch.int) + + def short(self) -> Self: + return self.to(dtype=torch.short) + + def char(self) -> Self: + return self.to(dtype=torch.int8) + + def byte(self) -> Self: + return self.to(dtype=torch.uint8) + + @property + def is_cuda(self) -> bool: + r"""Return true if `self.data` stored on a gpu.""" + return self.data.is_cuda + + def is_pinned(self) -> bool: + r"""Return true if `self.data` stored on in pinned memory.""" + return self.data.is_pinned() + + +# TorchScript doesn't support constructors on named tuples, so we use this helper +# method to construct PackedSequence +def _packed_sequence_init_args( + data: Tensor, + batch_sizes: Tensor | None = None, + sorted_indices: Tensor | None = None, + unsorted_indices: Tensor | None = None, +) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: + # NB: if unsorted_indices is provided, it should be the inverse permutation + # to sorted_indices. Don't assert it here because the PackedSequence ctor + # should only be used internally. + + if unsorted_indices is None: + unsorted_indices = invert_permutation(sorted_indices) + + # support being called as `PackedSequence(data, batch_sizes, sorted_indices)` + if batch_sizes is not None: + # TODO: Re-enable this check (.type isn't supported in TorchScript) + if batch_sizes.device.type != "cpu": + raise ValueError( + "batch_sizes should always be on CPU. " + "Instances of PackedSequence should never be created manually. " + "They should be instantiated by functions like pack_sequence " + "and pack_padded_sequences in nn.utils.rnn. " + "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence" + ) + return data, batch_sizes, sorted_indices, unsorted_indices + + # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)` + else: + assert isinstance(data, (list, tuple)) and len(data) == 2 + return data[0], data[1], sorted_indices, unsorted_indices + + +def _packed_sequence_init( + data: Tensor, + batch_sizes: Tensor | None = None, + sorted_indices: Tensor | None = None, + unsorted_indices: Tensor | None = None, +) -> PackedSequence: + data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args( + data, batch_sizes, sorted_indices, unsorted_indices + ) + return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices) + + +def invert_permutation(permutation: Tensor | None) -> Tensor | None: + """Returns the inverse of ``permutation``. + + This is useful for converting between sorted and unsorted indices in + a :class:`~nn.utils.rnn.PackedSequence`. + + Args: + permutation (Tensor, optional): a 1-D tensor of indices to invert + """ + if permutation is None: + return None + output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format) + output.scatter_( + 0, permutation, torch.arange(0, permutation.numel(), device=permutation.device) + ) + return output + + +def pack_padded_sequence( + input: Tensor, + lengths: Tensor | list[int], + batch_first: bool = False, + enforce_sorted: bool = True, +) -> PackedSequence: + r"""Packs a Tensor containing padded sequences of variable length. + + :attr:`input` can be of size ``T x B x *`` (if :attr:`batch_first` is ``False``) + or ``B x T x *`` (if :attr:`batch_first` is ``True``) where ``T`` is the length + of the longest sequence, ``B`` is the batch size, and ``*`` is any number of dimensions + (including 0). + + For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is + ``True``, the sequences should be sorted by length in a decreasing order, i.e. + ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest + one. `enforce_sorted = True` is only necessary for ONNX export. + + It is an inverse operation to :func:`pad_packed_sequence`, and hence :func:`pad_packed_sequence` + can be used to recover the underlying tensor packed in :class:`PackedSequence`. + + Note: + This function accepts any input that has at least two dimensions. You + can apply it to pack the labels, and use the output of the RNN with + them to compute the loss directly. A Tensor can be retrieved from + a :class:`PackedSequence` object by accessing its ``.data`` attribute. + + Args: + input (Tensor): padded batch of variable length sequences. + lengths (Tensor or list(int)): list of sequence lengths of each batch + element (must be on the CPU if provided as a tensor). + batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *`` + format, ``T x B x *`` otherwise. Default: ``False``. + enforce_sorted (bool, optional): if ``True``, the input is expected to + contain sequences sorted by length in a decreasing order. If + ``False``, the input will get sorted unconditionally. Default: ``True``. + + .. warning:: + The dim of ``input`` tensor will be truncated if its length larger than + correspond value in ``length``. + + Returns: + a :class:`PackedSequence` object + """ + if not isinstance(lengths, torch.Tensor): + if torch._C._get_tracing_state(): + warnings.warn( + "pack_padded_sequence has been called with a Python list of " + "sequence lengths. The tracer cannot track the data flow of Python " + "values, and it will treat them as constants, likely rendering " + "the trace incorrect for any other combination of lengths.", + stacklevel=2, + ) + lengths = torch.as_tensor(lengths, dtype=torch.int64, device="cpu") + else: + lengths = lengths.to(dtype=torch.int64) + + if enforce_sorted: + sorted_indices = None + else: + lengths, sorted_indices = torch.sort(lengths, descending=True) + sorted_indices = sorted_indices.to(input.device) + batch_dim = 0 if batch_first else 1 + input = input.index_select(batch_dim, sorted_indices) + + data, batch_sizes = _VF._pack_padded_sequence(input, lengths, batch_first) + return _packed_sequence_init(data, batch_sizes, sorted_indices, None) + + +def pad_packed_sequence( + sequence: PackedSequence, + batch_first: bool = False, + padding_value: float = 0.0, + total_length: int | None = None, +) -> tuple[Tensor, Tensor]: + r"""Pad a packed batch of variable length sequences. + + It is an inverse operation to :func:`pack_padded_sequence`. + + The returned Tensor's data will be of size ``T x B x *`` (if :attr:`batch_first` is ``False``) + or ``B x T x *`` (if :attr:`batch_first` is ``True``) , where ``T`` is the length of the longest + sequence and ``B`` is the batch size. + + Example: + >>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]) + >>> lens = [2, 1, 3] + >>> packed = pack_padded_sequence( + ... seq, lens, batch_first=True, enforce_sorted=False + ... ) + >>> packed + PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]), + sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0])) + >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True) + >>> seq_unpacked + tensor([[1, 2, 0], + [3, 0, 0], + [4, 5, 6]]) + >>> lens_unpacked + tensor([2, 1, 3]) + + .. note:: + :attr:`total_length` is useful to implement the + ``pack sequence -> recurrent network -> unpack sequence`` pattern in a + :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. + See :ref:`this FAQ section ` for + details. + + Args: + sequence (PackedSequence): batch to pad + batch_first (bool, optional): if ``True``, the output will be in ``B x T x *`` + format, ``T x B x *`` otherwise. + padding_value (float, optional): values for padded elements. + total_length (int, optional): if not ``None``, the output will be padded to + have length :attr:`total_length`. This method will throw :class:`ValueError` + if :attr:`total_length` is less than the max sequence length in + :attr:`sequence`. + + Returns: + Tuple of Tensor containing the padded sequence, and a Tensor + containing the list of lengths of each sequence in the batch. + Batch elements will be re-ordered as they were ordered originally when + the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``. + """ + max_seq_length = sequence.batch_sizes.size(0) + if total_length is not None: + if total_length < max_seq_length: + raise ValueError( + "Expected total_length to be at least the length " + "of the longest sequence in input, but got " + f"total_length={total_length} and max sequence length being {max_seq_length}" + ) + max_seq_length = total_length + padded_output, lengths = _VF._pad_packed_sequence( + sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length + ) + unsorted_indices = sequence.unsorted_indices + if unsorted_indices is not None: + batch_dim = 0 if batch_first else 1 + return ( + padded_output.index_select(batch_dim, unsorted_indices), + lengths[unsorted_indices.cpu()], + ) + return padded_output, lengths + + +# NOTE: for JIT-compatibility, we need to be more restrictive here and use specific types instead of Iterable. +def pad_sequence( + sequences: Tensor | list[Tensor], + batch_first: bool = False, + padding_value: float = 0.0, + padding_side: str = "right", +) -> Tensor: + r"""Pad a list of variable length Tensors with :attr:`padding_value`. + + ``pad_sequence`` stacks a list of Tensors along a new dimension, and pads them + to equal length. :attr:`sequences` can be list of sequences with size ``L x *``, + where `L` is length of the sequence and ``*`` is any number of dimensions + (including ``0``). If :attr:`batch_first` is ``False``, the output is of size + ``T x B x *``, and ``B x T x *`` otherwise, where ``B`` is the batch size + (the number of elements in :attr:`sequences`), ``T`` is the length of the longest + sequence. + + Example: + >>> from torch.nn.utils.rnn import pad_sequence + >>> a = torch.ones(25, 300) + >>> b = torch.ones(22, 300) + >>> c = torch.ones(15, 300) + >>> pad_sequence([a, b, c]).size() + torch.Size([25, 3, 300]) + + Note: + This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` + where `T` is the length of the longest sequence. This function assumes + trailing dimensions and type of all the Tensors in sequences are same. + + Args: + sequences (list[Tensor]): list of variable length sequences. + batch_first (bool, optional): if ``True``, the output will be in ``B x T x *`` + format, ``T x B x *`` otherwise. + padding_value (float, optional): value for padded elements. Default: ``0``. + padding_side (str, optional): the side to pad the sequences on. + Default: ``'right'``. + + Returns: + Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. + Tensor of size ``B x T x *`` otherwise + """ + if not (torch.jit.is_tracing() or torch.jit.is_scripting()): + # JIT doesn't support `Iterable` + if not isinstance(sequences, Iterable): + msg = ( + "pad_sequence: Expected iterable for input sequences, but got arg of type: " + f"{type(sequences)}" + ) + raise RuntimeError(msg) + + # In JIT context this leads to, + # RuntimeError: cannot statically infer the expected size of a list in this context + sequences = tuple(sequences) # type: ignore[assignment] + else: + # For JIT, we only support Union[Tensor, Tuple[Tensor]] + if isinstance(sequences, torch.Tensor): + sequences = sequences.unbind(0) # type: ignore[assignment] + + # assuming trailing dimensions and type of all the Tensors + # in sequences are same and fetching those from sequences[0] + return torch._C._nn.pad_sequence( + sequences, # type: ignore[arg-type] + batch_first, + padding_value, + padding_side, # type: ignore[arg-type] + ) + + +def unpad_sequence( + padded_sequences: Tensor, + lengths: Tensor, + batch_first: bool = False, +) -> list[Tensor]: + r"""Unpad padded Tensor into a list of variable length Tensors. + + ``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors. + + Example: + >>> from torch.nn.utils.rnn import pad_sequence, unpad_sequence + >>> a = torch.ones(25, 300) + >>> b = torch.ones(22, 300) + >>> c = torch.ones(15, 300) + >>> sequences = [a, b, c] + >>> padded_sequences = pad_sequence(sequences) + >>> lengths = torch.as_tensor([v.size(0) for v in sequences]) + >>> unpadded_sequences = unpad_sequence(padded_sequences, lengths) + >>> torch.allclose(sequences[0], unpadded_sequences[0]) + True + >>> torch.allclose(sequences[1], unpadded_sequences[1]) + True + >>> torch.allclose(sequences[2], unpadded_sequences[2]) + True + + Args: + padded_sequences (Tensor): padded sequences. + lengths (Tensor): length of original (unpadded) sequences. + batch_first (bool, optional): whether batch dimension first or not. Default: ``False``. + + Returns: + a list of :class:`Tensor` objects + """ + unpadded_sequences = [] + + if not batch_first: + padded_sequences.transpose_(0, 1) + + max_length = padded_sequences.shape[1] + idx = torch.arange(max_length, device=lengths.device) + + for seq, length in zip(padded_sequences, lengths, strict=True): + mask = idx < length + unpacked_seq = seq[mask] + unpadded_sequences.append(unpacked_seq) + + return unpadded_sequences + + +def pack_sequence( + sequences: list[Tensor], + enforce_sorted: bool = True, +) -> PackedSequence: + r"""Packs a list of variable length Tensors. + + Consecutive call of the next functions: ``pad_sequence``, ``pack_padded_sequence``. + + ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is + the length of a sequence and `*` is any number of trailing dimensions, + including ``0``. + + For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted`` + is ``True``, the sequences should be sorted in the order of decreasing length. + ``enforce_sorted = True`` is only necessary for ONNX export. + + Example: + >>> from torch.nn.utils.rnn import pack_sequence + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5]) + >>> c = torch.tensor([6]) + >>> pack_sequence([a, b, c]) + PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None) + + Args: + sequences (list[Tensor]): A list of sequences of decreasing length. + enforce_sorted (bool, optional): if ``True``, checks that the input + contains sequences sorted by length in a decreasing order. If + ``False``, this condition is not checked. Default: ``True``. + + Returns: + a :class:`PackedSequence` object + """ + lengths = torch.as_tensor([v.size(0) for v in sequences]) + return pack_padded_sequence( + pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted + ) + + +def unpack_sequence(packed_sequences: PackedSequence) -> list[Tensor]: + r"""Unpack PackedSequence into a list of variable length Tensors. + + ``packed_sequences`` should be a PackedSequence object. + + Example: + >>> from torch.nn.utils.rnn import pack_sequence, unpack_sequence + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5]) + >>> c = torch.tensor([6]) + >>> sequences = [a, b, c] + >>> print(sequences) + [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])] + >>> packed_sequences = pack_sequence(sequences) + >>> print(packed_sequences) + PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None) + >>> unpacked_sequences = unpack_sequence(packed_sequences) + >>> print(unpacked_sequences) + [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])] + + Args: + packed_sequences (PackedSequence): A PackedSequence object. + + Returns: + a list of :class:`Tensor` objects + """ + padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True) + unpacked_sequences = unpad_sequence(padded_sequences, lengths, batch_first=True) + return unpacked_sequences diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/spectral_norm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/spectral_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..a11613a51dac49d5a52d2c55f51734de37bd9e47 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/spectral_norm.py @@ -0,0 +1,368 @@ +# mypy: allow-untyped-defs +"""Spectral Normalization from https://arxiv.org/abs/1802.05957.""" + +from typing import Any, TypeVar + +import torch +import torch.nn.functional as F +from torch.nn.modules import Module + + +__all__ = [ + "SpectralNorm", + "SpectralNormLoadStateDictPreHook", + "SpectralNormStateDictHook", + "spectral_norm", + "remove_spectral_norm", +] + + +class SpectralNorm: + # Invariant before and after each forward call: + # u = F.normalize(W @ v) + # NB: At initialization, this invariant is not enforced + + _version: int = 1 + # At version 1: + # made `W` not a buffer, + # added `v` as a buffer, and + # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. + name: str + dim: int + n_power_iterations: int + eps: float + + def __init__( + self, + name: str = "weight", + n_power_iterations: int = 1, + dim: int = 0, + eps: float = 1e-12, + ) -> None: + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError( + "Expected n_power_iterations to be positive, but " + f"got n_power_iterations={n_power_iterations}" + ) + self.n_power_iterations = n_power_iterations + self.eps = eps + + def reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor: + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute( + self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim] + ) + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + def compute_weight(self, module: Module, do_power_iteration: bool) -> torch.Tensor: + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallelized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + weight = getattr(module, self.name + "_orig") + u = getattr(module, self.name + "_u") + v = getattr(module, self.name + "_v") + weight_mat = self.reshape_weight_to_matrix(weight) + + if do_power_iteration: + with torch.no_grad(): + for _ in range(self.n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + v = F.normalize( + torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v + ) + u = F.normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u) + if self.n_power_iterations > 0: + # See above on why we need to clone + u = u.clone(memory_format=torch.contiguous_format) + v = v.clone(memory_format=torch.contiguous_format) + + sigma = torch.dot(u, torch.mv(weight_mat, v)) + weight = weight / sigma + return weight + + def remove(self, module: Module) -> None: + with torch.no_grad(): + weight = self.compute_weight(module, do_power_iteration=False) + delattr(module, self.name) + delattr(module, self.name + "_u") + delattr(module, self.name + "_v") + delattr(module, self.name + "_orig") + module.register_parameter(self.name, torch.nn.Parameter(weight.detach())) + + def __call__(self, module: Module, inputs: Any) -> None: + setattr( + module, + self.name, + self.compute_weight(module, do_power_iteration=module.training), + ) + + def _solve_v_and_rescale(self, weight_mat, u, target_sigma): + # Tries to returns a vector `v` s.t. `u = F.normalize(W @ v)` + # (the invariant at top of this class) and `u @ W @ v = sigma`. + # This uses pinverse in case W^T W is not invertible. + v = torch.linalg.multi_dot( + [weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)] + ).squeeze(1) + return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) + + @staticmethod + def apply( + module: Module, name: str, n_power_iterations: int, dim: int, eps: float + ) -> "SpectralNorm": + for hook in module._forward_pre_hooks.values(): + if isinstance(hook, SpectralNorm) and hook.name == name: + raise RuntimeError( + f"Cannot register two spectral_norm hooks on the same parameter {name}" + ) + + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + if weight is None: + raise ValueError( + f"`SpectralNorm` cannot be applied as parameter `{name}` is None" + ) + if isinstance(weight, torch.nn.parameter.UninitializedParameter): + raise ValueError( + "The module passed to `SpectralNorm` can't have uninitialized parameters. " + "Make sure to run the dummy forward before applying spectral normalization" + ) + + with torch.no_grad(): + weight_mat = fn.reshape_weight_to_matrix(weight) + + h, w = weight_mat.size() + # randomly initialize `u` and `v` + u = F.normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) + v = F.normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) + + delattr(module, fn.name) + module.register_parameter(fn.name + "_orig", weight) + # We still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an nn.Parameter and + # gets added as a parameter. Instead, we register weight.data as a plain + # attribute. + setattr(module, fn.name, weight.data) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + module.register_forward_pre_hook(fn) + module._register_state_dict_hook(SpectralNormStateDictHook(fn)) + module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) + return fn + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormLoadStateDictPreHook: + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn) -> None: + self.fn = fn + + # For state_dict with version None, (assuming that it has gone through at + # least one training forward), we have + # + # u = F.normalize(W_orig @ v) + # W = W_orig / sigma, where sigma = u @ W_orig @ v + # + # To compute `v`, we solve `W_orig @ x = u`, and let + # v = x / (u @ W_orig @ x) * (W / W_orig). + def __call__( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) -> None: + fn = self.fn + version = local_metadata.get("spectral_norm", {}).get( + fn.name + ".version", None + ) + if version is None or version < 1: + weight_key = prefix + fn.name + if ( + version is None + and all(weight_key + s in state_dict for s in ("_orig", "_u", "_v")) + and weight_key not in state_dict + ): + # Detect if it is the updated state dict and just missing metadata. + # This could happen if the users are crafting a state dict themselves, + # so we just pretend that this is the newest. + return + has_missing_keys = False + for suffix in ("_orig", "", "_u"): + key = weight_key + suffix + if key not in state_dict: + has_missing_keys = True + if strict: + missing_keys.append(key) + if has_missing_keys: + return + with torch.no_grad(): + weight_orig = state_dict[weight_key + "_orig"] + weight = state_dict.pop(weight_key) + sigma = (weight_orig / weight).mean() + weight_mat = fn.reshape_weight_to_matrix(weight_orig) + u = state_dict[weight_key + "_u"] + v = fn._solve_v_and_rescale(weight_mat, u, sigma) + state_dict[weight_key + "_v"] = v + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormStateDictHook: + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn) -> None: + self.fn = fn + + def __call__(self, module, state_dict, prefix, local_metadata) -> None: + if "spectral_norm" not in local_metadata: + local_metadata["spectral_norm"] = {} + key = self.fn.name + ".version" + if key in local_metadata["spectral_norm"]: + raise RuntimeError(f"Unexpected key in metadata['spectral_norm']: {key}") + local_metadata["spectral_norm"][key] = self.fn._version + + +T_module = TypeVar("T_module", bound=Module) + + +def spectral_norm( + module: T_module, + name: str = "weight", + n_power_iterations: int = 1, + eps: float = 1e-12, + dim: int | None = None, +) -> T_module: + r"""Apply spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by rescaling the weight tensor + with spectral norm :math:`\sigma` of the weight matrix calculated using + power iteration method. If the dimension of the weight tensor is greater + than 2, it is reshaped to 2D in power iteration method to get spectral + norm. This is implemented via a hook that calculates spectral norm and + rescales weight before every :meth:`~Module.forward` call. + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm + eps (float, optional): epsilon for numerical stability in + calculating norms + dim (int, optional): dimension corresponding to number of outputs, + the default is ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with the spectral norm hook + + .. note:: + This function has been reimplemented as + :func:`torch.nn.utils.parametrizations.spectral_norm` using the new + parametrization functionality in + :func:`torch.nn.utils.parametrize.register_parametrization`. Please use + the newer version. This function will be deprecated in a future version + of PyTorch. + + Example:: + + >>> m = spectral_norm(nn.Linear(20, 40)) + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_u.size() + torch.Size([40]) + + """ + if dim is None: + if isinstance( + module, + ( + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, + ), + ): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + # pyrefly: ignore [bad-return] + return module + + +def remove_spectral_norm(module: T_module, name: str = "weight") -> T_module: + r"""Remove the spectral normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = spectral_norm(nn.Linear(40, 10)) + >>> remove_spectral_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + break + else: + raise ValueError(f"spectral_norm of '{name}' not found in {module}") + + for k, hook in module._state_dict_hooks.items(): + if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name: + del module._state_dict_hooks[k] + break + + for k, hook in module._load_state_dict_pre_hooks.items(): + if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name: + del module._load_state_dict_pre_hooks[k] + break + + return module diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/stateless.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/stateless.py new file mode 100644 index 0000000000000000000000000000000000000000..70f0afdeb52923a029a1843e1f2cfc702ab7473b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/stateless.py @@ -0,0 +1,279 @@ +# mypy: allow-untyped-defs +import contextlib +from typing import Any +from typing_extensions import deprecated + +import torch +from torch import Tensor +from torch.nn.utils._named_member_accessor import NamedMemberAccessor + + +__all__ = ["functional_call"] + + +def _untie_named_tensors_map( + module: "torch.nn.Module", + parameters_and_buffers: dict[str, Tensor], +) -> dict[str, Tensor]: + """ + Unties all tied tensors in the module to parameters_and_buffers. + + This function returns a new untied_parameters_and_buffers dictionary and leave the original + untied_parameters_and_buffers dictionary unchanged. It adds new (missing) keys for tied tensors + in the module to untied_parameters_and_buffers. The value of the new key is the user-given value + in the original parameters_and_buffers dictionary. + + If there are more than one user-given values for the same tied tensor, it will raise an error. + + For example, if the module has two tied weights self.foo and self.tied_foo and the user passes + {'foo': foo_value, ...}, this will return {'foo': foo_value, 'tied_foo': foo_value, ...}. If the + user passes {'foo': foo_value, 'tied_foo': tied_foo_value, ...}, it will raise an error. If the + user passes {'foo': foo_value, 'tied_foo': foo_value, ...}, it will not raise an error. + + Args: + module (torch.nn.Module): the module to determine which tensors are tied. + parameters_and_buffers (Dict[str, Tensor]): a map of {name: tensor} for reparamaterizing the module. + + Returns: + A new untied version of the parameters_and_buffers dictionary. + + Raises: + ValueError: if there are more than one user-given values for the same tied tensor. + """ + # A map of {name: tensor} for all tensors (including tied ones) in the module. + all_named_tensors: dict[str, Tensor] = {} + all_named_tensors.update(module.named_parameters(remove_duplicate=False)) + all_named_tensors.update(module.named_buffers(remove_duplicate=False)) + + # A map of {tensor: set(all_tied_names)} for all tensor names in the module. + tensor_to_tied_names_map: dict[Tensor, set[str]] = {} + for name, tensor in all_named_tensors.items(): + if tensor not in tensor_to_tied_names_map: + tensor_to_tied_names_map[tensor] = set() + tensor_to_tied_names_map[tensor].add(name) + + # A map of {tied_name: set(all_tied_names)} for all tensor names in the module. + # If a name is not tied, it will not be in this map. + tied_names_map: dict[str, set[str]] = {} + for tied_names in tensor_to_tied_names_map.values(): + if len(tied_names) > 1: + for tied_name in tied_names: + tied_names_map[tied_name] = tied_names + + # Make sure the user didn't pass multiple values for the same tied tensor. + given_names = set(parameters_and_buffers.keys()) + # same as given_names.intersection(tied_names_map.keys()) but dynamo can't + # handle that + given_names_for_tied_tensors: set[str] = set() + for name in given_names: + if name in tied_names_map: + given_names_for_tied_tensors.add(name) + + for given_name in given_names_for_tied_tensors: + tied_names = tied_names_map[given_name] + if ( + # Detect if there are multiple keys present for the same tied tensor. + len(tied_names.intersection(given_names_for_tied_tensors)) > 1 + # Only raise an error if the user passed multiple values for the same tied tensor. + # If all given values are the same, don't raise. + and len({parameters_and_buffers[tied_name] for tied_name in tied_names}) + != 1 + ): + raise ValueError( + f"functional_call got multiple values for keys {sorted(tied_names)}, " + f"which are tied. Consider using tie_weights=False" + ) + + # Untie the given named tensor map + # Make a copy for not modifying the original dict + untied_parameters_and_buffers = parameters_and_buffers.copy() + for given_name in given_names_for_tied_tensors: + for tied_name in tied_names_map[given_name]: + untied_parameters_and_buffers[tied_name] = parameters_and_buffers[ + given_name + ] + return untied_parameters_and_buffers + + +@contextlib.contextmanager +def _reparametrize_module( + module: "torch.nn.Module", + parameters_and_buffers: dict[str, Tensor], + tie_weights: bool = False, + strict: bool = False, + stack_weights: bool = False, +): + if tie_weights: + untied_parameters_and_buffers = _untie_named_tensors_map( + module, parameters_and_buffers + ) + else: + untied_parameters_and_buffers = parameters_and_buffers + + accessor = NamedMemberAccessor(module) + if strict: + missing_keys, unexpected_keys = accessor.check_keys( + untied_parameters_and_buffers + ) + error_msgs = [] + if len(unexpected_keys) > 0: + error_msgs.append( + f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}." + ) + if len(missing_keys) > 0: + error_msgs.append(f"Missing key(s): {', '.join(map(repr, missing_keys))}.") + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in reparametrizing for {}:\n\t{}".format( + module._get_name(), "\n\t".join(error_msgs) + ) + ) + + orig_parameters_and_buffers: dict[str, Tensor] = {} + try: + orig_parameters_and_buffers, _ = accessor.swap_tensors_dict( + untied_parameters_and_buffers, allow_missing=True + ) + yield + finally: + if stack_weights: + # When stacking is enabled, we will restore the weights in LIFO order. + orig_parameters_and_buffers = dict( + reversed(orig_parameters_and_buffers.items()) + ) + new_parameters_and_buffers, _ = accessor.swap_tensors_dict( + orig_parameters_and_buffers, allow_missing=True + ) + # Sometimes the module is not completely stateless and has some in-place modifications on + # the _parameters and _buffers dictionaries. + # Write the changed parameters and buffers back to the original dict. + parameters_and_buffers.update( + { + k: new_parameters_and_buffers[k] + for k in parameters_and_buffers + if k in new_parameters_and_buffers + } + ) + + +@deprecated( + "`torch.nn.utils.stateless.functional_call` is deprecated as of PyTorch 2.0 " + "and will be removed in a future version of PyTorch. " + "Please use `torch.func.functional_call` instead which is a drop-in replacement.", + category=FutureWarning, +) +def functional_call( + module: "torch.nn.Module", + parameters_and_buffers: dict[str, Tensor], + args: Any | tuple | None = None, + kwargs: dict[str, Any] | None = None, + *, + tie_weights: bool = True, + strict: bool = False, +): + r"""Perform a functional call on the module by replacing the module parameters and buffers with the provided ones. + + .. warning:: + + This API is deprecated as of PyTorch 2.0 and will be removed in a future + version of PyTorch. Please use :func:`torch.func.functional_call` instead, + which is a drop-in replacement for this API. + + .. note:: If the module has active parametrizations, passing a value in the + :attr:`parameters_and_buffers` argument with the name set to the regular parameter + name will completely disable the parametrization. + If you want to apply the parametrization function to the value passed + please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``. + + .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected + in the `parameters_and_buffers` input. + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # does self.foo = self.foo + 1 + >>> print(mod.foo) # tensor(0.) + >>> functional_call(mod, a, torch.ones(())) + >>> print(mod.foo) # tensor(0.) + >>> print(a['foo']) # tensor(1.) + + .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the + tie_weights flag. + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied + >>> print(mod.foo) # tensor(1.) + >>> mod(torch.zeros(())) # tensor(2.) + >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too + >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated + >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} + >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.) + + Args: + module (torch.nn.Module): the module to call + parameters_and_buffers (dict of str and Tensor): the parameters that will be used in + the module call. + args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument. + kwargs (dict): keyword arguments to be passed to the module call + tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as + tied in the reparamaterized version. Therefore, if True and different values are passed for the tied + parameters and buffers, it will error. If False, it will not respect the originally tied parameters and + buffers unless the values passed for both weights are the same. Default: True. + strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and + buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will + error. Default: False. + + Returns: + Any: the result of calling ``module``. + """ + return _functional_call( + module, + parameters_and_buffers, + args, + kwargs, + tie_weights=tie_weights, + strict=strict, + ) + + +def _functional_call( + module: "torch.nn.Module", + parameters_and_buffers: dict[str, Tensor], + args: Any | tuple | None = None, + kwargs: dict[str, Any] | None = None, + *, + tie_weights: bool = True, + strict: bool = False, +): + # TODO allow kwargs such as unsafe and others for parametrization + if ( + torch.jit.is_tracing() + or torch.jit.is_scripting() + or isinstance( + module, + ( + torch.jit.RecursiveScriptModule, + torch.jit.ScriptModule, + torch.jit.ScriptFunction, + ), + ) + ): + raise RuntimeError("The stateless API can't be used with Jitted modules") + if isinstance(module, torch.nn.DataParallel): + raise RuntimeError( + "The stateless API can't be used with nn.DataParallel module" + ) + if kwargs is None: + kwargs = {} + if args is None: + args = () + elif not isinstance(args, tuple): + args = (args,) + with _reparametrize_module( + module, parameters_and_buffers, tie_weights=tie_weights, strict=strict + ): + return module(*args, **kwargs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/weight_norm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/weight_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..7b336e8b8c08e59b2ee3d12ab481bacb4b6aa33d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nn/utils/weight_norm.py @@ -0,0 +1,165 @@ +# mypy: allow-untyped-defs +r"""Weight Normalization from https://arxiv.org/abs/1602.07868.""" + +from typing import Any, TypeVar +from typing_extensions import deprecated + +from torch import _weight_norm, norm_except_dim +from torch.nn.modules import Module +from torch.nn.parameter import Parameter, UninitializedParameter + + +__all__ = ["WeightNorm", "weight_norm", "remove_weight_norm"] + + +class WeightNorm: + name: str + dim: int + + def __init__(self, name: str, dim: int) -> None: + if dim is None: + dim = -1 + self.name = name + self.dim = dim + + # TODO Make return type more specific + def compute_weight(self, module: Module) -> Any: + g = getattr(module, self.name + "_g") + v = getattr(module, self.name + "_v") + return _weight_norm(v, g, self.dim) + + @staticmethod + @deprecated( + "`torch.nn.utils.weight_norm` is deprecated " + "in favor of `torch.nn.utils.parametrizations.weight_norm`.", + category=FutureWarning, + ) + def apply(module, name: str, dim: int) -> "WeightNorm": + for hook in module._forward_pre_hooks.values(): + if isinstance(hook, WeightNorm) and hook.name == name: + raise RuntimeError( + f"Cannot register two weight_norm hooks on the same parameter {name}" + ) + + if dim is None: + dim = -1 + + fn = WeightNorm(name, dim) + + weight = getattr(module, name) + if isinstance(weight, UninitializedParameter): + raise ValueError( + "The module passed to `WeightNorm` can't have uninitialized parameters. " + "Make sure to run the dummy forward before applying weight normalization" + ) + # remove w from parameter list + del module._parameters[name] + + # add g and v as new parameters and express w as g/||v|| * v + module.register_parameter( + name + "_g", Parameter(norm_except_dim(weight, 2, dim).data) + ) + module.register_parameter(name + "_v", Parameter(weight.data)) + setattr(module, name, fn.compute_weight(module)) + + # recompute weight before every forward() + module.register_forward_pre_hook(fn) + + return fn + + def remove(self, module: Module) -> None: + weight = self.compute_weight(module) + delattr(module, self.name) + del module._parameters[self.name + "_g"] + del module._parameters[self.name + "_v"] + setattr(module, self.name, Parameter(weight.data)) + + def __call__(self, module: Module, inputs: Any) -> None: + setattr(module, self.name, self.compute_weight(module)) + + +T_module = TypeVar("T_module", bound=Module) + + +def weight_norm(module: T_module, name: str = "weight", dim: int = 0) -> T_module: + r"""Apply weight normalization to a parameter in the given module. + + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} + + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude + (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``). + Weight normalization is implemented via a hook that recomputes the weight + tensor from the magnitude and direction before every :meth:`~Module.forward` + call. + + By default, with ``dim=0``, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + ``dim=None``. + + See https://arxiv.org/abs/1602.07868 + + .. warning:: + + This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm` + which uses the modern parametrization API. The new ``weight_norm`` is compatible + with ``state_dict`` generated from old ``weight_norm``. + + Migration guide: + + * The magnitude (``weight_g``) and direction (``weight_v``) are now expressed + as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1`` + respectively. If this is bothering you, please comment on + https://github.com/pytorch/pytorch/issues/102999 + + * To remove the weight normalization reparametrization, use + :func:`torch.nn.utils.parametrize.remove_parametrizations`. + + * The weight is no longer recomputed once at module forward; instead, it will + be recomputed on every access. To restore the old behavior, use + :func:`torch.nn.utils.parametrize.cached` before invoking the module + in question. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute the norm + + Returns: + The original module with the weight norm hook + + Example:: + + >>> m = weight_norm(nn.Linear(20, 40), name='weight') + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_g.size() + torch.Size([40, 1]) + >>> m.weight_v.size() + torch.Size([40, 20]) + + """ + WeightNorm.apply(module, name, dim) + return module + + +def remove_weight_norm(module: T_module, name: str = "weight") -> T_module: + r"""Remove the weight normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = weight_norm(nn.Linear(20, 40)) + >>> remove_weight_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, WeightNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError(f"weight_norm of '{name}' not found in {module}") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5cea2cdd2dfde0cbf65d6bf1e923198196950bb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/_numeric_suite.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/_numeric_suite.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ae858447200b181e76e2bdd3baafbe219859ce5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/_numeric_suite.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/_numeric_suite_fx.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/_numeric_suite_fx.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9dd12ce300112337421ebdf65c9af4514b1c6a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/_numeric_suite_fx.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/_quantized_conversions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/_quantized_conversions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50e84fabb0897f22209215fdd7fb7ff265cfa37f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/_quantized_conversions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28d2fcfae32fe156b7c0f28dfadcc72ce62fa855 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30fcd219d21875f3867f6c90558f2e2da03661a4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40a033e749c726d359a7fd80ccfd6f4be68610bc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/observer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/observer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03bf001aa3287be6b16e36bc591c451d84cbd245 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/observer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/qconfig.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/qconfig.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca5a1bcd740fa1c27e48d4fe65d9440e4ba57c04 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/qconfig.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quant_type.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quant_type.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..313535bf314fcc9d67040058a2320485973797d1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quant_type.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17ca5766442da4ac7a8dff8cef4655f2f32aa58c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59e2ca0eb1aad85e27d907a0a934a8503e1aabc8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_fx.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_fx.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1f29f981af3328fbfc171a8cff900b3a5dc11c1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_fx.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49a1bf884d46779ec151738a5bcfb8c4c6922c2e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/stubs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/stubs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bee0c9d33f919dbd733095d4b87d0f351a9b9fd0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/stubs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a67e8822c6e17d19d647f56e82ffcab511143af6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c01cbd457374c27e40b07daca5ae1644a701767d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__init__.py @@ -0,0 +1,15 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.convert import convert +from torch.ao.quantization.fx.fuse import fuse + +# omitting files that's unlikely to be used right now, for example +# the newly added lower_to_fbgemm etc. +from torch.ao.quantization.fx.prepare import prepare diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0111d5fcd0bcedd695c48155a9c009439be17df4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/_equalize.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/_equalize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fade105beda2b0768c371c8c00379e1dac6d1a2b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/_equalize.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/convert.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/convert.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97013f4a5698f2619e1d6cfaf406b1c80bab4f11 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/convert.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/fuse.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7c503580e22cccef0c13b0f3f8cfb5c0331a12f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/fuse.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/fusion_patterns.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/fusion_patterns.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20fa0943196abee8485c43570571a39cdc600682 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/fusion_patterns.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/graph_module.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/graph_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a727bdf5e1a18ac4cdb55426b169b1f164032f53 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/graph_module.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/match_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/match_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fb20f30f7f6888ba0b4f195b360b2555c473785 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/match_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/pattern_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/pattern_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..447e1e375062380294252f7cce295e5a1d40c262 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/pattern_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/prepare.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/prepare.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fab36e95614aa3280b5058c2c661a4f1da17e45 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/prepare.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/quantization_patterns.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/quantization_patterns.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..205a8be80d4a20d65c604e517886f3be1d72283d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/quantization_patterns.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/quantization_types.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/quantization_types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ec01fc3cf94e19b9e978e7add9a05eb45b6a105 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/quantization_types.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2b5b5affb16d17f19eaf332e7f389b453d267bb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/_equalize.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/_equalize.py new file mode 100644 index 0000000000000000000000000000000000000000..d6b8611d4a769a9c1e93682180becc5117020d55 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/_equalize.py @@ -0,0 +1,39 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx._equalize import ( + _convert_equalization_ref, + _InputEqualizationObserver, + _WeightEqualizationObserver, + calculate_equalization_scale, + clear_weight_quant_obs_node, + convert_eq_obs, + CUSTOM_MODULE_SUPP_LIST, + custom_module_supports_equalization, + default_equalization_qconfig, + EqualizationQConfig, + fused_module_supports_equalization, + get_equalization_qconfig_dict, + get_layer_sqnr_dict, + get_op_node_and_weight_eq_obs, + input_equalization_observer, + is_equalization_observer, + maybe_get_next_equalization_scale, + maybe_get_next_input_eq_obs, + maybe_get_weight_eq_obs_node, + nn_module_supports_equalization, + node_supports_equalization, + remove_node, + reshape_scale, + scale_input_observer, + scale_weight_functional, + scale_weight_node, + update_obs_for_equalization, + weight_equalization_observer, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/convert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..30a661da41e5e2bb417a0e0aa6c7088a1b8ea7e4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/convert.py @@ -0,0 +1,10 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.convert import convert diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/fuse.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..22ad750e9f8784376cecee4f5d10cfcd1488a7ac --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/fuse.py @@ -0,0 +1,10 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.fuse import fuse diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/fusion_patterns.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/fusion_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..982d919655f36320c87e066fa04e8ab10e70a719 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/fusion_patterns.py @@ -0,0 +1,10 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/graph_module.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..74b63903d7400c037ca15ac7b9cf200d70d07ab9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/graph_module.py @@ -0,0 +1,18 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.graph_module import ( + _is_observed_module, + _is_observed_standalone_module, + FusedGraphModule, + GraphModule, + ObservedGraphModule, + ObservedStandaloneGraphModule, + QuantizedGraphModule, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/match_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/match_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8585a21ad445dd20338d24267d8a0f05f96d0f92 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/match_utils.py @@ -0,0 +1,15 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.match_utils import ( + _find_matches, + _is_match, + _MatchResult, + MatchAllNode, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/pattern_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/pattern_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fa601d1eb619c14a37f95177b9850942ab361974 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/pattern_utils.py @@ -0,0 +1,36 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.pattern_utils import ( + _register_fusion_pattern, + _register_quant_pattern, + get_default_fusion_patterns, + get_default_output_activation_post_process_map, + get_default_quant_patterns, + QuantizeHandler, +) + + +# QuantizeHandler.__module__ = _NAMESPACE +_register_fusion_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils" +get_default_fusion_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils" +_register_quant_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils" +get_default_quant_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils" +get_default_output_activation_post_process_map.__module__ = ( + "torch.ao.quantization.fx.pattern_utils" +) + +# __all__ = [ +# "QuantizeHandler", +# "_register_fusion_pattern", +# "get_default_fusion_patterns", +# "_register_quant_pattern", +# "get_default_quant_patterns", +# "get_default_output_activation_post_process_map", +# ] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/prepare.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..a6007ef242af5d33566065a0b9d570399deccf94 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/prepare.py @@ -0,0 +1,10 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.prepare import prepare diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/quantization_patterns.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/quantization_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..89f8d4406e9126525d6c1518c6743a5c84c7b760 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/quantization_patterns.py @@ -0,0 +1,49 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.quantize_handler import ( + BatchNormQuantizeHandler, + BinaryOpQuantizeHandler, + CatQuantizeHandler, + ConvReluQuantizeHandler, + CopyNodeQuantizeHandler, + CustomModuleQuantizeHandler, + DefaultNodeQuantizeHandler, + EmbeddingQuantizeHandler, + FixedQParamsOpQuantizeHandler, + GeneralTensorShapeOpQuantizeHandler, + LinearReLUQuantizeHandler, + QuantizeHandler, + RNNDynamicQuantizeHandler, + StandaloneModuleQuantizeHandler, +) + + +QuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +BinaryOpQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +CatQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +ConvReluQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +LinearReLUQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +BatchNormQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +EmbeddingQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +RNNDynamicQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +DefaultNodeQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +FixedQParamsOpQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +CopyNodeQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +CustomModuleQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +GeneralTensorShapeOpQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +StandaloneModuleQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/quantization_types.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/quantization_types.py new file mode 100644 index 0000000000000000000000000000000000000000..0820ea057078ea89da763b1c5864089b8682a9f3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/quantization_types.py @@ -0,0 +1,10 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.utils import Pattern, QuantizerCls diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e45c82b8fb6f2379a5805442666f5551c2680683 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/quantization/fx/utils.py @@ -0,0 +1,21 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.utils import ( + all_node_args_have_no_tensors, + assert_and_get_unique_device, + create_getattr_from_value, + get_custom_module_class_keys, + get_linear_prepack_op_for_dtype, + get_new_attr_name_with_prefix, + get_non_observable_arg_indexes_and_types, + get_qconv_prepack_op, + graph_module_from_producer_nodes, + maybe_get_next_module, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..721534197b7d32041befe0d34f3ebc6d63573a15 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/__pycache__/_comparison.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/__pycache__/_comparison.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ed9e5444225b761f7c7374595404096c9d92d6f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/__pycache__/_comparison.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/__pycache__/_creation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/__pycache__/_creation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc99f985ef0f669289b45db04cf75eb628200265 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/__pycache__/_creation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/__pycache__/_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bd157c10f47d317631509b711b93fe7c47550ff Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/__pycache__/_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6b1e956f52804b34b960a01562bdd522a5d1786 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/autocast_test_lists.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/autocast_test_lists.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f7b290c89f365d1294b9329e47b02c7e6bad377 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/autocast_test_lists.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/autograd_function_db.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/autograd_function_db.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8156113feeb84f44ed7ca71426beb9e73602aad1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/autograd_function_db.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/check_kernel_launches.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/check_kernel_launches.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fac9df75f5996fefc46c50f7738e53a24af01ac Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/check_kernel_launches.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_cuda.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_cuda.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89a4257949f982099d5411c5055f13cadfc5c905 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_cuda.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_device_type.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_device_type.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbc08bddfa627387290350485a46ac3a182e14f2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_device_type.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_dist_composable.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_dist_composable.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b04e1b8a72930243f86b26eb9bc8f7c8fca277a8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_dist_composable.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_distributed.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_distributed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61cc4a467051ac008b6a417bbc2204dd45428877 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_distributed.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_dtype.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_dtype.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7994063362dfd1ebe7db74dac75f9ec11752bffc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_dtype.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_fsdp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_fsdp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6217a4468b39b7389f3d4fbece745f4829bdf690 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_fsdp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_jit.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_jit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf9c0ba55175b2cee0c64722bfaccbfbc9568ef1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_jit.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_mkldnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_mkldnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..686e5bdc5b6fb2c7806cd6d88906e28b4bf6675f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_mkldnn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_mps.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_mps.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5163f27892d28e86410d7b6d76f3f394b9f1ec2b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_mps.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_optimizers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_optimizers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..197644a552485df47268315044b0ce3a7733c29e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_optimizers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_pruning.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_pruning.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7e7455bef7323f71e09375a377de1cd7eed8e6d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_pruning.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_quantized.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_quantized.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4e586cde50962e4700fa4e78bd0041df5625600 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_quantized.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_subclass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_subclass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7701824cbd1131d7024a66c1be75e7f173903d9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_subclass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/composite_compliance.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/composite_compliance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fd5f0363ee983408add33d060c9e24eb9e8f6a7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/composite_compliance.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/custom_op_db.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/custom_op_db.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84fd195f5f656b778ccd259cd081517fcb51f40e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/custom_op_db.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/custom_tensor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/custom_tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..365543ca03c501d61634455551d6bbf976e568ae Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/custom_tensor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dist_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dist_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d2170e7f0be399bf179dea9db15f22b7fb6c9e7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dist_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dynamo_pytree_test_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dynamo_pytree_test_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d57c8fce26d46fd086927205d9a9a8ab7560a500 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dynamo_pytree_test_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dynamo_test_failures.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dynamo_test_failures.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f721b62f25ff7863f36b646b51f6901b32101ec2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dynamo_test_failures.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aea54b08fa856cb460a622cbc3675106c3edfc55 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module2.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..621c8c855f98cb2230a68d28e63c0d5ff6043bf4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module2.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module3.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module3.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8dfc45347f523e813f258eb012f9ec08e027232 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module3.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/hop_db.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/hop_db.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a600dc03da51a6f900b2c83a328ef7b4adbf529b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/hop_db.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/hypothesis_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/hypothesis_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7984625932b1ce17f6b285ec6cb661c134a21794 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/hypothesis_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/inductor_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/inductor_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..695a3ffa114c9cfb73a6b08f1e3cf3dbce6f4170 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/inductor_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/jit_metaprogramming_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/jit_metaprogramming_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c70af30e8da6cde1e0095338dd6999d59b0d9e13 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/jit_metaprogramming_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/jit_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/jit_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53eaf5018e48a228d8bfd93a0abaffb4630840ed Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/jit_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_tensor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d710cd5a11ae0e2cfc3fb5dfdf3841a8ab19c7e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_tensor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4a59d6f493a7a17f5728cb51cd4beeb33102d5c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/quantization_torch_package_models.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/quantization_torch_package_models.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..110b8dd99e9b3bc5fe62135b2666ec34df527c1a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/quantization_torch_package_models.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/static_module.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/static_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b276874ced41ec533d9297a09a05ddd99e30d9e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/static_module.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/subclasses.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/subclasses.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8caa76ea4349faade5ce47b8a6c66d50fe7975f5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/subclasses.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/torchbind_impls.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/torchbind_impls.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78d35929125af4cb21a14e188bc848d8ebb9dfd8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/torchbind_impls.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/triton_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/triton_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fd357c2a1491f8cabd5caf53051ab4224a16ab9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/triton_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/two_tensor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/two_tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e28ccb3893da8e6562bdb0e7e6ebc2a548d824d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/two_tensor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/autograd_function_db.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/autograd_function_db.py new file mode 100644 index 0000000000000000000000000000000000000000..46abb4bb758dde5752d974f5459ccd77ac9c0f74 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/autograd_function_db.py @@ -0,0 +1,633 @@ +# mypy: ignore-errors + +import torch +from functools import partial +from torch.testing import make_tensor +from torch.testing._internal.opinfo.core import ( + OpInfo, + SampleInput, +) +from torch.testing._internal.common_dtype import all_types_and +import numpy as np + +# Note: [autograd.Function db] +# +# This is a collection of autograd.Function test cases written as OpInfos +# so they can easily be consumed by OpInfo-based tests to check if a subsystem +# supports autograd.Function. +# +# Axes: +# - saves {output, input, intermediate, non-tensor} +# - {inputs, output} x {single tensor, tensors, arbitrary objects} +# - Uses {mark_dirty, mark_non_differentiable, once_differentiable} + + +def to_numpy(tensor): + return tensor.cpu().numpy() + + +class NumpyCube(torch.autograd.Function): + @staticmethod + def forward(input): + input_np = to_numpy(input) + dinput = torch.tensor(3 * input_np ** 2, device=input.device) + return torch.tensor(input_np ** 3, device=input.device), dinput + + @staticmethod + def setup_context(ctx, inputs, output): + ctx.save_for_backward(inputs[0], output[1]) + ctx.save_for_forward(inputs[0], output[1]) + + @staticmethod + def backward(ctx, grad_output, grad_saved): + input, dinput = ctx.saved_tensors + return NumpyMul.apply(grad_output, dinput) + 6 * NumpyMul.apply(grad_saved, input) + + @staticmethod + def vmap(info, in_dims, input): + result = NumpyCube.apply(input) + return result, (in_dims[0], in_dims[0]) + + @staticmethod + def jvp(ctx, input_tangent): + input, dinput = ctx.saved_tensors + return NumpyMul.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input) + + +class CubeGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x): + return x ** 3, 3 * x ** 2 + + @staticmethod + def setup_context(ctx, inputs, outputs): + ctx.save_for_backward(inputs[0], outputs[1]) + ctx.save_for_forward(inputs[0], outputs[1]) + + @staticmethod + def backward(ctx, grad_output, grad_saved): + _input, dinput = ctx.saved_tensors + result = grad_output * dinput + 6 * dinput + return result + + @staticmethod + def jvp(ctx, input_tangent): + input, dinput = ctx.saved_tensors + return MulGenVmap.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input) + + +def sample_inputs_numpy_cube(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(1, low=0.8, high=2), args=()) + + +class NumpyCubeNotComposable(torch.autograd.Function): + @staticmethod + def forward(input): + input_np = to_numpy(input) + return torch.tensor(input_np ** 3, device=input.device), input_np + + @staticmethod + def setup_context(ctx, inputs, output): + _, input_np = output + ctx.input_np = input_np + ctx.device = inputs[0].device + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output, grad_saved): + result_np = 3 * (ctx.input_np ** 2) + return torch.tensor(result_np, device=ctx.device) + + +class NumpyMul(torch.autograd.Function): + @staticmethod + def forward(x, y): + return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) + + @staticmethod + def setup_context(ctx, inputs, output): + ctx.save_for_backward(*inputs) + ctx.save_for_forward(*inputs) + + @staticmethod + def backward(ctx, grad_output): + x, y = ctx.saved_tensors + gx = None + if ctx.needs_input_grad[0]: + gx = NumpyMul.apply(grad_output, y) + gy = None + if ctx.needs_input_grad[1]: + gy = NumpyMul.apply(grad_output, x) + return gx, gy + + @staticmethod + def vmap(info, in_dims, x, y): + x_bdim, y_bdim = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + result = NumpyMul.apply(x, y) + result = result.movedim(-1, 0) + return result, 0 + + @staticmethod + def jvp(ctx, x_tangent, y_tangent): + x, y = ctx.saved_tensors + return x_tangent * y + y_tangent * x + +def sample_inputs_numpy_mul(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # Broadcasting + yield SampleInput(make_arg(4, low=0.9, high=2), args=(make_arg(3, 4, low=0.9, high=2),)) + +def sample_inputs_numpy_mul_scalar(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(4, low=0.9, high=2), args=(), kwargs={"scalar": 3.14}) + +class MulGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, y): + return x * y + + @staticmethod + def setup_context(ctx, inputs, outputs): + ctx.save_for_backward(*inputs) + ctx.save_for_forward(*inputs) + + @staticmethod + def backward(ctx, grad_output): + x, y = ctx.saved_tensors + gx = None + if ctx.needs_input_grad[0]: + gx = MulGenVmap.apply(grad_output, y) + gy = None + if ctx.needs_input_grad[1]: + gy = MulGenVmap.apply(grad_output, x) + return gx, gy + + @staticmethod + def jvp(ctx, x_tangent, y_tangent): + x, y = ctx.saved_tensors + return x_tangent * y + y_tangent * x + + +class NumpyExp_(torch.autograd.Function): + @staticmethod + def forward(x): + x_np = to_numpy(x) + np.exp(x_np, x_np) + return x + + @staticmethod + def setup_context(ctx, inputs, output): + x, = inputs + ctx.mark_dirty(x) + ctx.save_for_backward(output) + ctx.save_for_forward(output) + + @staticmethod + def backward(ctx, grad_output): + output, = ctx.saved_tensors + return NumpyMul.apply(grad_output, output) + + @staticmethod + def vmap(info, in_dims, x): + NumpyExp_.apply(x) + return x, in_dims[0] + + @staticmethod + def jvp(ctx, x_tangent): + # Doesn't call numpy operations because I didn't want to write NumpyMul_ + output, = ctx.saved_tensors + x_tangent.mul_(output) + return x_tangent + +class NumpySort(torch.autograd.Function): + @staticmethod + def forward(x, dim): + device = x.device + x = to_numpy(x) + ind = np.argsort(x, axis=dim) + ind_inv = np.argsort(ind, axis=dim) + return ( + torch.tensor(x, device=device), + torch.tensor(ind, device=device), + torch.tensor(ind_inv, device=device), + ) + + @staticmethod + def setup_context(ctx, inputs, output): + _x, dim = inputs + _, ind, ind_inv = output + ctx.mark_non_differentiable(ind, ind_inv) + ctx.save_for_backward(ind, ind_inv) + ctx.save_for_forward(ind, ind_inv) + ctx.dim = dim + + @staticmethod + def backward(ctx, grad_output, _0, _1): + ind, ind_inv = ctx.saved_tensors + return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None + + @staticmethod + def vmap(info, in_dims, x, dim): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 0) + # wrap dim + dim = dim if dim >= 0 else dim + x.dim() - 1 + return NumpySort.apply(x, dim + 1), (0, 0, 0) + + @staticmethod + def jvp(ctx, x_tangent, _): + ind, ind_inv = ctx.saved_tensors + return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim), None, None + +class SortGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, dim): + ind = torch.argsort(x, dim=dim) + ind_inv = torch.argsort(ind, axis=dim) + result = torch.take_along_dim(x, ind, dim=dim) + return result, ind, ind_inv + + @staticmethod + def setup_context(ctx, inputs, outputs): + x, dim = inputs + _, ind, ind_inv = outputs + ctx.mark_non_differentiable(ind, ind_inv) + ctx.save_for_backward(ind, ind_inv) + ctx.save_for_forward(ind, ind_inv) + ctx.dim = dim + + @staticmethod + def backward(ctx, grad_output, _0, _1): + ind, ind_inv = ctx.saved_tensors + return TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim), None + + @staticmethod + def jvp(ctx, x_tangent, _): + ind, ind_inv = ctx.saved_tensors + return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim), None, None + + +def sample_inputs_numpy_sort(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(3, 5), args=(1,)) + + +def sample_inputs_numpy_take(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + tensor = make_arg(3, 5) + dim = 1 + _, ind, ind_inv = NumpySort.apply(tensor, 1) + yield SampleInput(tensor, args=(ind, ind_inv, dim)) + + +class NumpyTake(torch.autograd.Function): + @staticmethod + def forward(x, ind, ind_inv, dim): + device = x.device + x = to_numpy(x) + ind = to_numpy(ind) + return torch.tensor(np.take_along_axis(x, ind, dim), device=device) + + @staticmethod + def setup_context(ctx, inputs, output): + _x, ind, ind_inv, dim = inputs + ctx.save_for_backward(ind, ind_inv) + ctx.save_for_forward(ind, ind_inv) + ctx.dim = dim + + @staticmethod + def backward(ctx, grad_output): + ind, ind_inv = ctx.saved_tensors + result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim) + return result, None, None, None + + @staticmethod + def vmap(info, in_dims, x, ind, ind_inv, dim): + x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims + + # wrap dim + logical_dim = x.dim() if x_bdim is None else x_bdim - 1 + dim = dim if dim >= 0 else dim + logical_dim + + def expand_bdim(x, x_bdim): + if x_bdim is None: + return x.expand(info.batch_size, *x.shape) + return x.movedim(x_bdim, 0) + + x = expand_bdim(x, x_bdim) + ind = expand_bdim(ind, ind_bdim) + ind_inv = expand_bdim(ind_inv, ind_inv_bdim) + + return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0 + + @staticmethod + def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _): + assert ind_tangent is None + assert ind_inv_tangent is None + ind, ind_inv = ctx.saved_tensors + return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim) + +class TakeGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, ind, ind_inv, dim): + return torch.take_along_dim(x, ind, dim) + + @staticmethod + def setup_context(ctx, inputs, outputs): + _x, ind, ind_inv, dim = inputs + ctx.save_for_backward(ind, ind_inv) + ctx.save_for_forward(ind, ind_inv) + ctx.dim = dim + + @staticmethod + def backward(ctx, grad_output): + ind, ind_inv = ctx.saved_tensors + result = TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim) + return result, None, None, None + + @staticmethod + def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _): + ind, ind_inv = ctx.saved_tensors + return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim) + +class Select(torch.autograd.Function): + @staticmethod + def forward(x, idx): + return x[idx] + + @staticmethod + def setup_context(ctx, inputs, output): + x, idx = inputs + ctx.x_shape = x.shape + ctx.idx = idx + + @staticmethod + def backward(ctx, grad_output): + result = grad_output.new_zeros(ctx.x_shape) + result[ctx.idx] = grad_output + return result, None + + @staticmethod + def vmap(info, in_dims, x, idx): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 1) + return Select.apply(x, idx), 0 + + @staticmethod + def jvp(ctx, x_tangent, _): + return Select.apply(x_tangent, ctx.idx) + +class SelectGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, idx): + return x[idx] + + @staticmethod + def setup_context(ctx, inputs, outputs): + x, idx = inputs + ctx.x_shape = x.shape + ctx.idx = idx + + @staticmethod + def backward(ctx, grad_output): + result = grad_output.new_zeros(ctx.x_shape) + result[ctx.idx] = grad_output + return result, None + + @staticmethod + def jvp(ctx, x_tangent, _): + return SelectGenVmap.apply(x_tangent, ctx.idx) + + +def sample_inputs_select(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(3, 5), args=(2,)) + +class ScaleGradGenVmap(torch.autograd.Function): + generate_vmap_rule = True + scale = 3.14 + + @staticmethod + def forward(x): + return x.clone() + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def backward(ctx, grad_output): + return grad_output * ScaleGradGenVmap.scale + + @staticmethod + def jvp(ctx, x_tangent): + return x_tangent * ScaleGradGenVmap.scale + +class ZeroGradientsGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, y): + return x.clone(), y.clone() + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def backward(ctx, gx, gy): + # Intentionally returning torch.zeros instead of zeros_like or new_zeros. + # Also intentionally not None. + return ( + # Intentionally too-large gradient + torch.zeros(3, 4, *gx.shape, dtype=gx.dtype, device=gx.device), + torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device), + ) + + @staticmethod + def jvp(ctx, gx, gy): + # Intentionally returning torch.zeros instead of zeros_like or new_zeros. + # Also intentionally not None. + return ( + torch.zeros(gx.shape, dtype=gx.dtype, device=gx.device), + torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device), + ) + + +def sample_inputs_forward_default_args(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(3, 5)) + + +class ForwardHasDefaultArgs(torch.autograd.Function): + @staticmethod + def forward(x, idx=(2,)): + return x[idx] + + @staticmethod + def setup_context(ctx, inputs, output): + x, idx = inputs + ctx.x_shape = x.shape + ctx.idx = idx + + @staticmethod + def backward(ctx, grad_output): + result = grad_output.new_zeros(ctx.x_shape) + result[ctx.idx] = grad_output + return result, None + + @staticmethod + def vmap(info, in_dims, x, idx): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 1) + return ForwardHasDefaultArgs.apply(x, idx), 0 + + @staticmethod + def jvp(ctx, x_tangent, _): + return ForwardHasDefaultArgs.apply(x_tangent, ctx.idx) + + +autograd_function_db = [ + OpInfo( + 'NumpyCubeAutogradFunction', + op=NumpyCube.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyExpMarkDirtyAutogradFunction', + op=lambda x: NumpyExp_.apply(x.clone()), + inplace_variant=NumpyExp_.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyMulAutogradFunction', + op=NumpyMul.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_mul, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyCubeNotComposableAutogradFunction', + op=lambda x: NumpyCubeNotComposable.apply(x)[0], + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpySortAutogradFunction', + op=NumpySort.apply, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + sample_inputs_func=sample_inputs_numpy_sort, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + gradcheck_wrapper=lambda y, ind: y, + ), + OpInfo( + 'NumpyTakeAutogradFunction', + op=NumpyTake.apply, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + sample_inputs_func=sample_inputs_numpy_take, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'SelectAutogradFunction', + op=Select.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_select, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'CubeGenVmapAutogradFunction', + op=CubeGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'MulGenVmapAutogradFunction', + op=MulGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_mul, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'SortGenVmapAutogradFunction', + op=SortGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_sort, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + gradcheck_wrapper=lambda y, ind: y, + ), + OpInfo( + 'SelectGenVmapAutogradFunction', + op=SelectGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_select, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'ScaleGradGenVmapAutogradFunction', + op=ScaleGradGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'ZeroGradientsGenVmapAutogradFunction', + op=ZeroGradientsGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_mul, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'ForwardHasDefaultArgsAutogradFunction', + op=ForwardHasDefaultArgs.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_forward_default_args, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/codegen/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e3572cfc4c6a0ddc3d8fa2e1b056415204acdfa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/codegen/__init__.py @@ -0,0 +1 @@ +# mypy: ignore-errors diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_cuda.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5d0cf2992110f4cc0120f58fb7ba39f8f87947 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_cuda.py @@ -0,0 +1,387 @@ +# mypy: ignore-errors + +r"""This file is allowed to initialize CUDA context when imported.""" + +import functools +import torch +import torch.cuda +from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS, IS_MACOS +import inspect +import contextlib +import os +import unittest + + +CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized() + + +TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2 +CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None +# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN +if TEST_WITH_ROCM: + TEST_CUDNN = LazyVal(lambda: TEST_CUDA) +else: + TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE))) + +TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0) +ROCM_VERSION = LazyVal(lambda : tuple(int(v) for v in torch.version.hip.split('.')[:2]) if torch.version.hip else (0, 0)) + +SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)) +SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0)) +SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 0)) +SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5)) +SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)) +SM89OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)) +SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)) +SM100OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0)) +SM120OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (12, 0)) + +IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10 + and torch.cuda.get_device_capability()[1] > 0) +IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and (torch.cuda.get_device_capability() in [(7, 2), (8, 7)] or IS_THOR)) +IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9)) +IS_SM90 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)) +IS_SM100 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (10, 0)) + +def evaluate_gfx_arch_within(arch_list): + if not torch.cuda.is_available(): + return False + gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName + effective_arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name) + # gcnArchName can be complicated strings like gfx90a:sramecc+:xnack- + # Hence the matching should be done reversely + return any(arch in effective_arch for arch in arch_list) + +def CDNA3OrLater(): + return evaluate_gfx_arch_within(["gfx940", "gfx941", "gfx942", "gfx950"]) + +def CDNA2OrLater(): + return evaluate_gfx_arch_within(["gfx90a", "gfx942"]) + +def evaluate_platform_supports_flash_attention(): + if TEST_WITH_ROCM: + arch_list = ["gfx90a", "gfx942", "gfx1201", "gfx950"] + if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": + arch_list += ["gfx1100", "gfx1101", "gfx1102", "gfx1150", "gfx1151", "gfx1200"] + return evaluate_gfx_arch_within(arch_list) + if TEST_CUDA: + return not IS_WINDOWS and SM80OrLater + return False + +def evaluate_platform_supports_efficient_attention(): + if TEST_WITH_ROCM: + arch_list = ["gfx90a", "gfx942", "gfx1201", "gfx950"] + if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": + arch_list += ["gfx1100", "gfx1101", "gfx1102", "gfx1150", "gfx1151", "gfx1200"] + return evaluate_gfx_arch_within(arch_list) + if TEST_CUDA: + return True + return False + +def evaluate_platform_supports_cudnn_attention(): + return (not TEST_WITH_ROCM) and SM80OrLater and (TEST_CUDNN_VERSION >= 90000) + +def evaluate_platform_supports_green_context(): + if IS_WINDOWS: + return False + if not _get_torch_cuda_version() >= (12, 8): + return False + driver_version = torch.utils.collect_env.get_nvidia_driver_version(torch.utils.collect_env.run) + if driver_version is None: + return False + return int(driver_version.split('.')[0]) >= 570 + +PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention()) +PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention()) +PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention()) +# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate +PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or + PLATFORM_SUPPORTS_CUDNN_ATTENTION or + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) + +PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM + +PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater) + +PLATFORM_SUPPORTS_GREEN_CONTEXT: bool = LazyVal(lambda: evaluate_platform_supports_green_context()) + +def evaluate_platform_supports_fp8(): + if torch.cuda.is_available(): + if torch.version.hip: + archs = ['gfx94'] + if ROCM_VERSION >= (6, 3): + archs.extend(['gfx120']) + if ROCM_VERSION >= (6, 5): + archs.append('gfx95') + for arch in archs: + if arch in torch.cuda.get_device_properties(0).gcnArchName: + return True + else: + return SM90OrLater or torch.cuda.get_device_capability() == (8, 9) + return False + +def evaluate_platform_supports_fp8_grouped_gemm(): + if torch.cuda.is_available(): + if torch.version.hip: + if "USE_FBGEMM_GENAI" not in torch.__config__.show(): + return False + archs = ['gfx942'] + for arch in archs: + if arch in torch.cuda.get_device_properties(0).gcnArchName: + return True + else: + return SM90OrLater and not SM100OrLater + return False + +def evaluate_platform_supports_mx_gemm(): + if torch.cuda.is_available(): + if torch.version.hip: + if ROCM_VERSION >= (7, 0): + return 'gfx950' in torch.cuda.get_device_properties(0).gcnArchName + else: + return SM100OrLater + return False + +def evaluate_platform_supports_mxfp8_grouped_gemm(): + if torch.cuda.is_available() and not torch.version.hip: + built_with_fbgemm_genai = "USE_FBGEMM_GENAI" in torch.__config__.show() + return built_with_fbgemm_genai and IS_SM100 + return False + +PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mx_gemm()) +PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8()) +PLATFORM_SUPPORTS_FP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_fp8_grouped_gemm()) +PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mxfp8_grouped_gemm()) + +if TEST_NUMBA: + try: + import numba.cuda + TEST_NUMBA_CUDA = numba.cuda.is_available() + except Exception: + TEST_NUMBA_CUDA = False + TEST_NUMBA = False +else: + TEST_NUMBA_CUDA = False + +# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and +# RNG have been initialized. +__cuda_ctx_rng_initialized = False + + +# after this call, CUDA context and RNG must have been initialized on each GPU +def initialize_cuda_context_rng(): + global __cuda_ctx_rng_initialized + assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng' + if not __cuda_ctx_rng_initialized: + # initialize cuda context and rng for memory tests + for i in range(torch.cuda.device_count()): + torch.randn(1, device=f"cuda:{i}") + __cuda_ctx_rng_initialized = True + + +@contextlib.contextmanager +def tf32_off(): + old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 + try: + torch.backends.cuda.matmul.allow_tf32 = False + with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False): + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul + + +@contextlib.contextmanager +def tf32_on(self, tf32_precision=1e-5): + old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 + old_precision = self.precision + try: + torch.backends.cuda.matmul.allow_tf32 = True + self.precision = tf32_precision + with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True): + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul + self.precision = old_precision + + +@contextlib.contextmanager +def tf32_enabled(): + """ + Context manager to temporarily enable TF32 for CUDA operations. + Restores the previous TF32 state after exiting the context. + """ + old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 + try: + torch.backends.cuda.matmul.allow_tf32 = True + with torch.backends.cudnn.flags( + enabled=None, benchmark=None, deterministic=None, allow_tf32=True + ): + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul + + +# This is a wrapper that wraps a test to run this test twice, one with +# allow_tf32=True, another with allow_tf32=False. When running with +# allow_tf32=True, it will use reduced precision as specified by the +# argument. For example: +# @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) +# @tf32_on_and_off(0.005) +# def test_matmul(self, device, dtype): +# a = ...; b = ...; +# c = torch.matmul(a, b) +# self.assertEqual(c, expected) +# In the above example, when testing torch.float32 and torch.complex64 on CUDA +# on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at +# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced +# precision to check values. +# +# This decorator can be used for function with or without device/dtype, such as +# @tf32_on_and_off(0.005) +# def test_my_op(self) +# @tf32_on_and_off(0.005) +# def test_my_op(self, device) +# @tf32_on_and_off(0.005) +# def test_my_op(self, device, dtype) +# @tf32_on_and_off(0.005) +# def test_my_op(self, dtype) +# if neither device nor dtype is specified, it will check if the system has ampere device +# if device is specified, it will check if device is cuda +# if dtype is specified, it will check if dtype is float32 or complex64 +# tf32 and fp32 are different only when all the three checks pass +def tf32_on_and_off(tf32_precision=1e-5, *, only_if=True): + def with_tf32_disabled(self, function_call): + with tf32_off(): + function_call() + + def with_tf32_enabled(self, function_call): + with tf32_on(self, tf32_precision): + function_call() + + def wrapper(f): + params = inspect.signature(f).parameters + arg_names = tuple(params.keys()) + + @functools.wraps(f) + def wrapped(*args, **kwargs): + kwargs.update(zip(arg_names, args, strict=False)) + cond = torch.cuda.is_tf32_supported() and only_if + if 'device' in kwargs: + cond = cond and (torch.device(kwargs['device']).type == 'cuda') + if 'dtype' in kwargs: + cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64}) + if cond: + with_tf32_disabled(kwargs['self'], lambda: f(**kwargs)) + with_tf32_enabled(kwargs['self'], lambda: f(**kwargs)) + else: + f(**kwargs) + + return wrapped + return wrapper + +# This is a wrapper that wraps a test to run it with TF32 turned off. +# This wrapper is designed to be used when a test uses matmul or convolutions +# but the purpose of that test is not testing matmul or convolutions. +# Disabling TF32 will enforce torch.float tensors to be always computed +# at full precision. +def with_tf32_off(f): + @functools.wraps(f) + def wrapped(*args, **kwargs): + with tf32_off(): + return f(*args, **kwargs) + + return wrapped + +def _get_magma_version(): + if 'Magma' not in torch.__config__.show(): + return (0, 0) + position = torch.__config__.show().find('Magma ') + version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0] + return tuple(int(x) for x in version_str.split(".")) + +def _get_torch_cuda_version(): + if torch.version.cuda is None: + return (0, 0) + cuda_version = str(torch.version.cuda) + return tuple(int(x) for x in cuda_version.split(".")) + +def _get_torch_rocm_version(): + if not TEST_WITH_ROCM or torch.version.hip is None: + return (0, 0) + rocm_version = str(torch.version.hip) + rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha + return tuple(int(x) for x in rocm_version.split(".")) + +def _check_cusparse_generic_available(): + return not TEST_WITH_ROCM + +def _check_hipsparse_generic_available(): + if not TEST_WITH_ROCM: + return False + if not torch.version.hip: + return False + + rocm_version = str(torch.version.hip) + rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha + rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) + return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1)) + + +TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available() +TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available() + +# Shared by test_torch.py and test_multigpu.py +def _create_scaling_models_optimizers(device="cuda", optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None): + # Create a module+optimizer that will use scaling, and a control module+optimizer + # that will not use scaling, against which the scaling-enabled module+optimizer can be compared. + mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device) + mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device) + with torch.no_grad(): + for c, s in zip(mod_control.parameters(), mod_scaling.parameters(), strict=True): + s.copy_(c) + + kwargs = {"lr": 1.0} + if optimizer_kwargs is not None: + kwargs.update(optimizer_kwargs) + opt_control = optimizer_ctor(mod_control.parameters(), **kwargs) + opt_scaling = optimizer_ctor(mod_scaling.parameters(), **kwargs) + + return mod_control, mod_scaling, opt_control, opt_scaling + +# Shared by test_torch.py, test_cuda.py and test_multigpu.py +def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None): + data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)), + (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)), + (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)), + (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))] + + loss_fn = torch.nn.MSELoss().to(device) + + skip_iter = 2 + + return _create_scaling_models_optimizers( + device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, + ) + (data, loss_fn, skip_iter) + + +def xfailIfSM89(func): + return func if not IS_SM89 else unittest.expectedFailure(func) + +def xfailIfSM89PreCUDA13(func): + """xfail on SM89 only for CUDA < 13. On CUDA 13+, test should pass on all architectures.""" + if IS_SM89 and _get_torch_cuda_version() < (13, 0): + return unittest.expectedFailure(func) + return func + +def xfailIfSM100OrLater(func): + return func if not SM100OrLater else unittest.expectedFailure(func) + +def xfailIfSM120OrLater(func): + return func if not SM120OrLater else unittest.expectedFailure(func) + +def xfailIfDistributedNotSupported(func): + return func if not (IS_MACOS or IS_JETSON) else unittest.expectedFailure(func) + +# Importing this module should NOT eagerly initialize CUDA +if not CUDA_ALREADY_INITIALIZED_ON_IMPORT: + assert not torch.cuda.is_initialized() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9acc6f0f7567627c30411ed4ddf61ba2022418bb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py @@ -0,0 +1,2038 @@ +# mypy: ignore-errors + +import copy +import gc +import inspect +import os +import runpy +import sys +import threading +import unittest +from collections import namedtuple +from collections.abc import Callable, Iterable, Sequence +from enum import Enum +from functools import partial, wraps +from typing import Any, ClassVar, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +from torch._inductor.utils import GPU_TYPES +from torch.testing._internal.common_cuda import ( + _get_torch_cuda_version, + _get_torch_rocm_version, + TEST_CUSPARSE_GENERIC, + TEST_HIPSPARSE_GENERIC, +) +from torch.testing._internal.common_dtype import get_all_dtypes +from torch.testing._internal.common_utils import ( + _TestParametrizer, + clear_tracked_input, + compose_parametrize_fns, + dtype_name, + get_tracked_input, + IS_FBCODE, + IS_MACOS, + is_privateuse1_backend_available, + IS_REMOTE_GPU, + IS_S390X, + IS_SANDCASTLE, + IS_WINDOWS, + NATIVE_DEVICES, + PRINT_REPRO_ON_FAILURE, + skipCUDANonDefaultStreamIf, + skipIfTorchDynamo, + TEST_HPU, + TEST_MKL, + TEST_MPS, + TEST_WITH_ASAN, + TEST_WITH_MIOPEN_SUGGEST_NHWC, + TEST_WITH_MTIA, + TEST_WITH_ROCM, + TEST_WITH_TORCHINDUCTOR, + TEST_WITH_TSAN, + TEST_WITH_UBSAN, + TEST_XPU, + TestCase, +) + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +try: + import psutil # type: ignore[import] + + HAS_PSUTIL = True +except ModuleNotFoundError: + HAS_PSUTIL = False + psutil = None + +# Note [Writing Test Templates] +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# This note was written shortly after the PyTorch 1.9 release. +# If you notice it's out-of-date or think it could be improved then please +# file an issue. +# +# PyTorch has its own framework for instantiating test templates. That is, for +# taking test classes that look similar to unittest or pytest +# compatible test classes and optionally doing the following: +# +# - instantiating a version of the test class for each available device type +# (often the CPU, CUDA, and META device types) +# - further instantiating a version of each test that's always specialized +# on the test class's device type, and optionally specialized further +# on datatypes or operators +# +# This functionality is similar to pytest's parametrize functionality +# (see https://docs.pytest.org/en/6.2.x/parametrize.html), but with considerable +# additional logic that specializes the instantiated test classes for their +# device types (see CPUTestBase and CUDATestBase below), supports a variety +# of composable decorators that allow for test filtering and setting +# tolerances, and allows tests parametrized by operators to instantiate +# only the subset of device type x dtype that operator supports. +# +# This framework was built to make it easier to write tests that run on +# multiple device types, multiple datatypes (dtypes), and for multiple +# operators. It's also useful for controlling which tests are run. For example, +# only tests that use a CUDA device can be run on platforms with CUDA. +# Let's dive in with an example to get an idea for how it works: +# +# -------------------------------------------------------- +# A template class (looks like a regular unittest TestCase) +# class TestClassFoo(TestCase): +# +# # A template test that can be specialized with a device +# # NOTE: this test case is not runnable by unittest or pytest because it +# # accepts an extra positional argument, "device", that they do not understand +# def test_bar(self, device): +# pass +# +# # Function that instantiates a template class and its tests +# instantiate_device_type_tests(TestCommon, globals()) +# -------------------------------------------------------- +# +# In the above code example we see a template class and a single test template +# that can be instantiated with a device. The function +# instantiate_device_type_tests(), called at file scope, instantiates +# new test classes, one per available device type, and new tests in those +# classes from these templates. It actually does this by removing +# the class TestClassFoo and replacing it with classes like TestClassFooCPU +# and TestClassFooCUDA, instantiated test classes that inherit from CPUTestBase +# and CUDATestBase respectively. Additional device types, like XLA, +# (see https://github.com/pytorch/xla) can further extend the set of +# instantiated test classes to create classes like TestClassFooXLA. +# +# The test template, test_bar(), is also instantiated. In this case the template +# is only specialized on a device, so (depending on the available device +# types) it might become test_bar_cpu() in TestClassFooCPU and test_bar_cuda() +# in TestClassFooCUDA. We can think of the instantiated test classes as +# looking like this: +# +# -------------------------------------------------------- +# # An instantiated test class for the CPU device type +# class TestClassFooCPU(CPUTestBase): +# +# # An instantiated test that calls the template with the string representation +# # of a device from the test class's device type +# def test_bar_cpu(self): +# test_bar(self, 'cpu') +# +# # An instantiated test class for the CUDA device type +# class TestClassFooCUDA(CUDATestBase): +# +# # An instantiated test that calls the template with the string representation +# # of a device from the test class's device type +# def test_bar_cuda(self): +# test_bar(self, 'cuda:0') +# -------------------------------------------------------- +# +# These instantiated test classes ARE discoverable and runnable by both +# unittest and pytest. One thing that may be confusing, however, is that +# attempting to run "test_bar" will not work, despite it appearing in the +# original template code. This is because "test_bar" is no longer discoverable +# after instantiate_device_type_tests() runs, as the above snippet shows. +# Instead "test_bar_cpu" and "test_bar_cuda" may be run directly, or both +# can be run with the option "-k test_bar". +# +# Removing the template class and adding the instantiated classes requires +# passing "globals()" to instantiate_device_type_tests(), because it +# edits the file's Python objects. +# +# As mentioned, tests can be additionally parametrized on dtypes or +# operators. Datatype parametrization uses the @dtypes decorator and +# require a test template like this: +# +# -------------------------------------------------------- +# # A template test that can be specialized with a device and a datatype (dtype) +# @dtypes(torch.float32, torch.int64) +# def test_car(self, device, dtype) +# pass +# -------------------------------------------------------- +# +# If the CPU and CUDA device types are available this test would be +# instantiated as 4 tests that cover the cross-product of the two dtypes +# and two device types: +# +# - test_car_cpu_float32 +# - test_car_cpu_int64 +# - test_car_cuda_float32 +# - test_car_cuda_int64 +# +# The dtype is passed as a torch.dtype object. +# +# Tests parametrized on operators (actually on OpInfos, more on that in a +# moment...) use the @ops decorator and require a test template like this: +# -------------------------------------------------------- +# # A template test that can be specialized with a device, dtype, and OpInfo +# @ops(op_db) +# def test_car(self, device, dtype, op) +# pass +# -------------------------------------------------------- +# +# See the documentation for the @ops decorator below for additional details +# on how to use it and see the note [OpInfos] in +# common_methods_invocations.py for more details on OpInfos. +# +# A test parametrized over the entire "op_db", which contains hundreds of +# OpInfos, will likely have hundreds or thousands of instantiations. The +# test will be instantiated on the cross-product of device types, operators, +# and the dtypes the operator supports on that device type. The instantiated +# tests will have names like: +# +# - test_car_add_cpu_float32 +# - test_car_sub_cuda_int64 +# +# The first instantiated test calls the original test_car() with the OpInfo +# for torch.add as its "op" argument, the string 'cpu' for its "device" argument, +# and the dtype torch.float32 for is "dtype" argument. The second instantiated +# test calls the test_car() with the OpInfo for torch.sub, a CUDA device string +# like 'cuda:0' or 'cuda:1' for its "device" argument, and the dtype +# torch.int64 for its "dtype argument." +# +# In addition to parametrizing over device, dtype, and ops via OpInfos, the +# @parametrize decorator is supported for arbitrary parametrizations: +# -------------------------------------------------------- +# # A template test that can be specialized with a device, dtype, and value for x +# @parametrize("x", range(5)) +# def test_car(self, device, dtype, x) +# pass +# -------------------------------------------------------- +# +# See the documentation for @parametrize in common_utils.py for additional details +# on this. Note that the instantiate_device_type_tests() function will handle +# such parametrizations; there is no need to additionally call +# instantiate_parametrized_tests(). +# +# Clever test filtering can be very useful when working with parametrized +# tests. "-k test_car" would run every instantiated variant of the test_car() +# test template, and "-k test_car_add" runs every variant instantiated with +# torch.add. +# +# It is important to use the passed device and dtype as appropriate. Use +# helper functions like make_tensor() that require explicitly specifying +# the device and dtype so they're not forgotten. +# +# Test templates can use a variety of composable decorators to specify +# additional options and requirements, some are listed here: +# +# - @deviceCountAtLeast() +# Passes a list of strings representing all available devices of +# the test class's device type as the test template's "device" argument. +# If there are fewer devices than the value passed to the decorator +# the test is skipped. +# - @dtypes() +# In addition to accepting multiple dtypes, the @dtypes decorator +# can accept a sequence of tuple pairs of dtypes. The test template +# will be called with each tuple for its "dtype" argument. +# - @onlyNativeDeviceTypes +# Skips the test if the device is not a native device type (currently CPU, CUDA, Meta) +# - @onlyCPU +# Skips the test if the device is not a CPU device +# - @onlyCUDA +# Skips the test if the device is not a CUDA device +# - @onlyMPS +# Skips the test if the device is not a MPS device +# - @skipCPUIfNoLapack +# Skips the test if the device is a CPU device and LAPACK is not installed +# - @skipCPUIfNoMkl +# Skips the test if the device is a CPU device and MKL is not installed +# - @skipCUDAIfNoMagma +# Skips the test if the device is a CUDA device and MAGMA is not installed +# - @skipCUDAIfRocm +# Skips the test if the device is a CUDA device and ROCm is being used + + +# Note [Adding a Device Type] +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# To add a device type: +# +# (1) Create a new "TestBase" extending DeviceTypeTestBase. +# See CPUTestBase and CUDATestBase below. +# (2) Define the "device_type" attribute of the base to be the +# appropriate string. +# (3) Add logic to this file that appends your base class to +# device_type_test_bases when your device type is available. +# (4) (Optional) Write setUpClass/tearDownClass class methods that +# instantiate dependencies (see MAGMA in CUDATestBase). +# (5) (Optional) Override the "instantiate_test" method for total +# control over how your class creates tests. +# +# setUpClass is called AFTER tests have been created and BEFORE and ONLY IF +# they are run. This makes it useful for initializing devices and dependencies. + + +def _dtype_test_suffix(dtypes): + """Returns the test suffix for a dtype, sequence of dtypes, or None.""" + if isinstance(dtypes, (list, tuple)): + if len(dtypes) == 0: + return "" + return "_" + "_".join(dtype_name(d) for d in dtypes) + elif dtypes: + return f"_{dtype_name(dtypes)}" + else: + return "" + + +def _update_param_kwargs(param_kwargs, name, value): + """Adds a kwarg with the specified name and value to the param_kwargs dict.""" + # Make name plural (e.g. devices / dtypes) if the value is composite. + plural_name = f"{name}s" + + # Clear out old entries of the arg if any. + if name in param_kwargs: + del param_kwargs[name] + if plural_name in param_kwargs: + del param_kwargs[plural_name] + + if isinstance(value, (list, tuple)): + param_kwargs[plural_name] = value + elif value is not None: + param_kwargs[name] = value + + # Leave param_kwargs as-is when value is None. + + +class DeviceTypeTestBase(TestCase): + device_type: str = "generic_device_type" + + # Flag to disable test suite early due to unrecoverable error such as CUDA error. + _stop_test_suite = False + + # Precision is a thread-local setting since it may be overridden per test + _tls = threading.local() + _tls.precision = TestCase._precision + _tls.rel_tol = TestCase._rel_tol + + @property + def precision(self): + return self._tls.precision + + @precision.setter + def precision(self, prec): + self._tls.precision = prec + + @property + def rel_tol(self): + return self._tls.rel_tol + + @rel_tol.setter + def rel_tol(self, prec): + self._tls.rel_tol = prec + + # Returns a string representing the device that single device tests should use. + # Note: single device tests use this device exclusively. + @classmethod + def get_primary_device(cls): + return cls.device_type + + @classmethod + def _init_and_get_primary_device(cls): + try: + return cls.get_primary_device() + except Exception: + # For CUDATestBase, XPUTestBase, XLATestBase, and possibly others, the primary device won't be available + # until setUpClass() sets it. Call that manually here if needed. + if hasattr(cls, "setUpClass"): + cls.setUpClass() + return cls.get_primary_device() + + # Returns a list of strings representing all available devices of this + # device type. The primary device must be the first string in the list + # and the list must contain no duplicates. + # Note: UNSTABLE API. Will be replaced once PyTorch has a device generic + # mechanism of acquiring all available devices. + @classmethod + def get_all_devices(cls): + return [cls.get_primary_device()] + + # Returns the dtypes the test has requested. + # Prefers device-specific dtype specifications over generic ones. + @classmethod + def _get_dtypes(cls, test): + if not hasattr(test, "dtypes"): + return None + + default_dtypes = test.dtypes.get("all") + msg = f"@dtypes is mandatory when using @dtypesIf however '{test.__name__}' didn't specify it" + assert default_dtypes is not None, msg + + return test.dtypes.get(cls.device_type, default_dtypes) + + def _get_precision_override(self, test, dtype): + if not hasattr(test, "precision_overrides"): + return self.precision + return test.precision_overrides.get(dtype, self.precision) + + def _get_tolerance_override(self, test, dtype): + if not hasattr(test, "tolerance_overrides"): + return self.precision, self.rel_tol + return test.tolerance_overrides.get(dtype, tol(self.precision, self.rel_tol)) + + def _apply_precision_override_for_test(self, test, param_kwargs): + dtype = param_kwargs.get("dtype") + dtype = param_kwargs.get("dtypes", dtype) + if dtype: + self.precision = self._get_precision_override(test, dtype) + self.precision, self.rel_tol = self._get_tolerance_override(test, dtype) + + # Creates device-specific tests. + @classmethod + def instantiate_test(cls, name, test, *, generic_cls=None): + def instantiate_test_helper( + cls, name, *, test, param_kwargs=None, decorator_fn=lambda _: [] + ): + # Add the device param kwarg if the test needs device or devices. + param_kwargs = {} if param_kwargs is None else param_kwargs + test_sig_params = inspect.signature(test).parameters + if "device" in test_sig_params or "devices" in test_sig_params: + device_arg: str = cls._init_and_get_primary_device() + if hasattr(test, "num_required_devices"): + device_arg = cls.get_all_devices() + _update_param_kwargs(param_kwargs, "device", device_arg) + + # Apply decorators based on param kwargs. + for decorator in decorator_fn(param_kwargs): + test = decorator(test) + + # Constructs the test + @wraps(test) + def instantiated_test(self, param_kwargs=param_kwargs): + # Sets precision and runs test + # Note: precision is reset after the test is run + guard_precision = self.precision + guard_rel_tol = self.rel_tol + try: + self._apply_precision_override_for_test(test, param_kwargs) + result = test(self, **param_kwargs) + except RuntimeError as rte: + # check if rte should stop entire test suite. + self._stop_test_suite = self._should_stop_test_suite() + # Check if test has been decorated with `@expectedFailure` + # Using `__unittest_expecting_failure__` attribute, see + # https://github.com/python/cpython/blob/ffa505b580464/Lib/unittest/case.py#L164 + # In that case, make it fail with "unexpected success" by suppressing exception + if ( + getattr(test, "__unittest_expecting_failure__", False) + and self._stop_test_suite + ): + import sys + + print( + "Suppressing fatal exception to trigger unexpected success", + file=sys.stderr, + ) + return + # raise the runtime error as is for the test suite to record. + raise rte + finally: + self.precision = guard_precision + self.rel_tol = guard_rel_tol + + return result + + assert not hasattr(cls, name), f"Redefinition of test {name}" + setattr(cls, name, instantiated_test) + + def default_parametrize_fn(test, generic_cls, device_cls): + # By default, no parametrization is needed. + yield (test, "", {}, lambda _: []) + + # Parametrization decorators set the parametrize_fn attribute on the test. + parametrize_fn = getattr(test, "parametrize_fn", default_parametrize_fn) + + # If one of the @dtypes* decorators is present, also parametrize over the dtypes set by it. + dtypes = cls._get_dtypes(test) + if dtypes is not None: + + def dtype_parametrize_fn(test, generic_cls, device_cls, dtypes=dtypes): + for dtype in dtypes: + param_kwargs: dict[str, Any] = {} + _update_param_kwargs(param_kwargs, "dtype", dtype) + + # Note that an empty test suffix is set here so that the dtype can be appended + # later after the device. + yield (test, "", param_kwargs, lambda _: []) + + parametrize_fn = compose_parametrize_fns( + dtype_parametrize_fn, parametrize_fn + ) + + # Instantiate the parametrized tests. + for ( + test, # noqa: B020 + test_suffix, + param_kwargs, + decorator_fn, + ) in parametrize_fn(test, generic_cls, cls): + test_suffix = "" if test_suffix == "" else "_" + test_suffix + cls_device_type = ( + cls.device_type + if cls.device_type != "privateuse1" + else torch._C._get_privateuse1_backend_name() + ) + device_suffix = "_" + cls_device_type + + # Note: device and dtype suffix placement + # Special handling here to place dtype(s) after device according to test name convention. + dtype_kwarg = None + if "dtype" in param_kwargs or "dtypes" in param_kwargs: + dtype_kwarg = ( + param_kwargs["dtypes"] + if "dtypes" in param_kwargs + else param_kwargs["dtype"] + ) + test_name = ( + f"{name}{test_suffix}{device_suffix}{_dtype_test_suffix(dtype_kwarg)}" + ) + + instantiate_test_helper( + cls=cls, + name=test_name, + test=test, + param_kwargs=param_kwargs, + decorator_fn=decorator_fn, + ) + + def run(self, result=None): + super().run(result=result) + # Early terminate test if _stop_test_suite is set. + if self._stop_test_suite: + result.stop() + + +class CPUTestBase(DeviceTypeTestBase): + device_type = "cpu" + + # No critical error should stop CPU test suite + def _should_stop_test_suite(self): + return False + + +class CUDATestBase(DeviceTypeTestBase): + device_type = "cuda" + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + primary_device: ClassVar[str] + cudnn_version: ClassVar[Any] + no_magma: ClassVar[bool] + no_cudnn: ClassVar[bool] + + def has_cudnn(self): + return not self.no_cudnn + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def get_all_devices(cls): + primary_device_idx = int(cls.get_primary_device().split(":")[1]) + num_devices = torch.cuda.device_count() + + prim_device = cls.get_primary_device() + cuda_str = "cuda:{0}" + non_primary_devices = [ + cuda_str.format(idx) + for idx in range(num_devices) + if idx != primary_device_idx + ] + return [prim_device] + non_primary_devices + + @classmethod + def setUpClass(cls): + # has_magma shows up after cuda is initialized + t = torch.ones(1).cuda() + cls.no_magma = not torch.cuda.has_magma + + # Determines if cuDNN is available and its version + cls.no_cudnn = not torch.backends.cudnn.is_acceptable(t) + cls.cudnn_version = None if cls.no_cudnn else torch.backends.cudnn.version() + + # Acquires the current device as the primary (test) device + cls.primary_device = f"cuda:{torch.cuda.current_device()}" + + +# See Note [Lazy Tensor tests in device agnostic testing] +lazy_ts_backend_init = False + + +class LazyTestBase(DeviceTypeTestBase): + device_type = "lazy" + + def _should_stop_test_suite(self): + return False + + @classmethod + def setUpClass(cls): + import torch._lazy + import torch._lazy.metrics + import torch._lazy.ts_backend + + global lazy_ts_backend_init + if not lazy_ts_backend_init: + # Need to connect the TS backend to lazy key before running tests + torch._lazy.ts_backend.init() + lazy_ts_backend_init = True + + +class MPSTestBase(DeviceTypeTestBase): + device_type = "mps" + primary_device: ClassVar[str] + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def get_all_devices(cls): + # currently only one device is supported on MPS backend + prim_device = cls.get_primary_device() + return [prim_device] + + @classmethod + def setUpClass(cls): + cls.primary_device = "mps:0" + + def _should_stop_test_suite(self): + return False + + +class XPUTestBase(DeviceTypeTestBase): + device_type = "xpu" + primary_device: ClassVar[str] + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def get_all_devices(cls): + # currently only one device is supported on MPS backend + primary_device_idx = int(cls.get_primary_device().split(":")[1]) + num_devices = torch.xpu.device_count() + + prim_device = cls.get_primary_device() + xpu_str = "xpu:{0}" + non_primary_devices = [ + xpu_str.format(idx) + for idx in range(num_devices) + if idx != primary_device_idx + ] + return [prim_device] + non_primary_devices + + @classmethod + def setUpClass(cls): + cls.primary_device = f"xpu:{torch.xpu.current_device()}" + + def _should_stop_test_suite(self): + return False + + +class HPUTestBase(DeviceTypeTestBase): + device_type = "hpu" + primary_device: ClassVar[str] + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def setUpClass(cls): + cls.primary_device = "hpu:0" + + +class PrivateUse1TestBase(DeviceTypeTestBase): + primary_device: ClassVar[str] + device_mod = None + device_type = "privateuse1" + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def get_all_devices(cls): + primary_device_idx = int(cls.get_primary_device().split(":")[1]) + num_devices = cls.device_mod.device_count() + prim_device = cls.get_primary_device() + device_str = f"{cls.device_type}:{{0}}" + non_primary_devices = [ + device_str.format(idx) + for idx in range(num_devices) + if idx != primary_device_idx + ] + return [prim_device] + non_primary_devices + + @classmethod + def setUpClass(cls): + cls.device_type = torch._C._get_privateuse1_backend_name() + cls.device_mod = getattr(torch, cls.device_type, None) + assert ( + cls.device_mod is not None + ), f"""torch has no module of `{cls.device_type}`, you should register + a module by `torch._register_device_module`.""" + cls.primary_device = f"{cls.device_type}:{cls.device_mod.current_device()}" + + +# Adds available device-type-specific test base classes +def get_device_type_test_bases(): + # set type to List[Any] due to mypy list-of-union issue: + # https://github.com/python/mypy/issues/3351 + test_bases: list[Any] = [] + + if IS_SANDCASTLE or IS_FBCODE: + if IS_REMOTE_GPU: + # Skip if sanitizer is enabled or we're on MTIA machines + if ( + not TEST_WITH_ASAN + and not TEST_WITH_TSAN + and not TEST_WITH_UBSAN + and not TEST_WITH_MTIA + ): + test_bases.append(CUDATestBase) + else: + test_bases.append(CPUTestBase) + else: + test_bases.append(CPUTestBase) + if torch.cuda.is_available(): + test_bases.append(CUDATestBase) + + if is_privateuse1_backend_available(): + test_bases.append(PrivateUse1TestBase) + # Disable MPS testing in generic device testing temporarily while we're + # ramping up support. + # elif torch.backends.mps.is_available(): + # test_bases.append(MPSTestBase) + + return test_bases + + +device_type_test_bases = get_device_type_test_bases() + + +def filter_desired_device_types(device_type_test_bases, except_for=None, only_for=None): + # device type cannot appear in both except_for and only_for + intersect = set(except_for if except_for else []) & set( + only_for if only_for else [] + ) + assert not intersect, ( + f"device ({intersect}) appeared in both except_for and only_for" + ) + + # Replace your privateuse1 backend name with 'privateuse1' + if is_privateuse1_backend_available(): + privateuse1_backend_name = torch._C._get_privateuse1_backend_name() + + def func_replace(x: str): + return x.replace(privateuse1_backend_name, "privateuse1") + + except_for = ( + ([func_replace(x) for x in except_for] if except_for is not None else None) + if not isinstance(except_for, str) + else func_replace(except_for) + ) + only_for = ( + ([func_replace(x) for x in only_for] if only_for is not None else None) + if not isinstance(only_for, str) + else func_replace(only_for) + ) + + if except_for: + device_type_test_bases = filter( + lambda x: x.device_type not in except_for, device_type_test_bases + ) + if only_for: + device_type_test_bases = filter( + lambda x: x.device_type in only_for, device_type_test_bases + ) + + return list(device_type_test_bases) + + +# Note [How to extend DeviceTypeTestBase to add new test device] +# The following logic optionally allows downstream projects like pytorch/xla to +# add more test devices. +# Instructions: +# - Add a python file (e.g. pytorch/xla/test/pytorch_test_base.py) in downstream project. +# - Inside the file, one should inherit from `DeviceTypeTestBase` class and define +# a new DeviceTypeTest class (e.g. `XLATestBase`) with proper implementation of +# `instantiate_test` method. +# - DO NOT import common_device_type inside the file. +# `runpy.run_path` with `globals()` already properly setup the context so that +# `DeviceTypeTestBase` is already available. +# - Set a top-level variable `TEST_CLASS` equal to your new class. +# E.g. TEST_CLASS = XLATensorBase +# - To run tests with new device type, set `TORCH_TEST_DEVICE` env variable to path +# to this file. Multiple paths can be separated by `:`. +# See pytorch/xla/test/pytorch_test_base.py for a more detailed example. +_TORCH_TEST_DEVICES = os.environ.get("TORCH_TEST_DEVICES", None) +if _TORCH_TEST_DEVICES: + for path in _TORCH_TEST_DEVICES.split(":"): + # runpy (a stdlib module) lacks annotations + mod = runpy.run_path(path, init_globals=globals()) # type: ignore[func-returns-value] + device_type_test_bases.append(mod["TEST_CLASS"]) + + +PYTORCH_CUDA_MEMCHECK = os.getenv("PYTORCH_CUDA_MEMCHECK", "0") == "1" + +PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY = "PYTORCH_TESTING_DEVICE_ONLY_FOR" +PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY = "PYTORCH_TESTING_DEVICE_EXCEPT_FOR" +PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY = "PYTORCH_TESTING_DEVICE_FOR_CUSTOM" + + +def get_desired_device_type_test_bases( + except_for=None, only_for=None, include_lazy=False, allow_mps=False, allow_xpu=False +): + # allow callers to specifically opt tests into being tested on MPS, similar to `include_lazy` + test_bases = device_type_test_bases.copy() + if allow_mps and TEST_MPS and MPSTestBase not in test_bases: + test_bases.append(MPSTestBase) + if allow_xpu and TEST_XPU and XPUTestBase not in test_bases: + test_bases.append(XPUTestBase) + if TEST_HPU and HPUTestBase not in test_bases: + test_bases.append(HPUTestBase) + # Filter out the device types based on user inputs + desired_device_type_test_bases = filter_desired_device_types( + test_bases, except_for, only_for + ) + if include_lazy: + # Note [Lazy Tensor tests in device agnostic testing] + # Right now, test_view_ops.py runs with LazyTensor. + # We don't want to opt every device-agnostic test into using the lazy device, + # because many of them will fail. + # So instead, the only way to opt a specific device-agnostic test file into + # lazy tensor testing is with include_lazy=True + if IS_FBCODE: + print( + "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds", + file=sys.stderr, + ) + else: + desired_device_type_test_bases.append(LazyTestBase) + + def split_if_not_empty(x: str): + return x.split(",") if x else [] + + # run some cuda testcases on other devices if available + # Usage: + # export PYTORCH_TESTING_DEVICE_FOR_CUSTOM=privateuse1 + env_custom_only_for = split_if_not_empty( + os.getenv(PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY, "") + ) + if env_custom_only_for: + desired_device_type_test_bases += filter( + lambda x: x.device_type in env_custom_only_for, test_bases + ) + desired_device_type_test_bases = list(set(desired_device_type_test_bases)) + + # Filter out the device types based on environment variables if available + # Usage: + # export PYTORCH_TESTING_DEVICE_ONLY_FOR=cuda,cpu + # export PYTORCH_TESTING_DEVICE_EXCEPT_FOR=xla + env_only_for = split_if_not_empty( + os.getenv(PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, "") + ) + env_except_for = split_if_not_empty( + os.getenv(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, "") + ) + + return filter_desired_device_types( + desired_device_type_test_bases, env_except_for, env_only_for + ) + + +# Adds 'instantiated' device-specific test cases to the given scope. +# The tests in these test cases are derived from the generic tests in +# generic_test_class. This function should be used instead of +# instantiate_parametrized_tests() if the test class contains +# device-specific tests (NB: this supports additional @parametrize usage). +# +# See note "Writing Test Templates" +# TODO: remove "allow_xpu" option after Interl GPU support all test case instantiate by this function. +def instantiate_device_type_tests( + generic_test_class, + scope, + except_for=None, + only_for=None, + include_lazy=False, + allow_mps=False, + allow_xpu=False, +): + # Removes the generic test class from its enclosing scope so its tests + # are not discoverable. + del scope[generic_test_class.__name__] + + generic_members = set(generic_test_class.__dict__.keys()) + generic_tests = [x for x in generic_members if x.startswith("test")] + + # Creates device-specific test cases + for base in get_desired_device_type_test_bases( + except_for, only_for, include_lazy, allow_mps, allow_xpu + ): + class_name = generic_test_class.__name__ + base.device_type.upper() + + # type set to Any and suppressed due to unsupported runtime class: + # https://github.com/python/mypy/wiki/Unsupported-Python-Features + device_type_test_class: Any = type(class_name, (base, generic_test_class), {}) + + # Arrange for setUpClass and tearDownClass methods defined both in the test template + # class and in the generic base to be called. This allows device-parameterized test + # classes to support setup and teardown. + # NB: This should be done before instantiate_test() is called as that invokes setup. + @classmethod + def _setUpClass(cls): + # This should always be called, whether or not the test class invokes + # super().setUpClass(), to set the primary device. + base.setUpClass() + # We want to call the @classmethod defined in the generic base, but pass + # it the device-specific class object (cls), hence the __func__ call. + generic_test_class.setUpClass.__func__(cls) + + @classmethod + def _tearDownClass(cls): + # We want to call the @classmethod defined in the generic base, but pass + # it the device-specific class object (cls), hence the __func__ call. + generic_test_class.tearDownClass.__func__(cls) + base.tearDownClass() + + device_type_test_class.setUpClass = _setUpClass + device_type_test_class.tearDownClass = _tearDownClass + + for name in generic_members: + if name in generic_tests: # Instantiates test member + test = getattr(generic_test_class, name) + # XLA-compat shim (XLA's instantiate_test takes doesn't take generic_cls) + sig = inspect.signature(device_type_test_class.instantiate_test) + if len(sig.parameters) == 3: + # Instantiates the device-specific tests + device_type_test_class.instantiate_test( + name, copy.deepcopy(test), generic_cls=generic_test_class + ) + else: + device_type_test_class.instantiate_test(name, copy.deepcopy(test)) + # Ports non-test member. Setup / teardown have already been handled above + elif name not in device_type_test_class.__dict__: + nontest = getattr(generic_test_class, name) + setattr(device_type_test_class, name, nontest) + + # Mimics defining the instantiated class in the caller's file + # by setting its module to the given class's and adding + # the module to the given scope. + # This lets the instantiated class be discovered by unittest. + device_type_test_class.__module__ = generic_test_class.__module__ + scope[class_name] = device_type_test_class + + # Delete the generic form of the test functions (e.g. TestFoo.test_bar()) so they're + # not discoverable. This mutates the original class (TestFoo), which was removed from + # scope above. At this point, device-specific tests (e.g. TestFooCUDA.test_bar_cuda) + # have already been created and the generic forms are no longer needed. + for name in generic_tests: + delattr(generic_test_class, name) + + +# Category of dtypes to run an OpInfo-based test for +# Example use: @ops(dtype=OpDTypes.supported) +# +# There are 7 categories: +# - supported: Every dtype supported by the operator. Use for exhaustive +# testing of all dtypes. +# - unsupported: Run tests on dtypes not supported by the operator. e.g. for +# testing the operator raises an error and doesn't crash. +# - supported_backward: Every dtype supported by the operator's backward pass. +# - unsupported_backward: Run tests on dtypes not supported by the operator's backward pass. +# - any_one: Runs a test for one dtype the operator supports. Prioritizes dtypes the +# operator supports in both forward and backward. +# - none: Useful for tests that are not dtype-specific. No dtype will be passed to the test +# when this is selected. +# - any_common_cpu_cuda_one: Pick a dtype that supports both CPU and CUDA. +class OpDTypes(Enum): + supported = 0 # Test all supported dtypes (default) + unsupported = 1 # Test only unsupported dtypes + supported_backward = 2 # Test all supported backward dtypes + unsupported_backward = 3 # Test only unsupported backward dtypes + any_one = 4 # Test precisely one supported dtype + none = 5 # Instantiate no dtype variants (no dtype kwarg needed) + any_common_cpu_cuda_one = ( + 6 # Test precisely one supported dtype that is common to both cuda and cpu + ) + + +# Arbitrary order +ANY_DTYPE_ORDER = ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + torch.float16, + torch.bfloat16, + torch.long, + torch.int32, + torch.int16, + torch.int8, + torch.uint8, + torch.bool, + torch.float8_e4m3fn, + torch.float8_e5m2, +) + + +def _serialize_sample(sample_input): + # NB: For OpInfos, SampleInput.summary() prints in a cleaner way. + if getattr(sample_input, "summary", None) is not None: + return sample_input.summary() + return str(sample_input) + + +# Decorator that defines the OpInfos a test template should be instantiated for. +# +# Example usage: +# +# @ops(unary_ufuncs) +# def test_numerics(self, device, dtype, op): +# +# +# This will instantiate variants of test_numerics for each given OpInfo, +# on each device the OpInfo's operator supports, and for every dtype supported by +# that operator. There are a few caveats to the dtype rule, explained below. +# +# The @ops decorator can accept two +# additional arguments, "dtypes" and "allowed_dtypes". If "dtypes" is specified +# then the test variants are instantiated for those dtypes, regardless of +# what the operator supports. If given "allowed_dtypes" then test variants +# are instantiated only for the intersection of allowed_dtypes and the dtypes +# they would otherwise be instantiated with. That is, allowed_dtypes composes +# with the options listed above and below. +# +# The "dtypes" argument can also accept additional values (see OpDTypes above): +# OpDTypes.supported - the test is instantiated for all dtypes the operator +# supports +# OpDTypes.unsupported - the test is instantiated for all dtypes the operator +# doesn't support +# OpDTypes.supported_backward - the test is instantiated for all dtypes the +# operator's gradient formula supports +# OpDTypes.unsupported_backward - the test is instantiated for all dtypes the +# operator's gradient formula doesn't support +# OpDTypes.any_one - the test is instantiated for one dtype the +# operator supports. The dtype supports forward and backward if possible. +# OpDTypes.none - the test is instantiated without any dtype. The test signature +# should not include a dtype kwarg in this case. +# OpDTypes.any_common_cpu_cuda_one - the test is instantiated for a dtype +# that supports both CPU and CUDA. +# +# These options allow tests to have considerable control over the dtypes +# they're instantiated for. + + +class ops(_TestParametrizer): + def __init__( + self, + op_list, + *, + dtypes: Union[OpDTypes, Sequence[torch.dtype]] = OpDTypes.supported, + allowed_dtypes: Optional[Sequence[torch.dtype]] = None, + skip_if_dynamo=True, + ): + self.op_list = list(op_list) + self.opinfo_dtypes = dtypes + self.allowed_dtypes = ( + set(allowed_dtypes) if allowed_dtypes is not None else None + ) + self.skip_if_dynamo = skip_if_dynamo + + def _parametrize_test(self, test, generic_cls, device_cls): + """Parameterizes the given test function across each op and its associated dtypes.""" + if device_cls is None: + raise RuntimeError( + "The @ops decorator is only intended to be used in a device-specific " + "context; use it with instantiate_device_type_tests() instead of " + "instantiate_parametrized_tests()" + ) + + op = check_exhausted_iterator = object() + for op in self.op_list: + # Determine the set of dtypes to use. + dtypes: Union[set[torch.dtype], set[None]] + if isinstance(self.opinfo_dtypes, Sequence): + dtypes = set(self.opinfo_dtypes) + elif self.opinfo_dtypes == OpDTypes.unsupported_backward: + dtypes = set(get_all_dtypes()).difference( + op.supported_backward_dtypes(device_cls.device_type) + ) + elif self.opinfo_dtypes == OpDTypes.supported_backward: + dtypes = op.supported_backward_dtypes(device_cls.device_type) + elif self.opinfo_dtypes == OpDTypes.unsupported: + dtypes = set(get_all_dtypes()).difference( + op.supported_dtypes(device_cls.device_type) + ) + elif self.opinfo_dtypes == OpDTypes.supported: + dtypes = set(op.supported_dtypes(device_cls.device_type)) + elif self.opinfo_dtypes == OpDTypes.any_one: + # Tries to pick a dtype that supports both forward or backward + supported = op.supported_dtypes(device_cls.device_type) + supported_backward = op.supported_backward_dtypes( + device_cls.device_type + ) + supported_both = supported.intersection(supported_backward) + dtype_set = supported_both if len(supported_both) > 0 else supported + for dtype in ANY_DTYPE_ORDER: + if dtype in dtype_set: + dtypes = {dtype} + break + else: + dtypes = {} + elif self.opinfo_dtypes == OpDTypes.any_common_cpu_cuda_one: + # Tries to pick a dtype that supports both CPU and CUDA + supported = set(op.dtypes).intersection(op.dtypesIfCUDA) + if supported: + dtypes = { + next(dtype for dtype in ANY_DTYPE_ORDER if dtype in supported) + } + else: + dtypes = {} + + elif self.opinfo_dtypes == OpDTypes.none: + dtypes = {None} + else: + raise RuntimeError(f"Unknown OpDType: {self.opinfo_dtypes}") + + if self.allowed_dtypes is not None: + dtypes = dtypes.intersection(self.allowed_dtypes) + + # Construct the test name; device / dtype parts are handled outside. + # See [Note: device and dtype suffix placement] + test_name = op.formatted_name + + # Filter sample skips / xfails to only those that apply to the OpInfo. + # These are defined on the test function via decorators. + sample_skips_and_xfails = getattr(test, "sample_skips_and_xfails", None) + if sample_skips_and_xfails is not None: + sample_skips_and_xfails = [ + rule + for rule in sample_skips_and_xfails + if rule.op_match_fn(device_cls.device_type, op) + ] + + for dtype in dtypes: + # Construct parameter kwargs to pass to the test. + param_kwargs = {"op": op} + _update_param_kwargs(param_kwargs, "dtype", dtype) + + # NOTE: test_wrapper exists because we don't want to apply + # op-specific decorators to the original test. + # Test-specific decorators are applied to the original test, + # however. + try: + + @wraps(test) + def test_wrapper(*args, **kwargs): + try: + return test(*args, **kwargs) + except unittest.SkipTest as e: + raise e + except Exception as e: + tracked_input = get_tracked_input() + if PRINT_REPRO_ON_FAILURE and tracked_input is not None: + e_tracked = Exception( # noqa: TRY002 + f"{str(e)}\n\nCaused by {tracked_input.type_desc} " + f"at index {tracked_input.index}: " + f"{_serialize_sample(tracked_input.val)}" + ) + e_tracked._tracked_input = tracked_input # type: ignore[attr] + raise e_tracked from e + raise e + finally: + clear_tracked_input() + + if self.skip_if_dynamo and not TEST_WITH_TORCHINDUCTOR: + test_wrapper = skipIfTorchDynamo( + "Policy: we don't run OpInfo tests w/ Dynamo" + )(test_wrapper) + + # Initialize info for the last input seen. This is useful for tracking + # down which inputs caused a test failure. Note that TrackedInputIter is + # responsible for managing this. + test.tracked_input = None + + decorator_fn = partial( + op.get_decorators, + generic_cls.__name__, + test.__name__, + device_cls.device_type, + dtype, + ) + + if sample_skips_and_xfails is not None: + test_wrapper.sample_skips_and_xfails = sample_skips_and_xfails + + yield (test_wrapper, test_name, param_kwargs, decorator_fn) + except Exception as ex: + # Provides an error message for debugging before rethrowing the exception + print(f"Failed to instantiate {test_name} for op {op.name}!") + raise ex + if op is check_exhausted_iterator: + raise ValueError( + "An empty op_list was passed to @ops. " + "Note that this may result from reuse of a generator." + ) + + +# Decorator that skips a test if the given condition is true. +# Notes: +# (1) Skip conditions stack. +# (2) Skip conditions can be bools or strings. If a string the +# test base must have defined the corresponding attribute to be False +# for the test to run. If you want to use a string argument you should +# probably define a new decorator instead (see below). +# (3) Prefer the existing decorators to defining the 'device_type' kwarg. +class skipIf: + def __init__(self, dep, reason, device_type=None): + self.dep = dep + self.reason = reason + self.device_type = device_type + + def __call__(self, fn): + @wraps(fn) + def dep_fn(slf, *args, **kwargs): + if ( + self.device_type is None + or self.device_type == slf.device_type + or ( + isinstance(self.device_type, Iterable) + and slf.device_type in self.device_type + ) + ): + if (isinstance(self.dep, str) and getattr(slf, self.dep, True)) or ( + isinstance(self.dep, bool) and self.dep + ): + raise unittest.SkipTest(self.reason) + + return fn(slf, *args, **kwargs) + + return dep_fn + + +# Skips a test on CPU if the condition is true. +class skipCPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="cpu") + + +# Skips a test on CUDA if the condition is true. +class skipCUDAIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="cuda") + + +# Skips a test on XPU if the condition is true. +class skipXPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="xpu") + + +# Skips a test on XPU or CUDA if the condition is true. +class skipGPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type=GPU_TYPES) + + +# Skips a test on Lazy if the condition is true. +class skipLazyIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="lazy") + + +# Skips a test on Meta if the condition is true. +class skipMetaIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="meta") + + +# Skips a test on MPS if the condition is true. +class skipMPSIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="mps") + + +class skipHPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="hpu") + + +# Skips a test on XLA if the condition is true. +class skipXLAIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="xla") + + +class skipPRIVATEUSE1If(skipIf): + def __init__(self, dep, reason): + device_type = torch._C._get_privateuse1_backend_name() + super().__init__(dep, reason, device_type=device_type) + + +def _has_sufficient_memory(device, size): + device_ = torch.device(device) + device_type = device_.type + if device_type in ["cuda", "xpu"]: + acc = torch.accelerator.current_accelerator() + # Case 1: no accelerator found + if not acc: + return False + # Case 2: accelerator found but not matching device type + if acc.type != device_type: + return True + # Case 3: accelerator found and matching device type but not available + if not torch.accelerator.is_available(): + return False + # Case 4: accelerator found and matching device type and available + gc.collect() + torch.accelerator.empty_cache() + + if device_.index is None: + device_ = torch.device(device_type, 0) + + if device_type == "cuda": + return ( + torch.cuda.memory.mem_get_info(device_)[0] + * torch.cuda.memory.get_per_process_memory_fraction(device_) + ) >= size + + if device_type == "xpu": + return torch.xpu.memory.mem_get_info(device_)[0] >= size + + if device_type == "xla": + raise unittest.SkipTest("TODO: Memory availability checks for XLA?") + + if device_type != "cpu": + raise unittest.SkipTest("Unknown device type") + + # CPU + if not HAS_PSUTIL: + raise unittest.SkipTest("Need psutil to determine if memory is sufficient") + + # The sanitizers have significant memory overheads + if TEST_WITH_ASAN or TEST_WITH_TSAN or TEST_WITH_UBSAN: + effective_size = size * 10 + else: + effective_size = size + + # don't try using all RAM on s390x, leave some for service processes + if IS_S390X: + effective_size = effective_size * 2 + + if psutil.virtual_memory().available < effective_size: + gc.collect() + return psutil.virtual_memory().available >= effective_size + + +def largeTensorTest(size, device=None, inductor=TEST_WITH_TORCHINDUCTOR): + """Skip test if the device has insufficient memory to run the test + + size may be a number of bytes, a string of the form "N GB", or a callable + + If the test is a device generic test, available memory on the primary device will be checked. + It can also be overridden by the optional `device=` argument. + In other tests, the `device=` argument needs to be specified. + """ + if isinstance(size, str): + assert size.endswith(("GB", "gb")), "only bytes or GB supported" + size = 1024**3 * int(size[:-2]) + + def inner(fn): + @wraps(fn) + def dep_fn(self, *args, **kwargs): + size_bytes: int = size(self, *args, **kwargs) if callable(size) else size + _device = device + if _device is None: + if hasattr(self, "get_primary_device"): + _device = self.get_primary_device() + else: + _device = self.device + + # If this is running with GPU cpp_wrapper, the autotuning step will generate + # an additional array of the same size as the input. + if inductor and torch._inductor.config.cpp_wrapper and _device != "cpu": + size_bytes *= 2 + if not _has_sufficient_memory(_device, size_bytes): + raise unittest.SkipTest(f"Insufficient {_device} memory") + + return fn(self, *args, **kwargs) + + return dep_fn + + return inner + + +class expectedFailure: + def __init__(self, device_type, dtype=None): + self.device_type = device_type + self.dtype = dtype + + def __call__(self, fn): + @wraps(fn) + def efail_fn(slf, *args, **kwargs): + if ( + not hasattr(slf, "device_type") + and hasattr(slf, "device") + and isinstance(slf.device, str) + ): + target_device_type = slf.device + else: + target_device_type = slf.device_type + + target_dtype = kwargs.get("dtype", getattr(slf, "dtype", None)) + device_matches = ( + self.device_type is None or self.device_type == target_device_type + ) + dtype_matches = self.dtype is None or self.dtype == target_dtype + + if device_matches and dtype_matches: + try: + fn(slf, *args, **kwargs) + except Exception: + return + else: + slf.fail("expected test to fail, but it passed") + + return fn(slf, *args, **kwargs) + + return efail_fn + + +class onlyOn: + def __init__(self, device_type: Union[str, list]): + self.device_type = device_type + + def __call__(self, fn): + @wraps(fn) + def only_fn(slf, *args, **kwargs): + if slf.device_type not in self.device_type: + reason = f"Only runs on {self.device_type}" + raise unittest.SkipTest(reason) + + return fn(slf, *args, **kwargs) + + return only_fn + + +# Decorator that provides all available devices of the device type to the test +# as a list of strings instead of providing a single device string. +# Skips the test if the number of available devices of the variant's device +# type is less than the 'num_required_devices' arg. +class deviceCountAtLeast: + def __init__(self, num_required_devices): + self.num_required_devices = num_required_devices + + def __call__(self, fn): + assert not hasattr(fn, "num_required_devices"), ( + f"deviceCountAtLeast redefinition for {fn.__name__}" + ) + fn.num_required_devices = self.num_required_devices + + @wraps(fn) + def multi_fn(slf, devices, *args, **kwargs): + if len(devices) < self.num_required_devices: + reason = f"fewer than {self.num_required_devices} devices detected" + raise unittest.SkipTest(reason) + + return fn(slf, devices, *args, **kwargs) + + return multi_fn + + +# Only runs the test on the native device type (currently CPU, CUDA, Meta and PRIVATEUSE1) +def onlyNativeDeviceTypes(fn: Callable[_P, _T]) -> Callable[_P, _T]: + @wraps(fn) + def only_fn(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: + if self.device_type not in NATIVE_DEVICES: + reason = f"onlyNativeDeviceTypes: doesn't run on {self.device_type}" + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return only_fn + + +# Only runs the test on the native device types and devices specified in the devices list +def onlyNativeDeviceTypesAnd(devices=None): + def decorator(fn): + @wraps(fn) + def only_fn(self, *args, **kwargs): + if ( + self.device_type not in NATIVE_DEVICES + and self.device_type not in devices + ): + reason = f"onlyNativeDeviceTypesAnd {devices} : doesn't run on {self.device_type}" + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return only_fn + + return decorator + + +# Specifies per-dtype precision overrides. +# Ex. +# +# @precisionOverride({torch.half : 1e-2, torch.float : 1e-4}) +# @dtypes(torch.half, torch.float, torch.double) +# def test_X(self, device, dtype): +# ... +# +# When the test is instantiated its class's precision will be set to the +# corresponding override, if it exists. +# self.precision can be accessed directly, and it also controls the behavior of +# functions like self.assertEqual(). +# +# Note that self.precision is a scalar value, so if you require multiple +# precisions (or are working with multiple dtypes) they should be specified +# explicitly and computed using self.precision (e.g. +# self.precision *2, max(1, self.precision)). +class precisionOverride: + def __init__(self, d): + assert isinstance(d, dict), ( + "precisionOverride not given a dtype : precision dict!" + ) + for dtype in d: + assert isinstance(dtype, torch.dtype), ( + f"precisionOverride given unknown dtype {dtype}" + ) + + self.d = d + + def __call__(self, fn): + fn.precision_overrides = self.d + return fn + + +# Specifies per-dtype tolerance overrides tol(atol, rtol). It has priority over +# precisionOverride. +# Ex. +# +# @toleranceOverride({torch.float : tol(atol=1e-2, rtol=1e-3}, +# torch.double : tol{atol=1e-4, rtol = 0}) +# @dtypes(torch.half, torch.float, torch.double) +# def test_X(self, device, dtype): +# ... +# +# When the test is instantiated its class's tolerance will be set to the +# corresponding override, if it exists. +# self.rtol and self.precision can be accessed directly, and they also control +# the behavior of functions like self.assertEqual(). +# +# The above example sets atol = 1e-2 and rtol = 1e-3 for torch.float and +# atol = 1e-4 and rtol = 0 for torch.double. +tol = namedtuple("tol", ["atol", "rtol"]) + + +class toleranceOverride: + def __init__(self, d): + assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!" + for dtype, prec in d.items(): + assert isinstance(dtype, torch.dtype), ( + f"toleranceOverride given unknown dtype {dtype}" + ) + assert isinstance(prec, tol), ( + "toleranceOverride not given a dtype : tol dict!" + ) + + self.d = d + + def __call__(self, fn): + fn.tolerance_overrides = self.d + return fn + + +# Decorator that instantiates a variant of the test for each given dtype. +# Notes: +# (1) Tests that accept the dtype argument MUST use this decorator. +# (2) Can be overridden for CPU or CUDA, respectively, using dtypesIfCPU +# or dtypesIfCUDA. +# (3) Can accept an iterable of dtypes or an iterable of tuples +# of dtypes. +# Examples: +# @dtypes(torch.float32, torch.float64) +# @dtypes((torch.long, torch.float32), (torch.int, torch.float64)) +class dtypes: + def __init__(self, *args, device_type="all"): + if len(args) > 0 and isinstance(args[0], (list, tuple)): + for arg in args: + assert isinstance(arg, (list, tuple)), ( + "When one dtype variant is a tuple or list, " + "all dtype variants must be. " + f"Received non-list non-tuple dtype {str(arg)}" + ) + assert all(isinstance(dtype, torch.dtype) for dtype in arg), ( + f"Unknown dtype in {str(arg)}" + ) + else: + assert all(isinstance(arg, torch.dtype) for arg in args), ( + f"Unknown dtype in {str(args)}" + ) + + self.args = args + self.device_type = device_type + + def __call__(self, fn): + d = getattr(fn, "dtypes", {}) + assert self.device_type not in d, f"dtypes redefinition for {self.device_type}" + d[self.device_type] = self.args + fn.dtypes = d + return fn + + +# Overrides specified dtypes on the CPU. +class dtypesIfCPU(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="cpu") + + +# Overrides specified dtypes on CUDA. +class dtypesIfCUDA(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="cuda") + + +# Overrides specified dtypes on Intel GPU. +class dtypesIfXPU(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="xpu") + + +class dtypesIfMPS(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="mps") + + +class dtypesIfHPU(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="hpu") + + +class dtypesIfPRIVATEUSE1(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type=torch._C._get_privateuse1_backend_name()) + + +def onlyCPU(fn): + return onlyOn("cpu")(fn) + + +def onlyCUDA(fn): + return onlyOn("cuda")(fn) + + +def onlyMPS(fn): + return onlyOn("mps")(fn) + + +def onlyXPU(fn): + return onlyOn("xpu")(fn) + + +def onlyHPU(fn): + return onlyOn("hpu")(fn) + + +def onlyPRIVATEUSE1(fn): + device_type = torch._C._get_privateuse1_backend_name() + device_mod = getattr(torch, device_type, None) + if device_mod is None: + reason = f"Skip as torch has no module of {device_type}" + return unittest.skip(reason)(fn) + return onlyOn(device_type)(fn) + + +def onlyCUDAAndPRIVATEUSE1(fn): + @wraps(fn) + def only_fn(self, *args, **kwargs): + if self.device_type not in ("cuda", torch._C._get_privateuse1_backend_name()): + reason = f"onlyCUDAAndPRIVATEUSE1: doesn't run on {self.device_type}" + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return only_fn + + +def disablecuDNN(fn): + @wraps(fn) + def disable_cudnn(self, *args, **kwargs): + if self.device_type == "cuda" and self.has_cudnn(): + with torch.backends.cudnn.flags(enabled=False): + return fn(self, *args, **kwargs) + return fn(self, *args, **kwargs) + + return disable_cudnn + + +def disableMkldnn(fn): + @wraps(fn) + def disable_mkldnn(self, *args, **kwargs): + if torch.backends.mkldnn.is_available(): + with torch.backends.mkldnn.flags(enabled=False): + return fn(self, *args, **kwargs) + return fn(self, *args, **kwargs) + + return disable_mkldnn + + +def expectedFailureCPU(fn): + return expectedFailure("cpu")(fn) + + +def expectedFailureCUDA(fn): + return expectedFailure("cuda")(fn) + + +def expectedFailureXPU(fn): + return expectedFailure("xpu")(fn) + + +def expectedFailureMeta(fn): + return skipIfTorchDynamo()(expectedFailure("meta")(fn)) + + +def expectedFailureXLA(fn): + return expectedFailure("xla")(fn) + + +def expectedFailureHPU(fn): + return expectedFailure("hpu")(fn) + + +def expectedFailureMPS(fn): + return expectedFailure("mps")(fn) + + +def expectedFailureMPSComplex(fn): + return expectedFailure("mps", torch.complex64)(fn) + + +def expectedFailureMPSPre15(fn): + import platform + + version = float(".".join(platform.mac_ver()[0].split(".")[:2]) or -1) + if not version or version < 1.0: # cpu or other unsupported device + return fn + if version < 15.0: + return expectedFailure("mps")(fn) + return fn + + +def expectedFailureMPSPre14(fn): + import platform + + version = float(".".join(platform.mac_ver()[0].split(".")[:2]) or -1) + if not version or version < 1.0: # cpu or other unsupported device + return fn + if version < 14.0: + return expectedFailure("mps")(fn) + return fn + + +# Skips a test on CPU if LAPACK is not available. +def skipCPUIfNoLapack(fn): + return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn) + + +# Skips a test on CPU if FFT is not available. +def skipCPUIfNoFFT(fn): + return skipCPUIf(not torch._C.has_spectral, "PyTorch is built without FFT support")( + fn + ) + + +# Skips a test on CPU if MKL is not available. +def skipCPUIfNoMkl(fn): + return skipCPUIf(not TEST_MKL, "PyTorch is built without MKL support")(fn) + + +# Skips a test on CPU if MKL Sparse is not available (it's not linked on Windows). +def skipCPUIfNoMklSparse(fn): + return skipCPUIf( + IS_WINDOWS or not TEST_MKL, "PyTorch is built without MKL support" + )(fn) + + +# Skips a test on CPU if mkldnn is not available. +def skipCPUIfNoMkldnn(fn): + return skipCPUIf( + not torch.backends.mkldnn.is_available(), + "PyTorch is built without mkldnn support", + )(fn) + + +# Skips a test on CUDA if MAGMA is not available. +def skipCUDAIfNoMagma(fn): + return skipCUDAIf("no_magma", "no MAGMA library detected")( + skipCUDANonDefaultStreamIf(True)(fn) + ) + + +def has_cusolver(): + return not TEST_WITH_ROCM + + +def has_hipsolver(): + rocm_version = _get_torch_rocm_version() + # hipSOLVER is disabled on ROCM < 5.3 + return rocm_version >= (5, 3) + + +# Skips a test on CUDA/ROCM if cuSOLVER/hipSOLVER is not available +def skipCUDAIfNoCusolver(fn): + return skipCUDAIf( + not has_cusolver() and not has_hipsolver(), "cuSOLVER not available" + )(fn) + + +# Skips a test if both cuSOLVER and MAGMA are not available +def skipCUDAIfNoMagmaAndNoCusolver(fn): + if has_cusolver(): + return fn + else: + # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA + return skipCUDAIfNoMagma(fn) + + +# Skips a test if both cuSOLVER/hipSOLVER and MAGMA are not available +def skipCUDAIfNoMagmaAndNoLinalgsolver(fn): + if has_cusolver() or has_hipsolver(): + return fn + else: + # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA + return skipCUDAIfNoMagma(fn) + + +# Skips a test on CUDA when using ROCm. +def skipCUDAIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"): + def dec_fn(fn): + reason = f"skipCUDAIfRocm: {msg}" + return skipCUDAIf(TEST_WITH_ROCM, reason=reason)(fn) + + if func: + return dec_fn(func) + return dec_fn + + +# Skips a test on CUDA when not using ROCm. +def skipCUDAIfNotRocm(fn): + return skipCUDAIf( + not TEST_WITH_ROCM, "test doesn't currently work on the CUDA stack" + )(fn) + + +# Skips a test on CUDA if ROCm is unavailable or its version is lower than requested. +def skipCUDAIfRocmVersionLessThan(version=None): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if self.device_type == "cuda": + if not TEST_WITH_ROCM: + reason = "ROCm not available" + raise unittest.SkipTest(reason) + rocm_version_tuple = _get_torch_rocm_version() + if ( + rocm_version_tuple is None + or version is None + or rocm_version_tuple < tuple(version) + ): + reason = ( + f"ROCm {rocm_version_tuple} is available but {version} required" + ) + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return wrap_fn + + return dec_fn + + +# Skips a test on CUDA when using ROCm. +def skipCUDAIfNotMiopenSuggestNHWC(fn): + return skipCUDAIf( + not TEST_WITH_MIOPEN_SUGGEST_NHWC, + "test doesn't currently work without MIOpen NHWC activation", + )(fn) + + +# Skips a test for specified CUDA versions, given in the form of a list of [major, minor]s. +def skipCUDAVersionIn(versions: Optional[list[tuple[int, int]]] = None): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + version = _get_torch_cuda_version() + if version == (0, 0): # cpu or rocm + return fn(self, *args, **kwargs) + if version in (versions or []): + reason = f"test skipped for CUDA version {version}" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + + return wrap_fn + + return dec_fn + + +# Skips a test for CUDA versions less than specified, given in the form of [major, minor]. +def skipCUDAIfVersionLessThan(versions: Optional[tuple[int, int]] = None): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + version = _get_torch_cuda_version() + if version == (0, 0): # cpu or rocm + return fn(self, *args, **kwargs) + if version < versions: + reason = f"test skipped for CUDA versions < {version}" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + + return wrap_fn + + return dec_fn + + +# Skips a test on CUDA if cuDNN is unavailable or its version is lower than requested. +def skipCUDAIfCudnnVersionLessThan(version=0): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if self.device_type == "cuda": + if self.no_cudnn: + reason = "cuDNN not available" + raise unittest.SkipTest(reason) + if self.cudnn_version is None or self.cudnn_version < version: + reason = f"cuDNN version {self.cudnn_version} is available but {version} required" + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return wrap_fn + + return dec_fn + + +# Skips a test on CUDA if cuSparse generic API is not available +def skipCUDAIfNoCusparseGeneric(fn): + return skipCUDAIf(not TEST_CUSPARSE_GENERIC, "cuSparse Generic API not available")( + fn + ) + + +def skipCUDAIfNoHipsparseGeneric(fn): + return skipCUDAIf( + not TEST_HIPSPARSE_GENERIC, "hipSparse Generic API not available" + )(fn) + + +def skipCUDAIfNoSparseGeneric(fn): + return skipCUDAIf( + not (TEST_CUSPARSE_GENERIC or TEST_HIPSPARSE_GENERIC), + "Sparse Generic API not available", + )(fn) + + +def skipCUDAIfNoCudnn(fn): + return skipCUDAIfCudnnVersionLessThan(0)(fn) + + +def skipCUDAIfMiopen(fn): + return skipCUDAIf(torch.version.hip is not None, "Marked as skipped for MIOpen")(fn) + + +def skipCUDAIfNoMiopen(fn): + return skipCUDAIf(torch.version.hip is None, "MIOpen is not available")( + skipCUDAIfNoCudnn(fn) + ) + + +def skipLazy(fn): + return skipLazyIf(True, "test doesn't work with lazy tensors")(fn) + + +def skipMeta(fn): + return skipMetaIf(True, "test doesn't work with meta tensors")(fn) + + +def skipXLA(fn): + return skipXLAIf(True, "Marked as skipped for XLA")(fn) + + +def skipMPS(fn): + return skipMPSIf(True, "test doesn't work on MPS backend")(fn) + + +def skipHPU(fn): + return skipHPUIf(True, "test doesn't work on HPU backend")(fn) + + +def skipXPU(fn): + return skipXPUIf(True, "test doesn't work on XPU backend")(fn) + + +def skipPRIVATEUSE1(fn): + return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn) + + +# TODO: the "all" in the name isn't true anymore for quite some time as we have also have for example XLA and MPS now. +# This should probably enumerate all available device type test base classes. +def get_all_device_types() -> list[str]: + return ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"] + + +# skip since currently flex attention requires at least `avx2` support on CPU. +IS_FLEX_ATTENTION_CPU_PLATFORM_SUPPORTED = ( + not torch.xpu.is_available() + and not torch.cuda.is_available() + and not IS_MACOS + and torch.cpu._is_avx2_supported() + and os.getenv("ATEN_CPU_CAPABILITY") != "default" +) +IS_FLEX_ATTENTION_XPU_PLATFORM_SUPPORTED = ( + torch.xpu.is_available() and torch.utils._triton.has_triton() +) +flex_attention_supported_platform = unittest.skipUnless( + IS_FLEX_ATTENTION_XPU_PLATFORM_SUPPORTED + or IS_FLEX_ATTENTION_CPU_PLATFORM_SUPPORTED + or ( + torch.cuda.is_available() + and torch.utils._triton.has_triton() + and torch.cuda.get_device_capability() >= (8, 0) + ), + "Requires CUDA and Triton, Intel GPU and triton, or CPU with avx2 and later", +) +if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName: + e4m3_type = torch.float8_e4m3fnuz + e5m2_type = torch.float8_e5m2fnuz + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max + E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max +else: + e4m3_type = torch.float8_e4m3fn + e5m2_type = torch.float8_e5m2 + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max + E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_jit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6e851d7e28b0466f9b49862f1df78781c2a461 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_jit.py @@ -0,0 +1,323 @@ +# mypy: ignore-errors + +# Torch +import torch +import torch.cuda +import torch.jit +import torch.jit._logging +import torch.jit.frontend +import torch.jit.quantized + +# Testing utils +from torch.testing._internal.common_dtype import floating_and_complex_types_and +from torch.testing._internal.common_utils import TestCase, \ + freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, is_iterable_of_tensors +from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401 + +# Standard library +from itertools import chain +from typing import Union +from torch._C import TensorType + +import io + +def check_output_types(self, func, ref_outputs, args, kwargs): + graph = getattr(func, 'last_graph', None) + types = [o.type() for o in graph.outputs()] + self.assertTrue(len(types) == 1) + t = types[0] + torch._C._jit_assert_is_instance(ref_outputs, t) + +# Test names in this set are only checked for a single derivative +nn_functional_single_grad = frozenset('test_nn_' + name for name in [ + 'pdist', + 'multilabel_margin_loss', + 'max_unpool3d', + 'multi_margin_loss', + 'binary_cross_entropy', + 'binary_cross_entropy_size_average', + 'ctc_loss', + 'grid_sample', +]) + +def check_against_reference(self, func, reference_func, output_func, args, kwargs=None, + allow_unused=True, check_types=True, no_grad=False, no_gradgrad=False): + """Verifies a function performs identically to some reference implementation. + + Commonly, this is used to verify that a JIT implementation + (output_func) matches the behavior of the eager implementation + (reference_func). + """ + kwargs = kwargs if kwargs else {} + + def allSum(vs): + if isinstance(vs, torch.Tensor): + vs = (vs,) + return sum((i + 1) * v.sum().abs() if v.dtype.is_complex else (i + 1) * v.sum() + for i, v in enumerate(vs) + if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16)) + + def clone_tensor(t, preserve_requires_grad): + require_grad = preserve_requires_grad and t.requires_grad + return t.detach().clone().requires_grad_(require_grad) + + def clone_inputs(preserve_requires_grad: bool): + inputs: list[Union[torch.Tensor, list[torch.Tensor]]] = [] + + for arg in args: + if isinstance(arg, torch.Tensor): + inputs.append(clone_tensor(arg, preserve_requires_grad)) + elif is_iterable_of_tensors(arg): + inputs.append([clone_tensor(t, preserve_requires_grad) for t in arg]) + else: + inputs.append(arg) + + return inputs + + # Returns tensors in args that requires_grad, including tensors in TensorList args + def get_recording_tensors(args): + recording_tensors: list[torch.Tensor] = [] + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.requires_grad: + recording_tensors.append(arg) + elif is_iterable_of_tensors(arg): + recording_tensors.extend(filter(lambda t: t.requires_grad, arg)) + + return recording_tensors + + # test no gradients case + nograd_inputs = clone_inputs(preserve_requires_grad=False) + outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs) + with enable_profiling_mode_for_profiling_tests(): + outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs) + self.assertEqual(outputs, outputs_test) + + if check_types: + check_output_types(self, func, outputs_test, nograd_inputs, kwargs) + + if no_grad: + # skip grad tests + return + + with enable_profiling_mode_for_profiling_tests(): + # test single grad case + recording_inputs = clone_inputs(preserve_requires_grad=True) + recording_tensors = get_recording_tensors(recording_inputs) + outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs)) + grads = torch.autograd.grad(allSum(outputs), recording_tensors, + allow_unused=allow_unused) + outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs)) + grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors, + allow_unused=allow_unused) + self.assertEqual(outputs, outputs_test) + self.assertEqual(grads, grads_test) + # test the grad grad case + if self._testMethodName in nn_functional_single_grad or no_gradgrad: + return + + outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs)) + l1 = allSum(outputs) + grads = torch.autograd.grad(l1, recording_tensors, create_graph=True, + allow_unused=allow_unused) + + l2 = (allSum(grads) * l1) + grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused) + recording_inputs = clone_inputs(preserve_requires_grad=True) + recording_tensors = get_recording_tensors(recording_inputs) + outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs)) + l1_test = allSum(outputs_test) + grads_test = torch.autograd.grad( + l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused) + + l2_test = (allSum(grads_test) * l1_test) + grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused) + + self.assertEqual(outputs, outputs_test) + self.assertEqual(grads, grads_test) + for g2, g2_test in zip(grads2, grads2_test, strict=True): + if g2 is None and g2_test is None: + continue + self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4) + +class JitCommonTestCase(TestCase): + def createFunctionFromGraph(self, trace): + graph = trace if isinstance(trace, torch._C.Graph) else trace.graph() + return torch._C._create_function_from_graph("forward", graph) + + def assertExportImport(self, trace, inputs): + m = self.createFunctionFromGraph(trace) + self.assertExportImportModule(m, inputs) + + def assertExportImportModule(self, m, inputs): + m_import = self.getExportImportCopy(m) + a = self.runAndSaveRNG(m, inputs) + b = self.runAndSaveRNG(m_import, inputs) + self.assertEqual(a, b, "Results of original model and " + "exported/imported version of model differed") + + def runAndSaveRNG(self, func, inputs, kwargs=None): + kwargs = kwargs if kwargs else {} + with freeze_rng_state(): + results = func(*inputs, **kwargs) + return results + + def getExportImportCopy(self, m, also_test_file=True, map_location=None): + buffer = io.BytesIO() + torch.jit.save(m, buffer) + buffer.seek(0) + imported = torch.jit.load(buffer, map_location=map_location) + + if not also_test_file: + return imported + + with TemporaryFileName() as fname: + torch.jit.save(imported, fname) + return torch.jit.load(fname, map_location=map_location) + + def autoDiffErrorMessage(self, should_autodiff_node, nodes_not_in_diff_graph, + fusion_nodes_not_found, non_fusible_nodes_being_fused, + fusion_nodes_found, nodes_in_diff_graph): + err_msg = "\nFailure in testing nodes' autodifferentiation. " + if should_autodiff_node: + err_msg += "One or more nodes were expected to be autodiffed, " \ + "but were not found in specified fusible/nonfusible " \ + "DifferentiableGraph groups. \nSpecifically:" + # The node is intended to appear in a differentiable graph but doesn't + diff_nodes_missing = [] + # The node is intended to appear in a differentiable graph + # outside of a fusion group but instead is in a fusion group + diff_nodes_in_fusion = [] + # The node is intended to appear in a fusion group but doesn't + fusion_nodes_missing = [] + # The node is intended to appear in a fusion group but instead + # is just in an outer differentiable graph + fusion_nodes_in_diff = [] + for node in nodes_not_in_diff_graph: + if node in non_fusible_nodes_being_fused: + diff_nodes_in_fusion.append(node) + else: + diff_nodes_missing.append(node) + for node in fusion_nodes_not_found: + if node in nodes_in_diff_graph: + fusion_nodes_in_diff.append(node) + else: + fusion_nodes_missing.append(node) + if len(diff_nodes_missing) > 0: + err_msg += f"\n {diff_nodes_missing} were not in one of the " \ + "DifferentiableGraphs when they were expected to be. " \ + "Did you intend for these nodes to be autodiffed? " \ + "If not, remove them from the list of nonfusible nodes." + if len(diff_nodes_in_fusion) > 0: + err_msg += f"\n {diff_nodes_in_fusion} were found in one of the FusionGroups " \ + "when they were expected to be just in a DifferentiableGraph. If it was " \ + "intended for these nodes to be in FusionGroups, reclassify these nodes as " \ + "fusible nodes. If these nodes were not intended to be fused, your " \ + "autodifferentiation logic might be wrong." + if len(fusion_nodes_missing) > 0: + err_msg += f"\n {fusion_nodes_missing} were not in one of the FusionGroups " \ + "of the DifferentiableGraphs when they were expected to be. " \ + "They were also not found in an outer DifferentiableGraph. Did you " \ + "intend for these nodes to be autodifferentiated? If not, you should " \ + "remove these nodes from the test's fusible nodes. Otherwise your " \ + "autodifferentiation logic might be wrong." + if len(fusion_nodes_in_diff) > 0: + err_msg += f"\n {fusion_nodes_in_diff} were not in one of the FusionGroups " \ + "of the DifferentiableGraphs when they were expected to be, " \ + "instead they were found just in an outer DifferentiableGraph. " \ + "Did you intend for these nodes to be fused? If not, you should " \ + "move these nodes into the test's nonfusible nodes. Otherwise your " \ + "autodifferentiation logic might be wrong." + else: + err_msg += "One or more nodes were not expected to be autodiffed " \ + "but were found in a DifferentiableGraph or in a FusionGroup " \ + "of a DifferentiableGraph. Did you intend for these nodes to be " \ + "autodiffed? If so, change this test to expect autodifferentiation. " \ + "\nSpecifically:" + if len(fusion_nodes_found) > 0: + err_msg += f"\n {fusion_nodes_found} were not expected to be in " \ + "one of the DifferentiableGraphs, but appeared in a FusionGroup " \ + "of a DifferentiableGraph. " + if len(nodes_in_diff_graph) > 0: + err_msg += f"\n {nodes_in_diff_graph} were not expected to " \ + "be in one of the DifferentiableGraphs but were." + return err_msg + + def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes): + diff_nodes = graph.findAllNodes('prim::DifferentiableGraph') + diff_subgraphs = [node.g('Subgraph') for node in diff_nodes] + + # Note: currently no tests have fusible_nodes + fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs])) + fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes] + + # For any non-fusible node, it must show up in one of the DifferentiableGraphs. + nodes_in_diff_graph = [] + nodes_not_in_diff_graph = [] + non_fusible_nodes_being_fused = [] + for node in nonfusible_nodes: + if any(g.findNode(node) is not None for g in diff_subgraphs): + nodes_in_diff_graph.append(node) + else: + nodes_not_in_diff_graph.append(node) + if any(g.findNode(node) is not None for g in fusion_subgraphs): + non_fusible_nodes_being_fused.append(node) + found_all_nonfusible_nodes = len(nodes_in_diff_graph) == len(nonfusible_nodes) + + # For any fusible node, it must show up in one of the FusionGroups in one of the DifferentiableGraphs. + fusion_nodes_found = [] + fusion_nodes_not_found = [] + for node in fusible_nodes: + if any(g.findNode(node) is not None for g in fusion_subgraphs): + fusion_nodes_found.append(node) + else: + fusion_nodes_not_found.append(node) + found_all_fusible_nodes = len(fusion_nodes_found) == len(fusible_nodes) + + if should_autodiff_node is not None: + err_msg = self.autoDiffErrorMessage(should_autodiff_node, + nodes_not_in_diff_graph, + fusion_nodes_not_found, + non_fusible_nodes_being_fused, + fusion_nodes_found, + nodes_in_diff_graph) + self.assertEqual(should_autodiff_node, + found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg) + + def checkShapeAnalysis(self, out_sizes: Union[list[int], list[list[int]]], + traced_graph, assert_propagation, constant_prop=True): + # repropagte input shapes provided by tracing, + prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled() + for enable_test_mode in [True, False]: + # here we are testing allowing/disallowing substituting in complete shapes as constants, + # disallowing constants helps stress test partial eval and substitution pipeline + torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode) + torch._C._jit_erase_non_input_shape_information(traced_graph) + if constant_prop: + torch._C._jit_pass_constant_propagation(traced_graph) + torch._C._jit_pass_propagate_shapes_on_graph(traced_graph) + # Add sizes to default tensor type to avoid checking something out of scope + # and difficulties with tracer leaving in other parts of tensor type + output = next(traced_graph.outputs()).type() + + def test_type(type, actual_size): + sizes = type.symbolic_sizes() + out_type = TensorType.get().with_sizes(sizes) + actual_type = TensorType.get().with_sizes(actual_size) + + # always check actual shape is a subtype of the output + self.assertTrue(actual_type.isSubtypeOf(out_type)) + + # and then if assertion flag is provided, check shape analysis + # is successful + if assert_propagation: + self.assertEqual(out_type.sizes(), actual_size) + + if output.isSubtypeOf(torch._C.TensorType.get()): + test_type(output, out_sizes) + else: + tuple_elements = output.elements() + for i in range(len(tuple_elements)): + test_type(tuple_elements[i], out_sizes[i]) + + torch._C._jit_set_symbolic_shapes_test_mode(prev_symbolic_shapes_test_enabled) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_methods_invocations.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_methods_invocations.py new file mode 100644 index 0000000000000000000000000000000000000000..dac77fe9aa731ade2f96a87abb16af21699e2a6a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_methods_invocations.py @@ -0,0 +1,25236 @@ +# mypy: ignore-errors + +from functools import wraps, partial +from itertools import product, chain, islice +import itertools +import functools +import copy +import operator +import random +import unittest +import math +import enum + +import torch +import numpy as np +import numpy.typing as npt +from torch import inf, nan + +from typing import Any, Union +from collections.abc import Sequence +from torch.testing import make_tensor +from torch.testing._internal.common_dtype import ( + _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, + floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, + empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types, +) +from torch.testing._internal.common_device_type import \ + (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, + skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride, + skipCPUIfNoMklSparse, + toleranceOverride, tol, skipXPU) +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, +) +from torch.testing._internal.common_quantized import ( + _bfloat16_to_float4_e2m1fn_x2, +) +from torch.testing._internal.common_utils import ( + make_fullrank_matrices_with_distinct_singular_values, + TEST_WITH_ROCM, IS_FBCODE, IS_WINDOWS, IS_MACOS, IS_S390X, TEST_SCIPY, + torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN, + GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW, + TEST_WITH_TORCHINDUCTOR, MACOS_VERSION, +) +from torch.testing._utils import wrapper_set_seed + +import torch._refs as refs # noqa: F401 +import torch._refs.nn.functional +import torch._refs.special +import torch._refs.linalg +import torch._prims as prims # noqa: F401 +from torch.utils import _pytree as pytree + + +from torch._vendor.packaging import version + +from torch.testing._internal.opinfo.core import ( # noqa: F401 + L, + M, + S, + XS, + _NOTHING, + _getattr_qual, + DecorateInfo, + SampleInput, + ErrorInput, + AliasInfo, + NumericsFilter, + OpInfo, + _generate_reduction_inputs, + _generate_reduction_kwargs, + sample_inputs_reduction, + ReductionOpInfo, + reference_inputs_elementwise_binary, + make_error_inputs_elementwise_binary, + generate_elementwise_binary_tensors, + generate_elementwise_binary_arbitrarily_strided_tensors, + generate_elementwise_binary_small_value_tensors, + generate_elementwise_binary_large_value_tensors, + generate_elementwise_binary_extremal_value_tensors, + generate_elementwise_binary_broadcasting_tensors, + generate_elementwise_binary_with_scalar_samples, + generate_elementwise_binary_with_scalar_and_type_promotion_samples, + generate_elementwise_binary_noncontiguous_tensors, + sample_inputs_elementwise_binary, + BinaryUfuncInfo, + sample_inputs_elementwise_unary, + generate_elementwise_unary_tensors, + generate_elementwise_unary_small_value_tensors, + generate_elementwise_unary_large_value_tensors, + generate_elementwise_unary_extremal_value_tensors, + reference_inputs_elementwise_unary, + UnaryUfuncInfo, + sample_inputs_spectral_ops, + SpectralFuncType, + SpectralFuncInfo, + ShapeFuncInfo, + sample_inputs_foreach, + ForeachFuncInfo, + gradcheck_wrapper_hermitian_input, + gradcheck_wrapper_ctc_loss, + gradcheck_wrapper_triangular_input, + gradcheck_wrapper_triangular_input_real_positive_diagonal, + gradcheck_wrapper_masked_operation, + gradcheck_wrapper_masked_pointwise_operation, + clone_sample, +) +from torch.testing._internal.opinfo.refs import ( # NOQA: F401 + _find_referenced_opinfo, + _inherit_constructor_args, + PythonRefInfo, + ReductionPythonRefInfo, + ElementwiseUnaryPythonRefInfo, + ElementwiseBinaryPythonRefInfo, +) +from torch.testing._internal.opinfo.utils import ( + np_unary_ufunc_integer_promotion_wrapper, + reference_reduction_numpy, + prod_numpy +) +from torch.testing._internal import opinfo +from torch.testing._internal.opinfo.definitions.linalg import ( + sample_inputs_linalg_cholesky, + sample_inputs_linalg_cholesky_inverse, + sample_inputs_cross, + sample_inputs_linalg_qr_geqrf, + sample_inputs_linalg_invertible, + sample_inputs_lu_solve, + sample_inputs_legacy_solve, + sample_inputs_svd, + sample_inputs_linalg_det_logdet_slogdet, + sample_inputs_linalg_lu, + sample_inputs_diagonal_diag_embed, + error_inputs_diagonal_diag_embed, +) +from torch.testing._internal.opinfo.definitions.special import ( + sample_inputs_i0_i1, + sample_inputs_polygamma, + reference_polygamma, +) +from torch.testing._internal.opinfo.definitions._masked import ( + sample_inputs_softmax_variant, +) +from torch.testing._internal.opinfo.definitions.sparse import ( + error_inputs_sparse_like_fns, + sample_inputs_sparse_like_fns, + error_inputs_sparse_mul, + sample_inputs_sparse_mul, + error_inputs_sparse_reduction_sum, + sample_inputs_sparse_reduction_sum +) + +if TEST_SCIPY: + from scipy import stats + import scipy.spatial + import scipy.special + + +def round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + +# test if a tensor is close to an integer +def close_to_int(x, eps=0.1): + if x.is_complex(): + y = torch.abs(torch.view_as_complex(torch.frac(torch.view_as_real(x)))) + else: + y = torch.abs(torch.frac(x)) + return (y < eps) | (y > (1 - eps)) + + +def sample_inputs_slice(op_info, device, dtype, requires_grad, **kwargs): + + make_input = partial(make_tensor, device=device, dtype=dtype, + low=None, high=None, requires_grad=requires_grad) + + yield SampleInput(make_input(3), 0) + + yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2) + + yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2, step=3) + + yield SampleInput(make_input(20, 30, 40), dim=0, start=-10, end=-2, step=2) + + +def sample_inputs_tensor_split(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, + low=None, high=None, requires_grad=requires_grad) + + args_cases = ( + # Cases with tensor indices. + (torch.tensor([1, 2, 3]),), + (torch.tensor(1),), + (torch.tensor([1, 2, 3]), 1), + (torch.tensor([1, 4, 2, 5, 3, 6])[::2], 1), + # Cases with list of indices. + ((2, 4),), + ((2, 4), 1), + ((2, 4), -1), + # Cases with integer section. + (3,), + (3, 1), + (3, -1), + ) + + for args in args_cases: + yield SampleInput(make_input((S, S, S)), args=args) + + +def sample_inputs_hsplit(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, + low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(6), 2) + yield SampleInput(make_arg(S, S, S), [1, 2, 3]) + +def sample_inputs_vsplit(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, + low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(6, S), 2) + yield SampleInput(make_arg(S, S, S), [1, 2, 3]) + +def sample_inputs_dsplit(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, + low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(S, S, S), [1, 2, 3]) + yield SampleInput(make_arg(S, S, 6), 2) + +def error_inputs_hsplit(op_info, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + err_msg1 = ("torch.hsplit requires a tensor with at least 1 dimension, " + "but got a tensor with 0 dimensions!") + yield ErrorInput(SampleInput(make_arg(()), 0), error_regex=err_msg1) + + err_msg2 = (f"torch.hsplit attempted to split along dimension 1, " + f"but the size of the dimension {S} " + f"is not divisible by the split_size 0!") + yield ErrorInput(SampleInput(make_arg((S, S, S)), 0), error_regex=err_msg2) + + # Incorrect type for indices_or_section argument + err_msg3 = ("received an invalid combination of arguments.") + yield ErrorInput( + SampleInput(make_arg((S, S, S)), "abc"), + error_type=TypeError, error_regex=err_msg3) + +def error_inputs_vsplit(op_info, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + err_msg1 = ("torch.vsplit requires a tensor with at least 2 dimension, " + "but got a tensor with 1 dimensions!") + yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1) + + err_msg2 = (f"torch.vsplit attempted to split along dimension 0, " + f"but the size of the dimension {S} " + f"is not divisible by the split_size 0!") + yield ErrorInput(SampleInput(make_arg(S, S, S), 0), + error_regex=err_msg2) + + # Incorrect type for indices_or_section argument + err_msg3 = ("received an invalid combination of arguments.") + yield ErrorInput(SampleInput(make_arg(S, S, S), "abc"), + error_type=TypeError, error_regex=err_msg3) + +def error_inputs_dsplit(op_info, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + err_msg1 = ("torch.dsplit requires a tensor with at least 3 dimension, " + "but got a tensor with 1 dimensions!") + yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1) + + err_msg2 = (f"torch.dsplit attempted to split along dimension 2, " + f"but the size of the dimension {S} " + f"is not divisible by the split_size 0!") + yield ErrorInput(SampleInput(make_arg(S, S, S), 0), error_regex=err_msg2) + + +def sample_inputs_as_strided(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # input shape, output shape, output stride, output storage offset + test_cases = ( + ((1,), (1,), (1,), 0), + ((3, 3), (2, 2), (1, 2), 0), + ((3, 3), (2, 2), (1, 2), 1), + ((16,), (2, 2, 2, 2), (1, 1, 1, 1), 0), + ((16,), (2, 1, 1, 2), (1, 7, 7, 1), 0), + ) + + for input_shape, output_shape, stride, storage_offset in test_cases: + input_t = make_arg(input_shape) + kwargs = dict(storage_offset=storage_offset) + yield SampleInput(input_t, args=(output_shape, stride), kwargs=kwargs) + +def sample_inputs_as_strided_partial_views(op_info, device, dtype, requires_grad, **kwargs): + def make_arg(): + base = make_tensor((20,), device=device, dtype=dtype) + return base[5:15].requires_grad_(requires_grad) + + # as_strided on offset, partial views + yield SampleInput(make_arg(), (2, 2), (1, 2)) + yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=0) + yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=10) + +def sample_inputs_as_strided_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # input shape, output shape, output stride, output storage offset + test_cases = [ + ((1,), (), (), 0), + ((1,), (1,), (1,), 0), + ((3, 3), (2, 2), (1, 2), 0), + ((3, 3), (2, 2), (1, 2), 1), + ((3, 3), (2, 2), (2, 1), 0), + # Scatter to larger dimensions + ((16,), (2, 2, 2, 2), (8, 4, 2, 1), 0), + # Scatter to larger dimensions with strides inverted + ((16,), (2, 1, 1, 2), (1, 2, 4, 8), 0), + ] + + for input_shape, output_shape, stride, storage_offset in test_cases: + input_t = make_arg(input_shape) + input_src = make_arg(output_shape) + yield SampleInput(input_t, input_src, output_shape, stride, storage_offset=storage_offset) + + +def error_inputs_as_strided_scatter(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + + # Create a small tensor and try to scatter it out of bounds + input_t = make_arg([4, 4]) + input_src = make_arg([2, 2]) + yield ErrorInput( + SampleInput(input_t, input_src, [2, 2], [200, 200], storage_offset=0), + error_regex="itemsize 4 requiring a storage size of 1604 are out of bounds for storage of size 64" + ) + + +def sample_inputs_combinations(op_info, device, dtype, requires_grad, **kwargs): + inputs = ( + (0,), + (0, 1), + (0, 1, 2, 3), + ) + + rvals = [1, 2, 4] + + products = product(inputs, rvals, [False, True]) + + for input_data, r, with_replacement in products: + input_t = torch.tensor(input_data, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(input_t, r=r, with_replacement=with_replacement) + +def sample_inputs_cartesian_prod(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(torch.tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # constructs 1-D tensors with varying number of elements + a = make_arg((0,)) + b = make_arg((0, 1)) + c = make_arg((0, 1, 2, 3)) + + # sample with only 1 tensor + yield SampleInput(a) + + # sample with 2 tensors + yield SampleInput(a, b) + + # sample with 3 tensors + yield SampleInput(a, b, c) + +def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input_shape, dict of dim and eps + cases: tuple[tuple, dict] = ( # type: ignore[assignment] + ((S, S), {'dim': 1}), + ((S, 2), {'dim': -1}), + ((S,), {'dim': 0, 'eps': 0.5}), + ((), {'dim': 0}), + ((S, S, M), {'dim': 2}), + ((S, S), {}) + ) + + for input_shape, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs) + # Test for Broadcasting + yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1}) + yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2}) + yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1}) + + +def sample_inputs_item(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + + cases = ( + (), + (()), + (1), + ((1,)), + ) + + for shape in cases: + yield SampleInput(make_arg(shape)) + +def error_inputs_item(op, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False) + + cases = ( + (M), + ((S,)), + (S, S), + (S, M, L), + ) + + for shape in cases: + yield ErrorInput( + SampleInput(make_arg(shape)), error_type=RuntimeError, + error_regex="elements cannot be converted to Scalar") + + +def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + # Ordered as: input shape, kwargs for training, momentum, eps + cases: tuple[tuple[int, ...], dict] = ( + ((S, S, S), {'training': True, 'momentum': 0.5, 'eps': 0.6}), + ((3, 2, 4), {'training': False, 'momentum': -1.2}), + ((3, 1), {'training': True, 'momentum': 0.0}), + ((0,), {'training': True}), + ((0,), {'training': False}), + ((3, 2, 3, 4), {'training': True, 'momentum': -1.0, 'eps': 0.5}), + ((3, 2, 3, 4), {'training': False, 'momentum': -1.0, 'eps': 0.5}), + ((2, 1), {}), + ) + + for input_shape, kwargs in cases: + # args: running mean, running var, weight and bias should necessarily be of shape: (channels,) + channels = input_shape[1] if len(input_shape) > 1 else 0 + weight = make_arg(channels) if channels > 0 else None + bias = make_arg(channels) if channels > 0 else None + running_mean = make_arg_without_requires_grad(channels, low=0) + running_var = make_arg_without_requires_grad(channels, low=0) + + yield SampleInput( + make_arg(input_shape), + args=( + running_mean, + running_var, + weight, + bias + ), + kwargs=kwargs + ) + + # Checking for permutations of weights and biases as `None` + is_training = [True, False, False] + + for training in is_training: + yield SampleInput( + make_arg(input_shape), + args=( + running_mean, + running_var, + make_arg(channels), + make_arg(channels) + ), + kwargs={'training': training} + ) + + # Test case for no optional kwargs + # running_mean and running_var are required in evaluation mode (training: False) but not in training mode + yield SampleInput(make_arg((1, 2, 3)), args=(None, None, None, None), kwargs={'training': True}) + +def sample_inputs_softmax_backward_data(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + cases = [ + ((S,), 0), + ((S, S), 0), + ((S, M, S), -1), + ] + input_dtypes = [dtype] + if dtype == torch.float and device == 'cuda': + input_dtypes += [torch.float16] + + for (shape, dim), input_dtype in product(cases, input_dtypes): + input = make_arg(shape) + output = torch.nn.functional.softmax(input, dim=dim, dtype=input_dtype) + yield SampleInput(make_arg(shape), output, dim, input_dtype) + +def sample_inputs_native_batch_norm(op_info, device, dtype, requires_grad, **kwargs): + samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + training = sample.kwargs.get('training', True) + momentum = sample.kwargs.get('momentum', 0.5) + eps = sample.kwargs.get('eps', 1e-5) + yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps)) + + +def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs): + samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + training = sample.kwargs.get('training', True) + momentum = sample.kwargs.get('momentum', 0.5) + eps = sample.kwargs.get('eps', 1e-5) + if args[0] is not None and args[1] is not None: + yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps)) + else: + yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps)) + +def sample_inputs__batch_norm_with_update(op_info, device, dtype, requires_grad, **kwargs): + samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + momentum = sample.kwargs.get('momentum', 0.5) + eps = sample.kwargs.get('eps', 1e-5) + if any(args[i] is None for i in range(4)): + continue + yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], momentum, eps)) + +def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + (()), + ((S, )), + ((S, S)), + ((S, M, S)) + ) + + for shape in cases: + yield SampleInput(make_arg(shape)) + +def sample_inputs_prelu(op_info, device, dtype, requires_grad, **kwargs): + op_kwargs = op_info.sample_kwargs(device, dtype, None)[0] + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad, + op_kwargs=op_kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + (()), + ((S, )), + ((S, S)), + ((S, M, S)) + ) + + for shape in cases: + for weight in [-1., 0., 0.8, 1.]: + weight_tensor = torch.tensor(weight, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(shape), args=(weight_tensor,)) + + channel_size = shape[1] if len(shape) >= 2 else 1 + yield SampleInput(make_arg(shape), args=(make_arg((channel_size,)),)) + + weight_tensor = torch.tensor(1., device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S)), kwargs=dict(weight=weight_tensor,)) + yield SampleInput(make_arg((S, S)), kwargs=dict(weight=make_arg((S,)),)) + +def reference_inputs_prelu(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_prelu(op, device, dtype, requires_grad, **kwargs) + yield from reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs) + +def sample_kwargs_prelu_scalar_weight(device, dtype, input): + weight = torch.rand((), device=device, dtype=dtype) + # NumPy does not support bfloat16, so we default to float32 (only for NumPy) in that case + if dtype == torch.bfloat16: + weight_cpu = weight.to(dtype=torch.float32, device="cpu") + else: + weight_cpu = weight.cpu() + np_weight = weight_cpu.numpy() + return ({'weight': weight}, {'weight': np_weight}) + +def error_inputs_prelu(op, device): + # Weight has numel != 1, but self.ndim is zero-dim tensor + inp = make_tensor((), device=device, dtype=torch.float32) + weight = make_tensor((2,), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), + error_regex="Not allow zero-dim input tensor.") + + # Weight has numel != 1, but numel does not match channel size + inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32) + weight = make_tensor((9,), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), + error_regex="Mismatch of parameter numbers and input channel size.") + + # Weight is neither a scalar nor 1-D tensor + inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32) + weight = make_tensor((2, 4), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), + error_regex="prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = 2") + + # src and index tensors must have the same # of dimensions +def sample_inputs_norm(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # ord = inf is tested in inputs_norm_inf as it fails on some tests + cases = [ + ((S, S), (2,), '2'), + ((S, S), (0,), '0'), + ((S, S), (0.5,), '0_5'), + ((S, S), (1,), '1'), + ((S, S), (3,), '3'), + ((S, S), (-1,), 'neg_1'), + ((S, S), (-2,), 'neg_2'), + ((S, S), (-0.5,), 'neg_0_5'), + ((S, S), (-1.5,), 'neg_1_5'), + ] + + cases_nonzero_input = ( + ((S, S, S), (1.5,), '1_5_default'), + ((S, S, S), (1.5, 1), '1_5_dim'), + ((S, S, S), (1.5, -1), '1_5_neg_dim'), + ((S, S, S), (1.5, 1, True), 'keepdim_1_5_dim'), + ((S, S, S), (1.5, -1, True), 'keepdim_1_5_neg_dim'), + ) + + cases_posdim = ( + ((S, S), (-2, 1,), 'neg_2_dim'), + ((S, S), (-1, 1,), 'neg_1_dim'), + ((S, S), (0, 1,), '0_dim'), + ((S, S), (1, 1,), '1_dim'), + ((S, S), (2, 1,), '2_dim'), + ((S, S), (3, 1,), '3_dim'), + ((S, S, S), (2, 1), '2_dim'), + ((S, S, S), (3, 1), '3_dim'), + ((S, S, S), (2, 1, True), 'keepdim_2_dim'), + ((S, S, S), (3, 1, True), 'keepdim_3_dim'), + ((), (2, 0), '2_dim_scalar'), + ((), (3, 0), '3_dim_scalar'), + ((), (2, 0, True), 'keepdim_2_dim_scalar'), + ((), (3, 0, True), 'keepdim_3_dim_scalar'), + ) + + cases_negdim = ((shape, args[:1] + (-args[1],) + args[2:], name.replace("_dim", "_neg_dim")) + for shape, args, name in cases_posdim) + + for shape, args, name in itertools.chain(cases, cases_posdim, cases_negdim): + yield SampleInput(make_arg(shape), args=args, name=name) + + for shape, args, name in cases_nonzero_input: + yield SampleInput(make_arg(shape, exclude_zero=True), args=args, name=name) + + +def sample_inputs_norm_fro(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + ((S, S), (), 'default'), + ((S, S), ('fro',), 'fro_default'), + ((S, S), ('fro', [0, 1],), 'fro'), + ) + + for shape, args, name in cases: + yield SampleInput(make_arg(shape), args=args, name=name) + + +def sample_inputs_norm_nuc(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + ((S, S), ('nuc',), 'nuc'), + ((S, S, S), ('nuc', [1, 2]), 'nuc_batched'), + ) + + for shape, args, name in cases: + yield SampleInput(make_arg(shape), args=args, name=name) + + +def sample_inputs_norm_inf(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + ((S, S), (-inf,), '-inf'), + ((S, S), (inf,), 'inf'), + ((S, S), (inf, 1,), 'inf_2_dim'), + ((S, S), (inf, -1,), 'inf_2_neg_dim'), + ) + + for shape, args, name in cases: + yield SampleInput(make_arg(shape), args=args, name=name) + + +def sample_inputs_equal(op, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes = ( + ((), ()), + ((S,), ()), + ((), (S,)), + ((S, 1), (S,)), + ((M, S), ()), + ((S, S), (S, S)) + ) + + for shape_lhs, shape_rhs in shapes: + lhs = make_arg(shape_lhs) + rhs = make_arg(shape_rhs) + broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs) + + yield SampleInput(lhs, args=(rhs,), broadcasts_input=broadcasts_input) + if shape_lhs == shape_rhs: + yield SampleInput(lhs, args=(lhs.clone().detach_(),)) + + +def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes = ( + ((), ()), + ((S,), ()), + ((S, 1), (S,)), + ((M, S), ()), + ((S, M, S), (M, S)), + ((S, M, S), (S, M, S)), + ((M, 1, S), (M, S)), + ((M, 1, S), (1, M, S)), + ((0, 1, 3), (0, 10, 3)) + ) + + num_inputs = kwargs.get('num_inputs') + sample_kwargs = kwargs.get('sample_kwargs', {}) + + for shape_lhs, shape_rhs in shapes: + lhs = make_arg(shape_lhs) + args = [make_arg(shape_rhs) for _ in range(num_inputs - 1)] + broadcasts_input = (shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)) + + yield SampleInput(lhs, args=tuple(args), kwargs=sample_kwargs, broadcasts_input=broadcasts_input) + +def sample_inputs_broadcast_shapes(op, device, dtype, requires_grad, **kwargs): + shapes = ( + ((), ()), + ((S,), ()), + ((S, 1), (S,)), + ((S, 1), S), + ((M, S), ()), + ((S, M, S), (M, S)), + ((S, M, S), (S, M, S)), + ((M, 1, S), (M, S)), + ((M, 1, S), (1, M, S)), + ((0, 1, 3), (0, 10, 3)) + ) + + for shape in shapes: + inp, *arg0 = shape + yield SampleInput(inp, args=tuple(arg0)) + +def sample_inputs_add_sub(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) + + # Adds alpha kwarg cases + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs) + rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs) + if dtype is not torch.bool: + yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': 2}) + else: + yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': True}) + neg_alpha = -3.125 if (dtype.is_floating_point or dtype.is_complex) else -3 + lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs) + rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs) + if dtype is not torch.bool: + yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': neg_alpha}) + else: + yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': False}) + +def error_inputs_arange(op, device, **kwargs): + yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzero') + yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') + yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') + yield ErrorInput(SampleInput(1549556900, args=(1549556828, 1989724)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') + yield ErrorInput(SampleInput(0, args=(float('inf'), 2)), error_type=RuntimeError, error_regex='unsupported range') + yield ErrorInput(SampleInput(float('-inf'), args=(1, 2)), error_type=RuntimeError, error_regex='unsupported range') + +def sample_inputs_arange(op, device, dtype, requires_grad, **kwargs): + int_samples = ( + # positive direction + (-1, 2, 2), + # negative direction + (2, -3, -1), + # start == end + (1, 1, 1), + (1, 1, -1), + # divides evenly + (0, -8, -4), + (1, 5, 2), + # bool + (False, True, True), + # default step + (0, 1, None), + # default start + (None, 3, None), + ) + + def to_float(start, end, step): + start = start + 0.1 if start is not None else None + end = end + 0.1 + step = float(step) if step is not None else None + return start, end, step + + float_samples = ( + # includes endpoint + (0., -8. - 1e-6, -4.), + (1., 5. + 1e-6, 2.), + (0., -8., -4.), + (1., 5., 2.), + *(to_float(start, end, step) for (start, end, step) in int_samples), + ) + + large_samples = ( + (0, 10000, None), + ) + + samples = int_samples + float_samples + if dtype not in (torch.int8, torch.uint8): + samples += large_samples + + for start, end, step in samples: + if start is None: + assert step is None + # Pass end as positional arg + yield SampleInput(end, kwargs={"dtype": dtype, "device": device}) + # (Similar to) calling torch.arange(end=3) + yield SampleInput(0, kwargs={"end": end, "dtype": dtype, "device": device}) + elif step is None: + yield SampleInput(start, args=(end,), kwargs={"dtype": dtype, "device": device}) + else: + yield SampleInput(start, args=(end, step), kwargs={"dtype": dtype, "device": device}) + + yield SampleInput(2) + yield SampleInput(1, args=(3, 1)) + +def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): + shapes = ( + (M,), + (S, S) + ) + + for shape in shapes: + yield SampleInput(input=shape, kwargs=dict(dtype=dtype, device=device, requires_grad=requires_grad)) + +def sample_inputs_normal(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((S, S), 0, 5), + ((S, S, S), -2, 0.5), + ) + for shape, mean, std in samples: + yield SampleInput(make_arg(shape), args=(mean, std)) + +def error_inputs_normal(op, device, **kwargs): + t = torch.zeros([10], device=device) + invalid_std = -1 + yield ErrorInput( + SampleInput(t, args=(0, invalid_std)), + error_type=RuntimeError, + error_regex=fr"normal expects std >= 0.0, but found std {invalid_std}", + ) + +def sample_inputs_cauchy(op, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), 0, 0.5), + ((S, S), 0, 1), + ((S, S, S), -2, 1), + ) + for shape, median, gamma in samples: + yield SampleInput(make_arg(shape), args=(median, gamma)) + + +def error_inputs_cauchy(op, device, **kwargs): + t = torch.zeros([10], device=device) + invalid_scale = 0 + yield ErrorInput( + SampleInput(t, args=(0, invalid_scale,)), + error_type=RuntimeError, + error_regex=fr"cauchy_ expects sigma > 0.0, but found sigma={invalid_scale}", + ) + + +def sample_inputs_exponential(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), 0.5), + ((S, S), 1), + ((S, S, S), 1.5), + ) + for shape, rate in samples: + yield SampleInput(make_arg(shape), args=(rate,)) + + +def error_inputs_exponential(op, device, **kwargs): + t = torch.zeros([10], device=device) + invalid_rate = 0 + yield ErrorInput( + SampleInput(t, args=(invalid_rate,)), + error_type=RuntimeError, + error_regex=fr"exponential_ expects lambda > 0.0, but found lambda={invalid_rate}", + ) + + +def sample_inputs_geometric(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), 0.2), + ((S, S), 0.5), + ((S, S, S), 0.8), + ) + for shape, rate in samples: + yield SampleInput(make_arg(shape), args=(rate,)) + + +def error_inputs_geometric(op, device, **kwargs): + t = torch.zeros([10], device=device) + neg_prob = -1 + yield ErrorInput( + SampleInput(t, args=(neg_prob,)), + error_type=RuntimeError, + error_regex=fr"geometric_ expects p to be in \(0, 1\), but got p={neg_prob}", + ) + + +def sample_inputs_log_normal(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), 0, 0.25), + ((S, S), 0.5, 1), + ((S, S, S), 0, 0.5), + ) + for shape, mean, std in samples: + yield SampleInput(make_arg(shape), args=(mean, std)) + + +def error_inputs_log_normal(op, device, **kwargs): + t = torch.zeros([10], device=device) + invalid_std = 0 + yield ErrorInput( + SampleInput(t, args=(0, invalid_std)), + error_type=RuntimeError, + error_regex=fr"log_normal_ expects std > 0.0, but found std={invalid_std}", + ) + + +def sample_inputs_uniform(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), -100, 100), + ((S, S), 0, 1), + ((S, S, S), 1, 2), + ) + for shape, hi, lo in samples: + yield SampleInput(make_arg(shape), args=(hi, lo)) + +def sample_inputs_ones_zeros(op, device, dtype, requires_grad, **kwargs): + # this is a bit messy, as we want the args to be tuples + # so if we pass size as a tuple, we have a tuple containing a tuple + sizes = ( + (M,), + (S, S), + ) + for size in sizes: + yield SampleInput(size, kwargs={'dtype': dtype, 'device': device}) + +def sample_inputs_full(op, device, dtype, requires_grad, **kwargs): + def get_val(dtype): + return make_tensor([], dtype=dtype, device="cpu").item() + + sizes = ( + (M,), + (S, S), + ) + fill_values = [get_val(dtype), get_val(torch.int)] + + for size, fill_value in product(sizes, fill_values): + yield SampleInput(size, fill_value, dtype=dtype, device=device) + + +def error_inputs_uniform(op, device, **kwargs): + t = torch.zeros([10], device=device) + yield ErrorInput( + SampleInput(t, args=(3, -1)), + error_type=RuntimeError, + error_regex=r"uniform_ expects to return a \[from, to\) range, but found from=3 > to=-1", + ) + + +def error_inputs_linspace(op, device, **kwargs): + yield ErrorInput(SampleInput(0, args=(3, -1)), error_type=RuntimeError, error_regex='number of steps must be non-negative') + yield ErrorInput( + SampleInput(0, args=(3, 1.)), + error_type=TypeError, + error_regex="received an invalid combination of arguments - got \\(int, int, float", + ) + yield ErrorInput( + SampleInput(torch.tensor([1, 1], device=device), args=(torch.tensor([3, 3], device=device), 1)), + error_type=RuntimeError, + error_regex="only supports 0-dimensional start and end tensors" + ) + + +def sample_inputs_linspace(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1, 4, 50) + starts = (-2., 0, 4.3, 50) + nsteps = (0, 1, 50) + # Extra case to replicate off-by-one issue on CUDA + cases = list(product(starts, ends, nsteps)) + [(0, 7, 50)] + for start, end, nstep in cases: + if dtype == torch.uint8 and (end < 0 or start < 0): + continue + yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device}) + + yield SampleInput(1, args=(3, 1)) + + +def sample_inputs_linspace_tensor_overload(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1, 4, 50) + starts = (-2., 0, 4.3, 50) + nsteps = (0, 1, 50) + is_start_end_tensors = ((True, True), (True, False), (False, True)) + make_arg = partial(torch.tensor, device=device, requires_grad=False) + + # Extra case to replicate off-by-one issue on CUDA + cases = list(product(starts, ends, nsteps, is_start_end_tensors)) + [(0, 7, 50, (True, True))] + for start, end, nstep, (is_start_tensor, is_end_tensor) in cases: + if dtype == torch.uint8 and (end < 0 or start < 0): + continue + + tensor_options = {"dtype": dtype, "device": device} + if is_start_tensor: + start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64) + if is_end_tensor: + end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64) + + yield SampleInput(start, args=(end, nstep), kwargs=tensor_options) + + yield SampleInput(1, args=(3, 1)) + + +def sample_inputs_logspace(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1.2, 2, 4) + starts = (-2., 0, 1, 2, 4.3) + nsteps = (0, 1, 2, 4) + bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.) + for start, end, nstep, base in product(starts, ends, nsteps, bases): + if dtype == torch.uint8 and end < 0 or start < 0: + continue + if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point): + # https://github.com/pytorch/pytorch/issues/82242 + continue + if base is None: + yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device}) + else: + yield SampleInput(start, args=(end, nstep, base), kwargs={"dtype": dtype, "device": device}) + + yield SampleInput(1, args=(3, 1, 2.)) + + +def sample_inputs_logspace_tensor_overload(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1.2, 2, 4) + starts = (-2., 0, 1, 2, 4.3) + nsteps = (0, 1, 2, 4) + bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.) + is_start_end_tensors = ((True, True), (True, False), (False, True)) + make_arg = partial(torch.tensor, device=device) + for start, end, nstep, base, (is_start_tensor, is_end_tensor) in product(starts, ends, nsteps, bases, is_start_end_tensors): + if dtype == torch.uint8 and end < 0 or start < 0: + continue + if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point): + # https://github.com/pytorch/pytorch/issues/82242 + continue + + tensor_options = {"dtype": dtype, "device": device} + + if (is_start_tensor): + start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64) + if (is_end_tensor): + end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64) + + if base is None: + yield SampleInput(start, args=(end, nstep), kwargs=tensor_options) + else: + yield SampleInput(start, args=(end, nstep, base), kwargs=tensor_options) + + yield SampleInput(1, args=(3, 1, 2.)) + + +def sample_inputs_isclose(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) + + # Creates additional inputs to test the rtol, atol, and equal_nan params + rtols = [0., 1e-7] + atols = [0., 1e-7] + equal_nans = [False, True] + + products = product(rtols, atols, equal_nans) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + for rtol, atol, equal_nan in products: + lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs) + rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs) + + yield SampleInput(lhs, args=(rhs,), + kwargs=dict(rtol=rtol, atol=atol, equal_nan=equal_nan)) + + +def error_inputs_isclose(op, device, **kwargs): + make_float_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + + yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'rtol': -0.4}), + error_type=RuntimeError, + error_regex='rtol must be greater than or equal to zero') + + yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'atol': -0.4}), + error_type=RuntimeError, + error_regex='atol must be greater than or equal to zero') + + +def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg((1, 2))) + yield SampleInput(make_arg((2,))) + yield SampleInput(make_arg(())) + + +def sample_inputs_mm(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_arg_conj(size): + return make_arg(size).conj().requires_grad_(requires_grad) + + first_shape, second_shape = (S, M), (M, S) + + yield SampleInput(make_arg(first_shape), args=(make_arg(second_shape),)) + + if dtype.is_complex: + yield SampleInput(make_arg(first_shape), args=(make_arg_conj(second_shape),)) + + # Matmul of empty matrices + yield SampleInput(make_arg((0, S)), args=(make_arg(S, M),)) + yield SampleInput(make_arg((S, 0)), args=(make_arg(0, M),)) + + +def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs): + alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6 if dtype.is_floating_point else 2) + beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2 if dtype.is_floating_point else 3) + tests_list = [ + ((2, 3), (2, 2), (2, 3), False), + ((3, 3), (3, 3), (3, 3), False), + ] + tests_with_lhs_broadcasting = [ + ((1,), (2, 2), (2, 3), True), + ((), (2, 2), (2, 3), True), + ] + test_cases = tests_list + tests_with_lhs_broadcasting # type: ignore[operator] + + kwargs = dict(alpha=alpha_val, beta=beta_val) + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape_a, shape_b, shape_c, broadcasts_input in test_cases: + yield SampleInput( + make_arg(shape_a), + make_arg(shape_b), + make_arg(shape_c), + **kwargs, + ).with_metadata(broadcasts_input=broadcasts_input) + + if dtype.is_complex: + shape = (3, 3) + yield SampleInput( + make_arg(shape), + make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad), + make_arg(shape), + **kwargs, + ) + yield SampleInput( + make_arg(shape), + make_arg(shape), + make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad), + **kwargs, + ) + # addmm of empty matrices + if dtype.is_floating_point: + yield SampleInput(make_arg(S, M), make_arg(S, 0), make_arg(0, M), **kwargs) + # empty matmul with broadcastable input + yield SampleInput(make_arg(M), make_arg(S, 0), make_arg(0, M), **kwargs).with_metadata(broadcasts_input=True) + +def sample_inputs_sparse_sampled_addmm(op_info, device, dtype, requires_grad, **kwargs): + alpha = 2 + 3j if dtype.is_complex else 0.6 + beta = 1 + 2j if dtype.is_complex else 0.2 + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # sparse.sampled_addmm performs: alpha * (A @ B) * sparse_ones_like(C) + beta * C + for m, n, k in itertools.product([0, 5], repeat=3): + yield SampleInput( + torch.eye(m, n, device=device, dtype=dtype) + .to_sparse_csr() + .requires_grad_(requires_grad), + make_arg((m, k)), + make_arg((k, n)), + alpha=alpha, + beta=beta, + ) + +def sample_inputs_sparse_mm_reduce(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + reductions = ["sum", "mean", "amax", "amin"] + for m, k, reduce in product([5, 7], [3, 11], reductions): + yield SampleInput( + torch.eye(m, m) + .to(device=device, dtype=dtype) + .to_sparse_csr() + .requires_grad_(requires_grad), + make_arg((m, k)), + reduce, + ) + + +def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(S, M), make_arg(M)) + +def sample_inputs_bmm(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(M, S, M), make_arg(M, M, S)) + +def sample_inputs_dot_vdot(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_arg_conj(size): + return make_arg(size).conj().requires_grad_(requires_grad) + + yield SampleInput(make_arg((S, )), make_arg((S, ))) + if dtype.is_complex: + # dot/vdot for (conj(input), conj(arg_tensor)) and (conj(input), arg_tensor) + # is tested in test_conj_view (which tests operations with only conjugated input tensor + # -- not conjugated arg tensors) + yield SampleInput(make_arg((S, )), make_arg_conj((S, ))) + + +def error_inputs_dot_vdot(op_info, device, is_ref=False, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + + yield ErrorInput(SampleInput(make_input(1), args=(make_input(3, dtype=torch.float16),)), + error_regex='dot : expected both vectors to have same dtype') + yield ErrorInput(SampleInput(make_input(1, 1), args=(make_input(3),)), + error_regex='1D tensors expected') + yield ErrorInput(SampleInput(make_input(9), args=(make_input(3),)), + error_regex='inconsistent tensor size') + if device != "cpu" and not is_ref: + yield ErrorInput(SampleInput(make_input(3), args=(make_input(3, device="cpu"),)), + error_regex='Expected all tensors to be on the same device') + + +def sample_inputs_addmv(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + test_cases = (((S,), (S, M), (M,), 1, 1, False), + ((S,), (S, M), (M,), 0.2, 0.6, False), + ) + + test_cases_with_broadcast = (((1,), (S, M), (M,), 1, 1, True), + ((1,), (S, M), (M,), 0.2, 0.6, True), + ((), (S, M), (M,), 1, 1, True), + ((), (S, M), (M,), 0.2, 0.6, True), + ) + + cases = test_cases + test_cases_with_broadcast + + # addmv performs: beta * M + alpha * (mat @ vec) + for size, mat, vec, beta, alpha, broadcasts_input in cases: + yield SampleInput(make_arg(size), args=(make_arg(mat), make_arg(vec)), + kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=broadcasts_input) + +def sample_inputs_addbmm(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # input_shape, batch1_shape, batch2_shape, beta_val, alpha_val, is_broadcasting + test_cases = [((S, M), (S, S, S), (S, S, M), 1, 1, False), + ((1,), (S, S, S), (S, S, M), 1, 1, True), + ((S, M), (S, S, S), (S, S, M), 0.6, 0.2, False), + ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True), + ((), (S, S, S), (S, S, M), 1, 1, True), + ((), (S, S, S), (S, S, M), 0.6, 0.2, True), + ] + + for input_shape, batch1_shape, batch2_shape, beta, alpha, is_broadcasting in test_cases: + if dtype.is_complex: + beta_complex, alpha_complex = beta * (1 + 2j), alpha * (2 + 3j) + yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)), + kwargs=dict(beta=beta_complex, alpha=alpha_complex), broadcasts_input=is_broadcasting) + yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)), + kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=is_broadcasting) + +def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + test_cases = [(((S, S), (S, S), (S, S)), False), + (((S, S), (S, 1), (1, S)), False), + (((1,), (S, S, 1), (1, S)), True), + (((), (), ()), False), + (((S, S), (), ()), True), + (((), (S, S, 1), (1, S)), True) + ] + + for input_args, broadcasts_input in test_cases: + # addcdiv should accept inputs with zero value + # Currently, it throws ZeroDivisionError when the denominator is zero + # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed + args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg + for arg in input_args) + yield SampleInput(*args).with_metadata(broadcasts_input=broadcasts_input) + + # addcdiv should accept inputs with zero value + # Currently, it throws ZeroDivisionError when the denominator is zero + # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed + args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg + for arg in input_args) + yield SampleInput( + *args, value=3.14 if dtype.is_floating_point or dtype.is_complex else 3 + ).with_metadata(broadcasts_input=broadcasts_input) + +def reference_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_addcmul_addcdiv( + op_info, device, dtype, requires_grad, **kwargs) + + # type promotion cases + supported_dtypes = op_info.supported_dtypes(device) + make_arg = partial(make_tensor, device=device, requires_grad=requires_grad) + + types = ( + (torch.float64, torch.complex128), + (torch.bfloat16, torch.float32), + ) + + values = ( + None, + True, False, + 3.14, 3, + 1.0, 1, + 0.0, 0, + -3.14, -3, + 3.14 + 2.71j, + ) + + for (type2, type3), value in product(types, values): + if (type2 not in supported_dtypes or + type3 not in supported_dtypes): + continue + + # RuntimeError: value cannot be converted without overflow + if (type(value) is complex and + type2 is not torch.complex128): + continue + + arg1 = make_arg([5, 5], dtype=dtype) + arg2 = make_arg([5, 5], dtype=type2) + arg3 = make_arg([1, 5], dtype=type3) + + # TypeError: addcdiv(): argument 'value' must be Number, not NoneType + if value is not None: + yield SampleInput(arg1, args=(arg2, arg3), kwargs=dict(value=value)) + else: + yield SampleInput(arg1, args=(arg2, arg3)) + +def sample_inputs_baddbmm(op_info, device, dtype, requires_grad, **kwargs): + test_cases = [((S, S, M), (S, S, S), (S, S, M), 1, 1, False), + ((1,), (S, S, S), (S, S, M), 1, 1, True), + ((S, S, M), (S, S, S), (S, S, M), 0.6, 0.2, False), + ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True), + ((), (S, S, S), (S, S, M), 1, 1, True), + ((), (S, S, S), (S, S, M), 0.6, 0.2, True), + ] + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + for (input_shape, batch1_shape, batch2_shape, alpha, beta, broadcasts_input) in test_cases: + yield SampleInput( + make_arg(input_shape), + make_arg(batch1_shape), + make_arg(batch2_shape), + beta=beta, + alpha=alpha + ).with_metadata(broadcasts_input=broadcasts_input) + + if dtype.is_complex: + yield SampleInput( + make_arg(input_shape), + make_arg(batch1_shape), + make_arg(batch2_shape), + beta=beta * (1 + 2j), + alpha=alpha * (2 + 3j), + ).with_metadata(broadcasts_input=broadcasts_input) + + if dtype.is_complex: + shapes = [(S, S, S), (S, M, S), (S, S, M)] + args = tuple(make_arg(s) for s in shapes) + yield SampleInput( + args[0].transpose_(-1, 1), + args[1].transpose(-1, 1).conj().requires_grad_(requires_grad), + args[2].transpose(-1, 1).conj().requires_grad_(requires_grad), + beta=beta * (1 + 2j), + alpha=alpha * (2 + 3j), + ) + +# TODO: add reduction kwargs +def sample_inputs_multilabel_soft_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes = ( + (S,), + (S, S), + ) + + for shape in shapes: + # Produce one with weight and one without. + yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),), kwargs={}) + yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),), + kwargs={'weight': _make_tensor(shape, requires_grad=False)}) + +def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None + ) + yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M)) + + yield SampleInput(make_arg(), make_arg(S), make_arg(M)).with_metadata(broadcasts_input=True) + + if dtype.is_complex: + alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j + elif dtype.is_floating_point: + alpha, beta = 0.2, 0.6 + else: + alpha, beta = 2, 3 + + yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M), beta=beta, alpha=alpha) + + yield SampleInput( + make_arg(), + make_arg(S), + make_arg(M), + beta=beta, + alpha=alpha, + ).with_metadata(broadcasts_input=True) + + # These samples fail gradcheck + if dtype.is_floating_point and not requires_grad: + tensor_options = dict(device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput( + torch.tensor([[math.nan]], **tensor_options), + torch.tensor([0.0], **tensor_options), + torch.tensor([0.0], **tensor_options), + beta=0.0, + alpha=0.0, + ).with_metadata(broadcasts_input=True) + + yield SampleInput( + torch.tensor([[0.0]], **tensor_options), + torch.tensor([math.nan], **tensor_options), + torch.tensor([math.nan], **tensor_options), + beta=0.0, + alpha=0.0, + ).with_metadata(broadcasts_input=True) + +def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ((), (S, S, S), (S,)) + + for shape in cases: + yield SampleInput(make_arg(shape)) + +def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) + make_weight = partial(_make_tensor, requires_grad=False) + + inputs = ( + ((), make_target([], low=0, high=1), {}), + ((S,), make_target([], low=0, high=S), {"p": 1}), + ((S,), make_target([1], low=0, high=S), {"p": 2}), + ((S, M), make_target([S], low=0, high=M), {"margin": 1.0}), + ((S, M), make_target([S], low=0, high=M), {"margin": -3.14}), + ((M, S), make_target([M], low=0, high=S), {"weight": None}), + ((M, S), make_target([M], low=0, high=S), {"weight": make_weight([S], low=-10., high=10.)}), + ((M, S), make_target([M], low=0, high=S), {"reduction": "none"}), + ((M, S), make_target([M], low=0, high=S), {"reduction": "mean"}), + ((M, S), make_target([M], low=0, high=S), {"reduction": "sum"}), + ) + + for input_shape, target, kwargs in inputs: + yield SampleInput(_make_tensor(input_shape), args=(target,), kwargs=kwargs) + + +def reference_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs) + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) + make_weight = partial(_make_tensor, requires_grad=False) + + inputs = ( + ((), make_target([], low=0, high=1)), + ((S,), make_target([], low=0, high=S)), + ((S,), make_target([1], low=0, high=S)), + ((M, S), make_target([M], low=0, high=S)), + ) + ps = (1, 2) + margins = (0, 7, -3.14) + weights = (False, True) + reductions = (None, "none", "mean", "sum") + + for (input_shape, target), p, margin, weight, reduction in product(inputs, ps, margins, weights, reductions): + input = _make_tensor(input_shape) + weight_shape = [input.size(-1)] if input.ndim > 0 else [1] + weight = make_weight(weight_shape, low=-10., high=10.) if weight else None + kwargs = {"p": p, "margin": margin, "weight": weight} + if reduction is not None: + kwargs["reduction"] = reduction + yield SampleInput(input, args=(target,), kwargs=kwargs) + + +def error_inputs_multi_margin_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex='abc is not a valid value for reduction') + # invalid input + yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5,),), kwargs={}), + error_type=RuntimeError, + error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]') + yield ErrorInput(SampleInput(make_input(0,), args=(make_input(5,),), kwargs={}), + error_type=RuntimeError, + error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]') + # invalid target + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={}), + error_type=RuntimeError, error_regex=r'inconsistent target size, expected 5 but got \[5, 4\]') + # invalid target dtype + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={}), + error_type=RuntimeError, error_regex='expected scalar type Long but found Float') + # invalid weight + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(())}), + error_type=ValueError, error_regex='weight must be one-dimensional') + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5, 4)}), + error_type=ValueError, error_regex='weight must be one-dimensional') + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5,)}), + error_type=RuntimeError, error_regex=r'inconsistent weight size, expected 4 but got \[5\]') + # invalid p + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'p': 3}), + error_type=ValueError, error_regex='only p == 1 and p == 2 supported') + + +def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs): + inputs = ( + ((), (0,), True), + ((S, S), (1,), True), + ((S, S), (1,), False), + ((S, S), (-2,), False), + ((S, S), (0, 1), False), + ) + # Test large inputs to check numerical stability + lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64, torch.complex64, torch.complex128) else (None,) + for low in lows: + high = low * 2 if low is not None else None + for shape, dim, keepdim in inputs: + t = make_tensor(shape, dtype=dtype, device=device, + low=low, high=high, + requires_grad=requires_grad) + yield SampleInput(t, dim, keepdim) + +def reference_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs) + + # https://github.com/pytorch/pytorch/issues/91843 + t = torch.tensor([20, 30, 100], dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(t, 0, False) + + t = torch.tensor((), dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(t, 0, False) + + # tests masking + # https://github.com/pytorch/pytorch/pull/91860#pullrequestreview-1241344073 + t = torch.tensor(float("inf")) + yield SampleInput(t, 0, True) + +def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + inputs = [ + ((), {}), + ((S, S), {}), + ((0, S, 0), {}), + ((S,), {'dtype': dtype, 'device': device}), + # Hard-code some dtypes/devices. We want to test cases where the + # (dtype, device) is different from the input's (dtype, device) + ((S,), {'dtype': torch.double if device != 'mps:0' else torch.float}), + ((S,), {'device': 'cpu'}), + ((S,), {'dtype': torch.double, 'device': 'cpu'}), + ] + if torch.cuda.is_available(): + inputs.append(((S,), {'device': 'cuda'})) + + for shape, kwargs in inputs: + t = make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(t, **kwargs) + +def reference_inputs_like_fns(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_like_fns(op, device, dtype, requires_grad, **kwargs) + + # shape + cases = ( + (), (0,), (1, 0), (1, 1, 4, 5), (5, 3, 0, 1), (1, 4, 3, 1, 1) + ) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in cases: + yield SampleInput(make_arg(shape)) + yield SampleInput(make_arg(shape).transpose(0, -1)) + yield SampleInput(make_arg(shape, noncontiguous=True)) + yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1)) + +def sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) + + inputs = ( + ([], make_target([], low=0, high=1), {}), + ([S], make_target([S], low=0, high=S), {}), + ([M, S], make_target([M, S], low=0, high=S), {}), + ([M, S], make_target([M, S], low=0, high=S), {"reduction": "none"}), + ([M, S], make_target([M, S], low=0, high=S), {"reduction": "mean"}), + ([M, S], make_target([M, S], low=0, high=S), {"reduction": "sum"}), + ) + + for shape, target, kwargs in inputs: + yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs) + + +def reference_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs) + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) + make_target_tensor = partial(torch.tensor, device=device, dtype=torch.long, requires_grad=False) + + inputs = ( + # random tests including -1 target labels + ([], make_target([], low=-1, high=1)), + ([S], make_target([S], low=-1, high=S)), + ([M, S], make_target([M, S], low=-1, high=S)), + # repeated target labels and -1 (labels after the first -1 are ignored) + ([], make_target_tensor(-1)), + ([7], make_target_tensor([2, 0, 6, -1, 4, -1, 6])), + ([4, 5], make_target_tensor([[4, -1, 0, -1, 2], [0, 0, 4, 1, 4], [-1, 3, -1, 1, 0], [4, 3, 2, 1, 0]])), + ) + reductions = (None, "none", "mean", "sum") + + for (shape, target), reduction in product(inputs, reductions): + kwargs = {} + if reduction is not None: + kwargs["reduction"] = reduction + yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs) + + +def error_inputs_multilabel_margin_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex='abc is not a valid value for reduction') + # invalid input + yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5, 4),), kwargs={}), + error_type=RuntimeError, + error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]') + yield ErrorInput(SampleInput(make_input(0,), args=(make_input(0,),), kwargs={}), + error_type=RuntimeError, + error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]') + # invalid target + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(4,),), kwargs={}), + error_type=RuntimeError, + error_regex=r'inconsistent target size: \[4\] for input of size: \[5, 4\]') + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input((),),), kwargs={}), + error_type=RuntimeError, + error_regex=r'inconsistent target size: \[\] for input of size: \[5, 4\]') + + +def get_independent_tensor(tensor): + return tensor.clone().requires_grad_(tensor.requires_grad) + +def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + sample.kwargs.setdefault('device', device) + # With high + yield SampleInput(high, sample.input.shape, *sample.args, **sample.kwargs) + # With low and high + yield SampleInput(low, high, sample.input.shape, *sample.args, **sample.kwargs) + +def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With high + yield SampleInput( + sample.input, + high, + *sample.args, + **sample.kwargs) + # With low and high + yield SampleInput( + get_independent_tensor(sample.input), + low, + high, + *sample.args, + **sample.kwargs) + +def sample_inputs_margin_ranking_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes = ( + (), + (S,), + (S, S), + (S, S, S), + ) + + margins = (0., 1.) + reductions = ('sum', 'mean', 'none') + + for shape in shapes: + for margin, reduction in product(margins, reductions): + kwargs = {'margin': margin, 'reduction': reduction} + yield SampleInput(_make_tensor(shape), + args=(_make_tensor(shape, requires_grad=False), + _make_tensor(shape, requires_grad=False)), + kwargs=kwargs) + +def reference_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs) + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for reduction in ('sum', 'mean', 'none'): + if dtype.is_floating_point: # only supports ints and floats + # NaN propagation + inp1 = make_input((10, )) + inp1[2] = float('nan') + inp2 = make_input((10, )) + inp2[4] = float('nan') + target = make_input((10, )) + inp2[9] = float('nan') + yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction}) + + # Inf handling + inp1 = make_input((10, )) + inp2[1] = float('inf') + inp2 = make_input((10, )) + inp2[4] = float('inf') + target = make_input((10, )) + inp2[7] = float('inf') + yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction}) + + # Broadcasting + inp1 = make_input((5, 2)) + inp2 = make_input((5, 1)) + target = make_input((1, 2)) + yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction}) + +def error_inputs_margin_ranking_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction value. + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5, 4),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex='is not a valid value') + # invalid input shapes + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5,),)), + error_regex='margin_ranking_loss : All input tensors should') + +def sample_inputs_new_fns(self, device, dtype, requires_grad, *, is_strided=False, **kwargs): + # input_shape, output_shape, strides, kwargs + # lengths of output_shape and strides must be equal + inputs = [ + ((), (), (), {}), + ((S, S), (2, 0), (3, 4), {}), + ((0, S, 0), (3, 2, 2), (1, 2, 3), {}), + ((S,), (2, 3), (7, 8), {'dtype': dtype, 'device': device}), + # Hard-code some dtypes/devices. We want to test cases where the + # (dtype, device) is different from the input's (dtype, device) + ((S,), (10,), (S,), {'dtype': torch.double if device != 'mps:0' else torch.float}), + ((S,), (1, 1, 12), (S, L, M), {'device': 'cpu'}), + ((S,), (2, 2, 2), (L, M, S), {'dtype': torch.double, 'device': 'cpu'}), + ] + if torch.cuda.is_available(): + inputs.append(((S,), (7, 2), (3, 4), {'device': 'cuda'})) + + for input_shape, output_shape, strides, kwargs in inputs: + t = make_tensor(input_shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + if is_strided: + yield SampleInput(t, output_shape, strides, **kwargs) + else: + yield SampleInput(t, output_shape, **kwargs) + +def sample_inputs_empty_strided(op, device, dtype, requires_grad=False, **kwargs): + + inputs = [ + ((), (), {'dtype': dtype, 'device': device}), + ((S,), (4,), {'dtype': dtype, 'device': device}), + ((S, S), (2, 1), {'dtype': dtype, 'device': device}), + ((S, S, S), (2, 0, 1), {'dtype': dtype, 'device': device}), + ] + + for shape, strides, kwargs in inputs: + yield SampleInput(shape, strides, requires_grad=requires_grad, **kwargs) + +def sample_inputs_empty(op, device, dtype, requires_grad, **kwargs): + # shape + cases = ( + (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1), + ) + + for case in cases: + yield SampleInput(case, device=device, dtype=dtype, requires_grad=requires_grad) + +def sample_inputs_empty_permuted(op, device, dtype, requires_grad, **kwargs): + # shape + cases = ( + (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1), + ) + + for case in cases: + for layout in itertools.permutations(range(len(case))): + yield SampleInput(case, layout, device=device, dtype=dtype, requires_grad=requires_grad) + +def error_inputs_empty_permuted(op_info, device, **kwargs): + yield ErrorInput( + SampleInput((2,), args=((0, 1),)), + error_type=RuntimeError, + error_regex="Number of dimensions in size does not match the length of the physical_layout" + ) + yield ErrorInput( + SampleInput((2,), args=((3,),)), + error_type=RuntimeError, + error_regex="Dimension out of range" + ) + yield ErrorInput( + SampleInput((2, 3), args=((0, 0),)), + error_type=RuntimeError, + error_regex="Duplicate dim not allowed" + ) + +def sample_inputs_scalar_tensor(op, device, dtype, requires_grad, **kwargs): + # Not including a scalar tensor in vals because meta tests start failing due to + # lack of meta support for _local_scalar_dense + # torch.tensor(2, device=device) + vals = (-5, 0, 1) + + for item in vals: + yield SampleInput(item, device=device, dtype=dtype, requires_grad=requires_grad) + +def sample_inputs_eye(op, device, dtype, requires_grad, **kwargs): + # only ints >= 0 are allowed for both arguments, unless m is omitted + sizes = (None, 0, 1, 2, 3, 4, 7, L, M, S) + + for n, m in product(sizes, sizes): + if n is None: + continue + + # TODO: no layout + _kwargs = {'device': device, 'dtype': dtype, 'requires_grad': requires_grad} + if m is None: + yield SampleInput(n, args=(), kwargs=_kwargs) + else: + yield SampleInput(n, args=(m,), kwargs=_kwargs) + +def error_inputs_eye(op_info, device, **kwargs): + # TODO: no layout + _kwargs = {'device': device, 'dtype': torch.float32} + + yield ErrorInput( + SampleInput(-1, args=(), kwargs=_kwargs), + error_regex="n must be greater or equal to 0, got -1" + ) + + yield ErrorInput( + SampleInput(-7, args=(42,), kwargs=_kwargs), + error_regex="n must be greater or equal to 0, got -7" + ) + + yield ErrorInput( + SampleInput(0, args=(-3,), kwargs=_kwargs), + error_regex="m must be greater or equal to 0, got -3" + ) + + +def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs): + def get_val(dtype): + return make_tensor([], dtype=dtype, device="cpu").item() + + for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs): + # The scalar we are passing to new_full must be the same dtype + # as the one of the resulting tensor + use_dtype = sample.kwargs.get('dtype', dtype) + yield SampleInput( + sample.input, *sample.args, get_val(use_dtype), **sample.kwargs) + +def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs): + def get_val(dtype): + return make_tensor([], dtype=dtype, device="cpu").item() + + double_dtype = torch.double if device != "mps:0" else torch.float + inputs = [ + ((), get_val(dtype), {}), + ((S, S), get_val(dtype), {}), + ((0, S, 0), get_val(dtype), {}), + ((S,), get_val(dtype), {'dtype': dtype, 'device': device}), + # Hard-code some dtypes/devices. We want to test cases where the + # (dtype, device) is different from the input's (dtype, device) + ((S,), get_val(double_dtype), {'dtype': double_dtype}), + ((S,), get_val(dtype), {'device': 'cpu'}), + ((S,), get_val(double_dtype), {'dtype': double_dtype, 'device': 'cpu'}), + ] + if torch.cuda.is_available(): + inputs.append(((S,), get_val(dtype), {'device': 'cuda'})) + + if torch.mps.is_available() and dtype not in [torch.float64, torch.complex128, torch.uint32, torch.uint16]: + inputs.append(((S,), get_val(dtype), {'device': 'mps'})) + + if not dtype.is_signed: + # For unsigned dtypes, negative values are converted. + inputs.append(((S,), -get_val(dtype), {})) + + for shape, fill_value, kwargs in inputs: + t = make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(t, fill_value, **kwargs) + +def sample_inputs_multinomial(self, device, dtype, requires_grad, **kwargs): + cases = [ + ([3], 3, {}), + ([10], 3, {}), + ([3, 10], 3, {}), + ([3], 3, dict(replacement=False)), + ([3], 3, dict(replacement=True)), + ([3, 4], 4, dict(replacement=True)), + ([3, 4], 4, dict(replacement=False)), + ] + + for shape, num_samples, kwargs in cases: + t = make_tensor(shape, dtype=dtype, device=device, + low=0, high=None, + requires_grad=requires_grad) + yield SampleInput(t, num_samples, **kwargs) + +def sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs): + def get_value_or_make_tensor(value_or_shape): + if isinstance(value_or_shape, list): + return make_tensor(value_or_shape, dtype=dtype, device=device, + low=0, high=None, + requires_grad=requires_grad) + return value_or_shape + + for value_or_mean_shape, value_or_std_shape, kwargs in cases: + mean = get_value_or_make_tensor(value_or_mean_shape) + std = get_value_or_make_tensor(value_or_std_shape) + yield SampleInput(mean, std, **kwargs) + +def sample_inputs_normal_tensor_first(self, device, dtype, requires_grad, **kwargs): + # value_or_size, value_or_size, kwargs + cases = [ + ([], [], {}), + ([3], [3], {}), + ([3, 4, 2], [3, 4, 2], {}), + ([2, 3], 1.1, {}), + ([1, 2, 3], [5, 2, 3], {}), # broadcasting + ] + + return sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs) + +def sample_inputs_normal_tensor_second(self, device, dtype, requires_grad, **kwargs): + yield SampleInput(1.6, 0.3, [2, 3], dtype=dtype, device=device) + yield SampleInput(1.6, 0.3, [2, 2, 2], dtype=dtype, layout=torch.strided, device=device) + yield SampleInput(2.7, make_tensor([4, 3], dtype=dtype, device=device, low=0, high=None, requires_grad=requires_grad)) + +def sample_inputs_bernoulli(self, device, dtype, requires_grad, **kwargs): + shapes = [ + [3], + [], + [0, 3], + [2, 3, 4], + ] + + for shape in shapes: + t = make_tensor(shape, dtype=dtype, device=device, + low=0, high=1, + requires_grad=requires_grad) + yield SampleInput(t) + +def error_inputs_bernoulli(op_info, device, **kwargs): + # more than one element of the written-to tensor refers to a single memory location + x = torch.rand((1,), device=device).expand((6,)) + err_msg = 'unsupported operation' + yield ErrorInput(SampleInput(torch.rand_like(x), kwargs={'out': x}), + error_regex=err_msg) + +def sample_inputs_logcumsumexp(self, device, dtype, requires_grad, **kwargs): + inputs = ( + ((S, S, S), 0), + ((S, S, S), 1), + ((), 0), + ) + + for large_number in (True, False): + for shape, dim in inputs: + t = make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + + if large_number and t.dim() > 0: + t[0] = 10000 + yield SampleInput(t, dim) + +def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs): + yield SampleInput( + make_tensor((S, S), dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad)) + + +def error_inputs_trace(op, device): + yield ErrorInput(SampleInput(make_tensor((3, 4, 5), dtype=torch.float32, device=device)), error_regex="expected a matrix") + + +def sample_inputs_renorm(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + cases = (((S, S, S), (2, 1, 0.5)), + ((S, S, S), (2, -1, 0.5)), + ((S, S, S), (1, 2, 3)), + ((S, S, S), (float('inf'), 2, 0.5)), + ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + +def sample_inputs_transpose_swapdims(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((1, 2, 3), (-1, -2)), + ((1, 2, 3), (-1, 2)), + ((1, 2, 3), (1, -2)), + ((1, 2, 3), (1, 2)), + ((), (0, 0)), + ((1, ), (0, 0)), + ((M, M), (0, 1)), + ((S, S, S), (2, 0)), ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + +def _numpy_ref_transpose(a, dim0, dim1): + if a.ndim <= 1: + return a + + return np.swapaxes(a, dim0, dim1) + +def sample_inputs_adjoint(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + shapes = ((1, 2, 3), (M, M), (S, S, S), (S, M, S), (M, S, M, S)) + return (SampleInput(make_arg(shape)) for shape in shapes) + +def sample_inputs_T(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + shapes = ((M, M), (M, L)) + return (SampleInput(make_arg(shape)) for shape in shapes) + +def error_inputs_T(self, device, has_ndims_error=False): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # Deprecated behavior in regular PyTorch, but throws an error in primTorch: + # https://github.com/pytorch/pytorch/issues/86968 + if has_ndims_error: + # ndims == 1 + yield ErrorInput(SampleInput(make_arg(M)), + error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 ' + r'to reverse their shape is not supported\.')) + + # ndims > 2 + yield ErrorInput(SampleInput(make_arg(M, S, L)), + error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 ' + r'to reverse their shape is not supported\.')) + + +def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=False): + """ + This function produces two tensors of shape (*, m, k) and (*, n, k) with k <= min(m, n). + Their matrix product could be used to generate tensor of shape (*, m, n) of rank k. + """ + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batches = [(), (2,)] + size = [3, 4] + for batch, m, n in product(batches, size, size): + k = 2 + a = make_arg((*batch, m, k)) + b = make_arg((*batch, n, k)) + yield a, b + + +def sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad=False, **kwargs): + # Function that's well defined on the outputs for complex inputs + def fn(usv): + U, S, V = usv + return U @ V.mH, S + + for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad): + *batch, m, k = a.shape + n = b.shape[-2] + + # NOTE: since svd_lowrank relies on non rank-revealing SVD, + # it inherits the problem of unstable behavior with repeated + # singular values including zeros. + # Since we want to avoid (repeated) zeros as singular values, + # we can only use k for q. + # This issues could be resolved with using a rank-revealing SVD + # which does not include "zero" singular values. + yield SampleInput(a, b, q=k, M=None).with_metadata(output_process_fn_grad=fn) + + for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad): + *batch, m, k = a.shape + n = b.shape[-2] + M = make_tensor((*batch, m, n), dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(a, b, q=k, M=M).with_metadata(output_process_fn_grad=fn) + +def chunk_iter(iterable, size): + it = iter(iterable) + while True: + chunk = tuple(islice(it, size)) + if not chunk: + break + yield chunk + +def sample_inputs_pca_lowrank(op_info, device, dtype, requires_grad=False, **kwargs): + # we reuse samples from svd_lowrank which come in group of two with + # kwarg['M'] = None and with kwarg['M'] = + samples = sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad, **kwargs) + for s1, s2 in chunk_iter(samples, 2): + del s1.kwargs['M'] + del s2.kwargs['M'] + s1.kwargs['center'] = False + s2.kwargs['center'] = True + yield s1 + yield s2 + +def np_sinc_with_fp16_as_fp32(x): + # Wraps numpy's sinc function so that fp16 values are promoted to fp32 + # before sinc is invoked. Context: numpy's sinc returns NaN when evaluated + # at 0 for fp16. + if x.dtype == np.float16: + return np.sinc(x.astype(np.float32)) + else: + return np.sinc(x) + +def sample_inputs_broadcast_to(op_info, device, dtype, requires_grad, **kwargs): + test_cases = ( + ((S, 1, 1), (S, S, S)), + ((S, 1, S), (S, S, S)), + ((S, 1), (S, S, S)), + ((1,), (S, S, S)), + ((1, S), (1, 1, S)), + ((), ()), + ((), (1, 3, 2)), + ) + + return ( + SampleInput( + make_tensor(size, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad), + shape, + ) for size, shape in test_cases) + +def sample_inputs_broadcast_tensors(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + test_cases: tuple[tuple] = (((3,), (1, 2, 1), (1, 1), (5, 1, 1),),) + + for shape, *other_shapes in test_cases: + yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes)) + +def reference_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs) + + m = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + n = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True) + + cases = ( + ((), (1, 1), (1, 1, 7, 1), (3, 1, 1)), + ((3, 5, 6), (1, 3, 5, 6), (1, 1, 1, 1, 6), (8, 3, 5, 6)) + ) + + for a, b, c, d in cases: + yield SampleInput(m(a), args=(m(b), m(c), m(d))) + yield SampleInput(n(a), args=(n(b), n(c), n(d))) + +def sample_inputs_block_diag(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + test_cases: tuple[tuple] = ( + ((1, S), (2, S), (3, S),), + ((S, 1), (S, 2), (S, 3),), + ((1,), (2,), (3,),), + ((2, S), (S,)) + ) + + for shape, *other_shapes in test_cases: + yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes)) + # We also want to test mixed complex-non-complex inputs to block_diag + if dtype == torch.complex32 or dtype == torch.complex64: + non_complex_dtype = torch.float32 if dtype == torch.complex32 else torch.float64 + make_arg_non_complex = partial(make_tensor, dtype=non_complex_dtype, device=device, requires_grad=requires_grad) + yield SampleInput(make_arg_non_complex(shape), args=tuple(make_arg(s) for s in other_shapes)) + +def sample_inputs_cdist(op_info, device, dtype, requires_grad, **kwargs): + small_S = 2 + test_cases = ( + ((S, S, 2), (S, S + 1, 2)), + ((S, S), (S, S)), + ((S, S, S), (S, S, S)), + ((3, 5), (3, 5)), + ((2, 3, 5), (2, 3, 5)), + ((1, 2, 3), (1, 2, 3)), + ((1, 1), (S, 1)), + ((0, 5), (4, 5)), + ((4, 5), (0, 5)), + ((0, 4, 5), (3, 5)), + ((4, 5), (0, 3, 5)), + ((0, 4, 5), (1, 3, 5)), + ((1, 4, 5), (0, 3, 5)), + # Using S here would make this one test take 9s + ((small_S, small_S, small_S + 1, 2), (small_S, small_S, small_S + 2, 2)), + ((small_S, 1, 1, small_S), (1, small_S, small_S)), + ((1, 1, small_S), (small_S, 1, small_S, small_S)), + ) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + # FIXME add an override for JIT and revert 0. back to 0 + # since it's accepted by eager + for p in [0., 1., 2., 3., 0.5, 1.5, 2.5, float("inf")]: + for t1_size, t2_size in test_cases: + # The args should never be non-contiguous as this is not supported in the backward + yield SampleInput(make_arg(t1_size), make_arg(t2_size), p, cm) + +def _fill_np(a, value): + a = a.copy() + a.fill(value) + return a + +def _fill_sample_kwargs(device, dtype, input): + if dtype is torch.bool: + value = True + else: + value = 3 + + return ({'value': value}, {'value': value}) + +def sample_inputs_comparison_ops(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) + + # Adds a sample input where both tensors have the same values + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + lhs = make_arg((S, S)) + yield SampleInput(lhs, args=(lhs.clone(),)) + +def sample_inputs_stack(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # shape x number of tensors + cases = ( + ((3, 4), 1), + ((1, 2, 1, 4), 3), + ((0, 1, 0), 2),) + + for shape, num_tensors in cases: + tensors = [make_arg(shape) for _ in range(num_tensors)] + for dim in range(-1, len(shape) - 1): + yield SampleInput(tensors, args=(dim,)) + + +def sample_inputs_chunk_cat(op_info, device, dtype, requires_grad, **kwargs): + # 1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors. + # If all input tensors have the same ndims, we support both negative and non-negative dim. + # 2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions. + # No requirements for (wrapped_dim, ...)-th dimension. + # 3. Expect positive num_chunks + # 4. Expect non-empty input tensor list and each input tensor should have at least 1 element + # 5. Non-contiguous input tensors are allowed. + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + same_ndim_cases = ( + ( + [ + torch.Size([1, 2, 3]), + torch.Size([1, 2, 3]), + ], -1, 5 + ), + ( + [ + torch.Size([1, 2, 129]), + torch.Size([1, 2, 297]), + ], -1, 5 + ), + ( + [ + torch.Size([1, 2, 3]), + torch.Size([1, 2, 3]), + ], 1, 5 + ), + ( + [ + torch.Size([3, 3, 2, 1]), + torch.Size([1, 4, 2, 2]), + torch.Size([2, 1, 3, 3]), + ], 0, 2 + ), + ) + for sizes, dim, num_chunks in same_ndim_cases: + tensors = [make_arg(size) for size in sizes] + yield SampleInput(tensors, args=(dim, num_chunks)) + + different_ndim_case = [ + torch.Size([2, 3, 3]), + torch.Size([2, 3, 1, 2]), + torch.Size([2, 3]), + torch.Size([2, 3, 2]), + torch.Size([2, 3, 271]), + ] + max_dim, num_chunks = 2, 3 + for dim in range(max_dim): + tensors = [] + for size in different_ndim_case: + tensors.append(make_arg(size)) + yield SampleInput(tensors, args=(dim, num_chunks)) + + # non-contiguous + for dim in range(max_dim): + tensors = [] + for size in different_ndim_case: + # make the last 2 dims column-major (i.e. non-contiguous) + t = make_arg(size).transpose(-2, -1).contiguous().transpose(-2, -1) + tensors.append(t) + yield SampleInput(tensors, args=(dim, num_chunks)) + +def error_inputs_chunk_cat(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # input tensors have different ndims but dim is negative + sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], -1, 3 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects non-negative dim when input tensors have different ndims', + ) + + # input tensors have different ndims but dim >= ndim of some input tensors + sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], 1, 3 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects dim < ndim for all input tensors', + ) + + # some tensors have different sizes for 0, ..., dim-1 dimensions. + sizes, dim, num_chunks = [torch.Size([2, 3, 4]), torch.Size([4, 3])], 1, 3 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors', + ) + + # negative num_chunks + sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, -1 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects positive num_chunks', + ) + + # zero as num_chunks + sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, 0 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects positive num_chunks', + ) + + # empty input tensor list + dim, num_chunks = 0, 1 + yield ErrorInput( + SampleInput([], args=(dim, num_chunks)), + error_regex='_chunk_cat expects a non-empty input tensor list', + ) + + # empty input tensor with 0 elements + sizes, dim, num_chunks = [torch.Size([0,]), torch.Size([3,])], 0, 1 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects non-empty tensor', + ) + + +def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases: tuple[tuple, tuple, dict] = ( # type: ignore[assignment] + ((S, S), (S, S), {'dim': -1}), + ((S, S), (S, S), {'dim': 1}), + ((M, S), (S, S), {'dim': 0}), # different shapes + ((1, 2, 3), (1, 2, 3), {'dim': -2}), + ((0,), (0,), {'dim': 0}), # empty tensor + ((0,), (S, S), {'dim': 1}), # empty tensor with unempty and dim=1 (special case for legacy_cat_wrap_dim) + ((0, S), (S, S), {'dim': 0}), + ((1,), (1,), {}) # dim not passed, fallback to default + ) + + for input_shape1, input_shape2, kwargs in cases: + yield SampleInput([make_arg(input_shape1), make_arg(input_shape2)], kwargs=kwargs) + + # from coat_lite_mini + yield SampleInput([make_arg((2, 2, 2, 2), memory_format=torch.channels_last)], args=(1,),) + +def error_inputs_cat(op_info, device, **kwargs): + + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for more than one element of the written-to tensor refer to a single memory location + yield ErrorInput(SampleInput([make_arg((S, S)), make_arg((S, S))], + kwargs={'out': make_arg((1, S)).expand((2 * S, S))}), + error_regex='unsupported operation') + + # error inputs for empty tensors + yield ErrorInput(SampleInput([], kwargs={'dim': 1}), + error_regex='non-empty list of Tensors', error_type=ValueError) + + # error inputs for different sizes + yield ErrorInput(SampleInput([make_arg((S, S, L, L)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}), + error_regex='Sizes of tensors must match except in dimension') + yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S, S, L, L))], kwargs={'dim': 1}), + error_regex='Sizes of tensors must match except in dimension') + + # error inputs for different dimensions + yield ErrorInput(SampleInput([make_arg((S - 1, 0)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}), + error_regex='Tensors must have same number of dimensions') + yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S - 1, 0))], kwargs={'dim': 1}), + error_regex='Tensors must have same number of dimensions') + + # error inputs for same memory locations + x = torch.zeros((0), device=device) + y = torch.randn((4, 6), device=device) + + err_msg = "the written-to tensor refer to a single memory location" + + yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': x}), + error_regex=err_msg) + yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': y}), + error_regex=err_msg) + + z = torch.zeros((4, 6), device=device) + yield ErrorInput(SampleInput((y, z), kwargs={'out': z[:2, :]}), + error_regex=err_msg) + + # error inputs for different devices + if torch.device(device).type == 'cuda': + x_cuda = make_tensor((3, 3), device=device, dtype=torch.float32) + y_cpu = make_tensor((3, 3), device='cpu', dtype=torch.float32) + yield ErrorInput(SampleInput((x_cuda, y_cpu)), + error_regex='Expected all tensors to be on the same device') + + # error inputs for different input sizes for more than 2 tensors + yield ErrorInput(SampleInput([make_arg((L, 1)), make_arg((L, 1, 1)), make_arg((L, 1, 1))]), + error_regex='Tensors must have same number of dimensions') + + yield ErrorInput(SampleInput([make_arg((S, 1, M)), make_arg((S, 1, 1)), make_arg((S, M, 1))], + kwargs={'dim': 1}), + error_regex='Sizes of tensors must match') + + # error inputs for None input + yield ErrorInput(SampleInput((make_arg((S, 1, 1)), None)), error_type=TypeError, + error_regex='got None') + + # error inputs for zero-dimensional tensors + yield ErrorInput(SampleInput([make_arg(()), make_arg(())]), + error_regex='zero-dimensional.*cannot be concatenated') + + # error inputs for different dtype of out tensors + d = make_tensor((2, 3), device=device, dtype=torch.double if not device.startswith("mps") else torch.float16) + x = make_tensor((2, 3), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'out': d}), error_type=TypeError, + error_regex='invalid combination of arguments') + +def reference_inputs_cat(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_cat_concat(op, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Noncontiguous type promoting tensors + a = make_arg((3, 4, 2)) + b = make_arg((3, 2, 2), noncontiguous=True, dtype=torch.double) + c = make_arg((3, 3, 2), dtype=torch.float16).permute(1, 0, 2) + + yield SampleInput((a, b, c), kwargs={'dim': 1}) + + # Special 1D tensor with dim length of 0 case + a = make_arg((0,)) + b = make_arg((3, 2, 2)) + + yield SampleInput((a, b, a)) + yield SampleInput((a, a, a)) + +def _elementwise_type_promo_np(*args, type_promotion_kind): + def _maybe_torch(x): + if isinstance(x, np.ndarray): + return torch.from_numpy(x) + return x + + flattened = pytree.arg_tree_leaves(*args) + transformed = tuple(_maybe_torch(a) for a in flattened) + result_dtype, _ = prims.utils.elementwise_dtypes( + *transformed, + type_promotion_kind=type_promotion_kind) + return torch_to_numpy_dtype_dict[result_dtype] + +def _cat_np(input_seq, dim=0): + inputs = tuple(a for a in input_seq if not (a.ndim == 1 and a.size == 0)) + + if len(inputs) == 0: + np_dtype = _elementwise_type_promo_np( + input_seq, + type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH) + return np.empty(0, dtype=np_dtype) + + return np.concatenate(inputs, axis=dim) + +def _floor_divide_np(a, b): + dtype = _elementwise_type_promo_np( + a, + b, + type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) + if isinstance(a, np.ndarray): + a = a.astype(dtype) + if isinstance(b, np.ndarray): + b = b.astype(dtype) + return np.floor_divide(a, b) + +def sample_inputs_hstack_dstack_vstack(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + tensor_shapes = ( + # First Tensor being 1-D is special + # case for hstack + ((S,), (S,), (S,)), + ((S, S), (S, S), (S, S)), + ) + for s1, s2, s3 in tensor_shapes: + tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3)) + yield SampleInput(tensors) + +def error_inputs_hstack_dstack_vstack(op, device): + make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False) + tensor_shapes = ( + ((S,), (S, S, S, S), (S,)), + ) + for s1, s2, s3 in tensor_shapes: + tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3)) + # Different dimension tensor + yield ErrorInput(SampleInput(tensors), error_regex="Tensors must have same number of dimensions") + + # empty tensor list + yield ErrorInput(SampleInput(()), error_regex="expects a non-empty TensorList") + +def sample_inputs_unbind(op_info, device, dtype, requires_grad, **kwargs): + # Note: we don't do any tests where we unbind along 0-length dims + # because in that case unbind returns and empty tuple, and that breaks + # some assumptions in some backward tests in test_ops.py + shape_dims = (((S,), 0), + ((S, S), 0), + ((S, S), 1), + ((S, S), -1), + ((S, 0, S), 0), + ((S, S, S), 1), + ) + for shape, dim in shape_dims: + yield SampleInput(make_tensor(shape, dtype=dtype, device=device, + requires_grad=requires_grad), + args=(dim,)) + +def error_inputs_unbind(op_info, device): + make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False) + yield ErrorInput(SampleInput(make_arg(()), args=(0,)), error_type=IndexError, + error_regex="Dimension specified as 0 but tensor has no dimensions") + yield ErrorInput(SampleInput(make_arg((2,)), args=(2,)), error_type=IndexError, + error_regex="Dimension out of range") + +def reference_unbind(t, dim): + """A numpy implementation of torch.unbind""" + return tuple(s.squeeze(dim) for s in np.split(t, t.shape[dim], dim)) + +def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + yield SampleInput( + make_arg((M, S)), + 0, + gather_variable((S, S), 1, M, True, device=device)) + yield SampleInput( + make_arg((M, S)), + 0, + gather_variable((S, S), 1, M, True, device=device).to(torch.int32)) + yield SampleInput( + make_arg((M, S)), + 1, + gather_variable((M, S // 2), 0, S, True, device=device)) + # Empty index tensor case, see: https://github.com/pytorch/pytorch/pull/65006 + yield SampleInput( + make_arg((S,)), + 0, + torch.tensor([], dtype=torch.uint8, device=device)) + yield SampleInput( + make_arg((S,)), + 0, + torch.tensor([[], []], dtype=torch.uint8, device=device)) + # 0D tensor case + yield SampleInput( + make_arg(()), + 0, + torch.tensor([0], dtype=torch.int64, device=device)) + yield SampleInput( + make_arg(()), + 0, + torch.tensor(0, dtype=torch.int64, device=device)) + +def _fill_indices(idx, dim, dim_size, elems_per_row, m, n, o): + for i in range(1 if dim == 0 else m): + for j in range(1 if dim == 1 else n): + for k in range(1 if dim == 2 else o): + ii = [i, j, k] + ii[dim] = slice(0, idx.size(dim) + 1) + idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row] + +def error_inputs_gather(op_info, device, **kwargs): + # src is [1, 2] + # [3, 4] + src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) + + # idx is [0, 0] + # [1, 0] + idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) + + # Index should be smaller than self except on dimension 1 + bad_src = make_tensor((1, 1), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(bad_src, args=(1, idx,)), + error_regex="Size does not match at dimension 0") + + # TODO: FIXME + # out.dtype must match src.dtype + # Creates new src & idx since SampleInputs can't share tensors + src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) + out = torch.empty((2, 2), device=device, dtype=torch.float64) + yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}), + error_regex="Expected out tensor to have dtype") + + # src and index tensors must have the same # of dimensions + # idx too few dimensions + src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) + idx = torch.tensor((0, 0), device=device, dtype=torch.long) + yield ErrorInput(SampleInput(src, args=(1, idx)), + error_regex="Index tensor must have the same number of dimensions") + + # src too few dimensions + src = torch.tensor((1, 2), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) + yield ErrorInput(SampleInput(src, args=(0, idx)), + error_regex="Index tensor must have the same number of dimensions") + + # index out of bounds + # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices + if torch.device(device).type == 'cpu': + src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 23), (1, 0)), device=device, dtype=torch.long) + yield ErrorInput(SampleInput(src, args=(1, idx,)), + error_regex="index 23 is out of bounds for dimension") + + x = torch.rand((1,), device=device).expand((3,)) + src = torch.rand((6,), device=device) + ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) + + yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=x)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=src)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(ind.clone(), args=(0, ind[1:],), kwargs=dict(out=ind[:1])), + error_type=RuntimeError, + error_regex='unsupported operation') + +def error_inputs_take(op_info, device, **kwargs): + x = torch.rand((1,), device=device).expand((3,)) + src = torch.rand((6,), device=device) + ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) + + yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=x)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=src)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(ind.clone(), args=(ind[1:],), kwargs=dict(out=ind[:-1])), + error_type=RuntimeError, + error_regex='unsupported operation') + +# Error inputs for scatter +def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs): + # Error when self.dtype != src.dtype (and src is not a scalar) + src = make_tensor((2, 5), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long) + dst = torch.zeros((3, 5), device=device, dtype=torch.double) + yield ErrorInput(SampleInput(dst, args=(0, idx, src)), + error_regex="Expected self.dtype to be equal to src.dtype") + + # Index and destination must have the same number of dimensions + src = make_tensor((2, 5), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long) + dst = torch.zeros((3, 5, 3), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(dst, args=(0, idx, src)), + error_regex="Index tensor must have the same number of dimensions as self tensor") + + # Index and src must have the same number of dimensions when src is not a scalar + src = make_tensor((2, 5, 2), device=device, dtype=torch.float32) + idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long) + dst = torch.zeros((3, 5), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(dst, args=(0, idx, src)), + error_regex="Index tensor must have the same number of dimensions as src tensor") + + # Index out of bounds + # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices + if torch.device(device).type == 'cpu': + src = make_tensor((2, 5), device=device, dtype=torch.float32) + idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long) + dst = torch.zeros((3, 5), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(dst, args=(0, idx, src)), + error_regex="index 34 is out of bounds for dimension 0 with size 3") + +def error_inputs_renorm(op_info, device, **kwargs): + zero_d = torch.randn((), device=device) + yield ErrorInput(SampleInput(zero_d, args=(0.5, 0, 1.0)), error_type=RuntimeError, + error_regex="needs at least 2 dimensions, got 0 dimensions") + + +def error_inputs_ormqr(op_info, device, **kwargs): + zero_d = torch.randn((), device=device) + yield ErrorInput(SampleInput(zero_d, args=(zero_d, zero_d)), error_type=RuntimeError, + error_regex="input must have at least 2 dimensions") + + # https://github.com/pytorch/pytorch/issues/85218 + tensor_0 = torch.full((5, 0,), 1, device=device) + tensor_1 = torch.full((5,), 1, device=device) + tensor_2 = torch.full((5, 5,), 1, device=device) + bool_3 = True + bool_4 = True + yield ErrorInput(SampleInput(tensor_0, args=(tensor_1, tensor_2, bool_3, bool_4)), error_type=RuntimeError, + error_regex=r"tau.shape\[-1\] must be equal to min\(other.shape\[-2\], input.shape\[-1\]\)") + + +def error_inputs_diag(op_info, device, **kwargs): + zero_d = torch.randn((), device=device) + yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError, + error_regex="1D or 2D") + zero_d = torch.randn(1, 1, 1, device=device) + yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError, + error_regex="1D or 2D") + +def error_inputs_embedding(op_info, device, **kwargs): + indices = torch.rand(2, 2, device=device).long() + weights = [ + torch.tensor(1.0, device=device), + torch.tensor(1.0, device=device).reshape(1, 1, 1), + ] + + for weight in weights: + yield ErrorInput(SampleInput(weight, args=(indices,)), error_type=RuntimeError, + error_regex="'weight' must be 2-D") + + +def error_inputs_t(op_info, device, **kwargs): + yield ErrorInput( + SampleInput(torch.randn(2, 3, 4, 5, device=device)), + error_regex="expects a tensor with <= 2", + ) + + +def error_inputs_multinomial(op_info, device, **kwargs): + x = torch.empty(1, 2, 3, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(2,)), + error_regex="prob_dist must be 1 or 2 dim") + + x = torch.empty(1, 2, dtype=torch.long, device=device) + yield ErrorInput(SampleInput(x, args=(2,)), + error_regex="multinomial only supports floating-point dtypes for input") + + x = torch.empty(1, 2, dtype=torch.double, device=device) + y = torch.empty(1, 2, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(2,), kwargs=dict(out=y)), + error_regex="multinomial expects Long tensor out") + + x = torch.empty(2, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(0,)), + error_regex="cannot sample n_sample <= 0 samples") + + x = torch.empty(2, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(-1,)), + error_regex="cannot sample n_sample <= 0 samples") + + x = torch.empty(2, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(3, False,)), + error_regex="cannot sample n_sample > prob_dist") + + x = torch.empty(16777217, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(3,)), + error_regex="number of categories cannot exceed") + + inputs = ((1., -1., 1.), (1., inf, 1.), (1., -inf, 1.), (1., 1., nan)) + + err_msg1 = "probability tensor contains either `inf`, `nan` or element < 0" + err_msg2 = "invalid multinomial distribution" + + rep_arg = (False, True) if torch.device(device).type == 'cpu' else (False,) + + if torch.device(device).type == 'cpu': + for rep in rep_arg: + kwargs = {'num_samples': 2, 'replacement': rep} + + for shape in inputs: + # error case when input tensor contains `inf`, `nan` or negative element + yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs), + error_regex=err_msg1 if rep is False else err_msg2) + + # error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input + x = torch.zeros(3, device=device) + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) + + # error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input + x = torch.zeros(3, 3, device=device) + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) + + # error case for the invalid multinomial distribution + x[1, :] = 1 + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) + +def error_inputs_gradient(op_info, device, **kwargs): + for dtype in [torch.long, torch.float32, torch.complex64]: + t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device, dtype=dtype) + + dim = (1, 0) + spacing = [0.1] + yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)), + error_type=RuntimeError, + error_regex='torch.gradient expected spacing to be unspecified, a scalar ') + + yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=3)), + error_type=RuntimeError, + error_regex='torch.gradient only supports edge_order=1 and edge_order=2.') + + dim = (1, 1) + spacing = 0.1 + yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)), + error_type=RuntimeError, + error_regex='dim 1 appears multiple times in the list of dims') + + dim = (0, 1) + coordinates = [torch.tensor([1, 2, 4], device='cpu'), torch.tensor([1, 2, 4], device='meta')] + yield ErrorInput(SampleInput(t, kwargs=dict(spacing=coordinates, dim=dim, edge_order=1)), + error_type=RuntimeError, + error_regex='torch.gradient expected each tensor to be on the same device,') + + yield ErrorInput(SampleInput(t, kwargs=dict(dim=3)), + error_type=IndexError, error_regex='') + + t = torch.tensor([[1], [2], [3]]) + yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=1)), + error_type=RuntimeError, + error_regex='torch.gradient expected each dimension size to be at least') + + t = torch.tensor([[1, 2], [3, 4]]) + yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=2)), + error_type=RuntimeError, + error_regex='torch.gradient expected each dimension size to be at least') + +def sample_inputs_rrelu(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_unary( + op_info, device, dtype, requires_grad, op_kwargs=dict(lower=0., upper=1., training=True)) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(S)) + yield SampleInput(make_arg(S), training=False) + +def error_inputs_rrelu(op_info, device, **kwargs): + input = make_tensor((S, S), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(input, kwargs={'lower': 0.3, 'upper': 0.1}), + error_regex='Lower bound should be less than or equal to the upper bound') + +def error_inputs_masked_select(op_info, device, **kwargs): + x = torch.rand((1,), device=device).expand((3,)) + y = torch.rand((6,), device=device) + mask = torch.tensor([True, False, True, True, False, False], device=device) + + yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=x)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=y)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(mask.clone(), args=(mask,), kwargs=dict(out=mask)), + error_type=RuntimeError, + error_regex='unsupported operation') + +def error_inputs_median(op_info, device, **kwargs): + x = torch.tensor([[[[[[[[[[[[[[[[[[[[[[[[[nan], + [nan]]]]]]]]]]]]]]]]]]]]]]]]], device=device) + if device == 'cuda': + yield ErrorInput(SampleInput(x, kwargs=dict(dim=(-1))), + error_type=RuntimeError, + error_regex='CUDA Tensors cannot have more than 25 dimensions') + else: + return + + +def error_inputs_index_select(op_info, device, **kwargs): + x = torch.rand((1, 6), device=device).expand((2, 6)) + y = torch.rand((3, 6), device=device) + ind = torch.tensor([0, 1], dtype=torch.int64, device=device) + + yield ErrorInput(SampleInput(y, args=(1, ind,), kwargs=dict(out=x)), + error_type=RuntimeError, + error_regex='unsupported operation') + +def error_inputs_index_add(op_info, device, **kwargs): + result = torch.tensor([[1., 2.], [4., 5.], [7., 8.]]) + source = torch.tensor([2., 4.]) + + yield ErrorInput(SampleInput(result, args=(0, torch.tensor([0, 2]), source)), + error_type=RuntimeError, + error_regex=r'source tensor shape must match self tensor shape, ' + r'excluding the specified dimension. Got self.shape = \[3, 2\] source.shape = \[2\]') + +def error_inputs_logcumsumexp(op_info, device, **kwargs): + dim = 3 + srcs = [torch.randn(5, 2, device=device), torch.randn(0, 2, device=device)] + for src in srcs: + yield ErrorInput(SampleInput(src, args=(dim,)), + error_type=IndexError, + error_regex='Dimension out of range') + +def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + yield SampleInput( + make_arg((S, S)), gather_variable((S, S), 1, S, True, device=device), 0) + + # `indices` broadcast + yield SampleInput( + make_arg((S, S)), gather_variable((1, S // 2), 0, S, True, device=device), 1) + + # `self` broadcast + yield SampleInput( + make_arg((1, S)), gather_variable((S, S // 2), 0, S, True, device=device), 1) + + # without `dim` arg + yield SampleInput( + make_arg((S, S)), gather_variable((S, S // 2), 0, S, True, device=device)) + + # Negative indices sample — guarded against python_ref + if not kwargs.get('is_python_ref', False): + neg_idx = gather_variable((S, S), 1, S, True, device=device) - S + yield SampleInput( + make_arg((S, S)), + neg_idx, + 1) + + +def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs): + + # Error Inputs for zero-dim tensors, when 'dim' arg is not provided. + shape = (S, 0, S) + err_msg_amax_amin = "reduction" + err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity" + if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: + yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin) + elif op_info.name == 'aminmax': + yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax) + + # Error Inputs for tensors with more than 64 dimension + sizes = [1] * 65 + err_msg1 = "only tensors with up to 64 dims are supported" + yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': -1}), + error_regex=err_msg1) + yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': 64}), + error_regex=err_msg1) + + # Error Inputs for repeated 'dim' + if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: + dims = [(0, 0), (0, -4)] + err_msg2 = "in the list of dims" + x = torch.randn(S, S, S, S, device=device) + for dim in dims: + yield ErrorInput(SampleInput(x, kwargs={'dim': dim}), error_regex=err_msg2) + + # Error Input for illegal dtype + input5 = torch.randn(L, L, dtype=torch.float32, device=device) + max_values = torch.empty(L, dtype=torch.float32, device=device) + min_values = torch.empty(L, dtype=torch.double, device=device) + illegal_values = torch.empty(L, dtype=torch.int, device=device) + + # Unlike regular PyTorch, amax and amin refs don't require input and out + # dtypes to match exactly: + # https://github.com/pytorch/pytorch/pull/87765#pullrequestreview-1162023824 + if is_ref: + err_msg_amax_amin2 = ("Attempting to cast from torch.float32 to out tensor with dtype " + "torch.int32, but this can't be cast because it is not safe!") + else: + err_msg_amax_amin2 = ("Expected the dtype for input and out to match, but got Float " + "for input's dtype and Int for out's dtype.") + err_msg_aminmax2 = "Expected out tensor to have dtype float, but got double instead" + + if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: + yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}), + error_regex=err_msg_amax_amin2) + elif op_info.name == 'aminmax': + yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': (max_values, min_values)}), + error_regex=err_msg_aminmax2) + + # Error Inputs for functions to raise an error on specified zero'd dimension as reduction dim + err_msg3 = "reduction" + # FIXME: eager and ref impl throw different types of errors + error_type = IndexError if 'refs' not in op_info.name else RuntimeError + yield ErrorInput(SampleInput(torch.rand(shape, device=device), kwargs={'dim': 1}), + error_type=error_type, error_regex=err_msg3) + +def sample_inputs_aminmax(op_info, device, dtype, requires_grad, **kwargs): + test_cases: tuple[tuple, dict] = ( # type: ignore[assignment] + ((S, S, S), {}), + ((S, S, S), {'dim': 1}), + ((S, S, S), {'dim': 1, 'keepdim': True}), + ((), {'dim': 0}), + ((), {}), + ((), {'dim': 0, 'keepdim': True}), + ((S, 0, S), {'dim': 0}), + ) + + for shape, kwargs in test_cases: + yield SampleInput( + make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad), + **kwargs) + +def error_inputs_diff(op_info, device, **kwargs): + t = torch.rand((1, 3), device=device) + n = -1 + yield ErrorInput(SampleInput(t, args=(n, ), kwargs=kwargs), + error_type=RuntimeError, + error_regex=f'order must be non-negative but got {n}') + +def sample_inputs_diff(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + test_cases = ( + ((1,), 0, None, None), + ((S,), 0, None, None), + ((S, 1), 0, None, None), + ((S, 1), 1, None, None), + ((S, S), 0, None, None), + ((S, S), 1, None, None), + ((S, S), 0, (1, S), (2, S)), + ((S, S), 0, None, (2, S)), + ((XS, XS, XS), 1, None, None), + ((XS, XS, XS), 2, None, None), + ((XS, XS, XS), 1, (XS, 1, XS), (XS, 1, XS)), + ((XS, XS, XS), 2, (XS, XS, 1), (XS, XS, 1)), + ((XS, XS, XS), 2, (XS, XS, XS), (XS, XS, XS)),) + + for size, dim, size_prepend, size_append in test_cases: + prepend_size = 0 if (size_prepend is None) else size_prepend[dim] + append_size = 0 if (size_append is None) else size_append[dim] + dim_size = size[dim] + prepend_size + append_size + for n in range(dim_size): + input_tensor = make_arg(size) + prepend = make_arg(size_prepend) if size_prepend else None + append = make_arg(size_append) if size_append else None + yield SampleInput(input_tensor, n, dim, prepend, append) + + # add some samples with n > dim_size + yield SampleInput(make_arg((XS, XS, XS)), S + 1, 1) + yield SampleInput(make_arg((XS, XS, XS)), S * 3 + 2, 2, make_arg((XS, XS, XS)), make_arg((XS, XS, XS))) + +def sample_inputs_histogram(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + for size, bin_ct, weighted, density in product(sizes, range(1, 5), [False, True], [False, True]): + input_tensor = make_arg(size) + weight_tensor = make_arg(size) if weighted else None + + yield SampleInput(input_tensor, bin_ct, + weight=weight_tensor, density=density) + + bins_tensor = make_arg((bin_ct + 1,)) + sorted_bins, _bins_indices = torch.sort(bins_tensor) + yield SampleInput(input_tensor, sorted_bins, + weight=weight_tensor, density=density) + +def sample_inputs_histogramdd(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((S, S), (S, S, S), (S, 1, S), (S, 0, S)) + bin_ct_patterns = ((1, 1, 1, 1, 1), (2, 3, 2, 3, 2), (3, 2, 3, 2, 3)) + + for size, bin_ct_pattern, weighted, density in product(sizes, bin_ct_patterns, [False, True], [False, True]): + input_tensor = make_arg(size) + bin_ct = bin_ct_pattern[:size[-1]] + weight_tensor = make_arg(size[:-1]) if weighted else None + + yield SampleInput(input_tensor, bin_ct, + weight=weight_tensor, density=density) + + bins_tensor = [make_arg(ct + 1) for ct in bin_ct] + yield SampleInput(input_tensor, bins_tensor, + weight=weight_tensor, density=density) + +def error_inputs_histogramdd(opinfo, device, **kwargs): + invalid_bins = [1, 1, 1, 1, 1] + make_arg = partial(make_tensor, dtype=torch.float, device=device, requires_grad=False) + msg = "histogramdd: The size of bins must be equal to the innermost dimension of the input." + yield ErrorInput(SampleInput(make_arg(5, 6), invalid_bins), error_regex=msg) + +def sample_inputs_histc(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + for size, min, max in product(sizes, [0, -10], [0, 10]): + # construct sample input omitting bins arg + yield SampleInput(make_arg(size), min=min, max=max) + + # construct sample inputs with a few different bins values + for bins in [1, 3, 10]: + yield SampleInput(make_arg(size), bins=bins, min=min, max=max) + +def sample_inputs_bincount(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + for size, weighted in product((S, M), [False, True]): + input_tensor = torch.randint(0, size, (size,), dtype=dtype, device=device) + weight_tensor = make_arg((size,)) if weighted else None + + max_val = int(input_tensor.max().item()) + + for minlength in [0, max_val // 2, max_val, 2 * max_val]: + yield SampleInput( + input_tensor, weights=weight_tensor, minlength=minlength) + +def sample_inputs_bucketize(op_info, device, dtype, requires_grad, reference_inputs_mode=False, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = (((), S), ((S,), S), ((S, S), S), ((S, S, S), S), ((S, 1, S), S), ((S, 0, S), S)) + + if reference_inputs_mode: + sizes += (((256,), 128), ((128,), 256), ((32, 32), 11), ((32, 4, 32), 33)) + + for (input_shape, nb), out_int32, right in product(sizes, [False, True], [False, True]): + input_tensor = make_arg(input_shape) + boundaries = make_arg(nb).msort() + + yield SampleInput(input_tensor, boundaries, + out_int32=out_int32, right=right) + +reference_inputs_bucketize = partial(sample_inputs_bucketize, reference_inputs_mode=True) + +def error_inputs_bucketize(opinfo, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float, device=device, requires_grad=False) + yield ErrorInput(SampleInput(make_arg((S, S, S)), make_arg((S, S))), + error_regex="boundaries tensor must be 1 dimension") + +def sample_inputs_searchsorted(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # (unsorted tensor size, (input sizes,), is_scalar) + sizes = ( + ((0,), ((0,),), False), + ((M,), ((), (M,), (M, M)), False), + ((0, 0), ((0, 0),), False), + ((M, M), ((M, M),), False), + ((0, 0, 0), ((0, 0, 0),), False), + ((M, M, M), ((M, M, M),), False), + ((L,), ((),), True), + ) + + for (size, input_sizes, is_scalar), noncontiguous, out_int32, right in product( + sizes, [False, True], [False, True], [False, True] + ): + unsorted_tensor = make_arg(size, noncontiguous=noncontiguous) + for input_size in input_sizes: + input = make_arg(input_size, noncontiguous=noncontiguous) + if is_scalar: + input = input.item() + if np.prod(size) == 0: + boundary_tensor = unsorted_tensor + sorter = make_tensor(size, dtype=torch.int64, device=device, noncontiguous=noncontiguous) + else: + boundary_tensor, sorter = torch.sort(unsorted_tensor) + side = "right" if right else "left" + + yield SampleInput(boundary_tensor, input, out_int32=out_int32, right=right) + yield SampleInput(boundary_tensor, input, out_int32=out_int32, side=side) + + yield SampleInput(unsorted_tensor, input, out_int32=out_int32, right=right, sorter=sorter) + yield SampleInput(unsorted_tensor, input, out_int32=out_int32, side=side, sorter=sorter) + +def sample_inputs_gradient(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + test_cases_float = ( + ((S,), None, None, 1), + ((S,), 2., None, 1), + ((S, S), None, None, 2), + ((S, S), [2.0, 2.1], None, 1), + ((S, S), [2.0, 2.1], (0, 1), 1), + ((4, 4, 4), [2., 1.], (0, 1), 2), + ) + for size, spacing, dim, edge_order in test_cases_float: + t = make_arg(size) + yield SampleInput(t, dim=dim, spacing=spacing, edge_order=edge_order) + + test_cases_tensor = ( + ((3, 3, 3), ((1.1, 2.0, 3.5), (4.0, 2, 6.0)), (0, -1), 1), + ((3, 3, 3), ((1.0, 3.0, 2.0), (8.0, 6.0, 1.0)), (0, 1), 2), + ) + for size, coordinates, dim, edge_order in test_cases_tensor: + t = make_arg(size) + coordinates_tensor_list = [] + for coords in coordinates: + # `coords` will always contain floating point values and Python 3.10 does not support this + # implicit conversion to an integer using `__int__` + # TODO: this can be simplified after https://github.com/pytorch/pytorch/issues/69316 is fixed + a = torch.tensor(coords, device=device) + coordinates_tensor_list.append(a.to(dtype)) + yield SampleInput(t, dim=dim, spacing=coordinates_tensor_list, edge_order=edge_order) + +def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + test_args = [ + ([1, 2],), + (slice(0, 3),), + ((slice(0, 3), 1),), + (([0, 2, 3], [1, 3, 3], [0, 0, 2]),), + (([0, 0, 3], [1, 1, 3], [0, 0, 2]),), + ((slice(None), slice(None), [0, 3]),), + ((slice(None), [0, 3], slice(None)),), + (([0, 3], slice(None), slice(None)),), + (([0, 3], [1, 2], slice(None)),), + (([0, 3], ),), + (([0, 3], slice(None)),), + (([0, 3], Ellipsis),), + (([0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])),), + (index_variable(2, S, device=device),), + (mask_not_all_zeros((S,)),), + ] + + for args in test_args: + yield SampleInput(make_arg((S, S, S)), args=args) + + yield SampleInput(make_arg((S, S, S, S)), args=((slice(None), [0, 1], slice(None), [0, 1]),)) + +def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + for accumulate in [False, True]: + # Test with indices arg + yield SampleInput( + make_arg((S, S,)), + # As defined in the docs, if accumulate is false, duplicate indices are not supported + (index_variable(2 if accumulate else 1, S, device=device),), + make_arg((2 if accumulate else 1, S)), + accumulate=accumulate) + + # Test with mask arg + mask = torch.zeros(S, dtype=torch.bool) if accumulate else mask_not_all_zeros((S,)) + yield SampleInput( + make_arg((S, S)), (mask, ), make_arg((S,)), accumulate=accumulate) + +def sample_inputs_sort(op_info, device, dtype, requires_grad, **kwargs): + def small_3d_unique(): + res = torch.randperm(S * S * S, dtype=torch.int64, device=device).view(S, S, S) + res = res.to(dtype).requires_grad_(requires_grad) + return res + + def large_1d_unique(): + res = torch.randperm(L * L * L, dtype=torch.int64, device=device) + res = res.to(dtype).requires_grad_(requires_grad) + return res + + # Test case for large tensor. + yield SampleInput(large_1d_unique()) + + # Test cases for small 3d tensors. + # Imitates legacy tests from test/test_torch.py + dims = range(-3, 3) + flag = [True, False] + for dim, descending, stable in product(dims, flag, flag): + # default schema without stable sort + if not (dtype == torch.bool and torch.device(device).type == 'cuda'): + # bool and cuda requires stable sort for stable results, at least + # for the return index + yield SampleInput(small_3d_unique(), dim, descending) + # schema with stable sort, no CUDA support yet + if torch.device(device).type == 'cpu': + yield SampleInput( + small_3d_unique(), dim=dim, descending=descending, stable=stable) + + # Test cases for scalar tensor + tensor_opt = dict(dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(torch.tensor(1, **tensor_opt)) + yield SampleInput(torch.tensor(1, **tensor_opt), 0) + yield SampleInput(torch.tensor(1, **tensor_opt), 0, True) + + # Test cases for empty tensor + yield SampleInput(torch.tensor((), **tensor_opt)) + yield SampleInput(torch.tensor((), **tensor_opt), 0) + yield SampleInput(torch.tensor((), **tensor_opt), 0, True) + + # Test cases for stable sort + yield SampleInput(small_3d_unique(), stable=True) + yield SampleInput(small_3d_unique(), dim=0, stable=True) + yield SampleInput(small_3d_unique(), dim=0, descending=True, stable=True) + +def sample_inputs_threshold(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + sizes = ((), (S,), (S, S), (S, S, S)) + for x_size in sizes: + # threshold and values args must be numbers + yield SampleInput(make_arg(x_size), make_arg(()).item(), make_arg(()).item()) + +def sample_inputs_unique(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + for shape, sorted, return_inverse, return_counts, dim in \ + product(sizes, [False, True], [False, True], [False, True], [None, -2, -1, 0, 1, 2]): + # torch.unique cannot be called if the input tensor has a zero dimension which isn't the selected dim + if 0 in shape and shape.index(0) is not dim: + continue + + # skip invalid dim args + if dim is not None and (dim < -len(shape) or dim >= len(shape)): + continue + + kwargs = dict(sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim) + + # construct a test case with only one distinct value + input_t = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(input_t, **kwargs) + + # construct a test case with mixed 0s and 1s + input_t = make_arg(shape, dtype=torch.bool, requires_grad=False)\ + .to(dtype).requires_grad_(requires_grad) + yield SampleInput(input_t, **kwargs) + + # construct a test case with many different values + yield SampleInput(make_arg(shape), **kwargs) + +def sample_inputs_unique_consecutive(*args, **kwargs): + for sample_input in sample_inputs_unique(*args, **kwargs): + if not sample_input.kwargs["sorted"]: + sample_input.kwargs.pop("sorted") + yield sample_input + +def sample_inputs_adaptive_avg_pool1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + ((0, 8, 8), (5,)), + ((3, 8, 8), 5), + ((3, 8, 8), 1) + ) + + for input_shape, output_size in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=(output_size,)) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=(output_size,)) + + +def error_inputs_adaptive_avg_pool1d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()), + error_regex="'output_size' should contain one int") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)), + error_regex="elements of output_size must be greater than or equal to 0") + + +def sample_inputs_adaptive_avg_pool2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + ((1, 8, 8, 8), (5, 7)), + ((2, 8, 8, 8), (None, 7)), + ((1, 8, 4, 3), (5, None)), + ((1, 8, 4, 3), (None, None)), + ((1, 8, 4, 3), (5)), + ) + + for input_shape, output_size in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=(output_size,)) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=(output_size,)) + + +def error_inputs_adaptive_avg_pool2d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for incorrect input dimension + yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)), + error_type=ValueError, error_regex="Input dimension should be at least 3") + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), + error_regex="output_size must be 2") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)), + error_regex="elements of output_size must be greater than or equal to 0") + + +def sample_inputs_adaptive_avg_pool3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + ((0, 8, 8, 8, 8), (5, 7, 4)), + ((1, 8, 4, 3, 7), (None, None, None)), + ((1, 8, 4, 3, 7), (1, 1, 1)), + ((3, 3, 8, 8, 6), (5, 7, None)), + ((1, 3, 8, 8, 6), (5, None, 2)), + ((3, 3, 8, 8, 6), (None, 3, 2)), + ) + + for input_shape, output_size in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=(output_size,)) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=(output_size,)) + + +def error_inputs_adaptive_avg_pool3d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for incorrect input dimension + yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)), + error_type=ValueError, error_regex="Input dimension should be at least 4") + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), + error_regex="output_size must be 3") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)), + error_regex="elements of output_size must be greater than or equal to 0") + + +def sample_inputs_adaptive_max_pool1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + # ((0, 8, 8), (5,)), + # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1] + ((3, 4, 4), 3), + ((3, 4, 4), 1) + ) + + for shapes, return_idx in product(cases, (True, False)): + # Batched + yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx)) + # Unbatched + yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx)) + + +def error_inputs_adaptive_max_pool1d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()), + error_regex="'output_size' should contain one int") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)), + error_regex="Trying to create tensor with negative dimension") + +def sample_inputs_adaptive_max_pool2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + # ((0, 8, 8, 8), (5, 7)), + # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1] + ((1, 4, 4, 4), (2, 3)), + ((2, 4, 4, 4), (None, 3)), + ((2, 4, 4, 4), (1, 1)), + ((1, 4, 4, 3), (3, None)), + ((1, 4, 4, 3), (None, None)), + ((1, 4, 4, 3), (3)), + ) + + for shapes, return_idx in product(cases, (True, False)): + # Batched + yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx)) + # Unbatched + yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx)) + +def error_inputs_adaptive_max_pool2d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for incorrect input dimension + yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)), + error_type=ValueError, error_regex="Input dimension should be at least 3") + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), + error_regex="internal error") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)), + error_regex="Trying to create tensor with negative dimension") + + +def sample_inputs_adaptive_max_pool3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + # ((0, 8, 8, 8, 8), (5, 7, 4)), + # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1] + ((1, 4, 4, 3, 5), (None, None, None)), + ((1, 4, 4, 3, 5), (1, 1, 1)), + ((3, 3, 4, 4, 6), (2, 3, None)), + ((1, 3, 4, 4, 6), (3, None, 2)), + ((3, 3, 4, 4, 6), (None, 3, 2)), + ) + + for shapes, return_idx in product(cases, (True, False)): + # Batched + yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx)) + # Unbatched + yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx)) + +def error_inputs_adaptive_max_pool3d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for incorrect input dimension + yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)), + error_type=ValueError, error_regex="Input dimension should be at least 4") + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), + error_regex="internal error") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)), + error_regex="Trying to create tensor with negative dimension") + + +class _TestParamsMaxPoolBase: + + def __init__(self) -> None: + self.kwargs = { + 'kernel_size': [3], + 'stride': [2, None], + 'ceil_mode': [True, False], + 'padding': [0, 1], + 'dilation': [1], + 'return_indices': [True, False] + } + + self.shapes = [ + [1, 2, None], # batch + [2], # channels + [3, 6] # signal + ] + + def _gen_shape(self): + for shape in product(*self.shapes): + # shape[0] is None indicates missing batch dimension + if shape[0] is None: + shape = shape[1:] + + yield shape, torch.contiguous_format + # only 2d (N, C, H, W) rank 4 tensors support channels_last memory format + if len(self.shapes) == 4 and len(shape) == 4: + yield shape, torch.channels_last + + def _gen_kwargs(self): + keys = self.kwargs.keys() + for values in product(*self.kwargs.values()): + yield dict(zip(keys, values, strict=True)) + + def gen_input_params(self): + yield from product(self._gen_shape(), self._gen_kwargs()) + +class _TestParamsMaxPool1d(_TestParamsMaxPoolBase): + + def __init__(self) -> None: + super().__init__() + self.kwargs['kernel_size'] += [(3,)] + self.kwargs['stride'] += [(2,)] + self.kwargs['padding'] += [(1,)] + self.kwargs['dilation'] += [(1,)] + +class _TestParamsMaxPool2d(_TestParamsMaxPoolBase): + + def __init__(self) -> None: + super().__init__() + self.kwargs['kernel_size'] += [(3, 2)] + self.kwargs['stride'] += [(2, 1)] + self.kwargs['padding'] += [(1, 1)] + self.kwargs['dilation'] += [(1, 2)] + + self.shapes.append([6]) + +class _TestParamsMaxPool3d(_TestParamsMaxPoolBase): + + def __init__(self) -> None: + super().__init__() + self.kwargs['kernel_size'] += [(3, 2, 3)] + self.kwargs['stride'] += [(2, 1, 2)] + self.kwargs['dilation'] += [(1, 2, 1)] + + self.shapes.append([6]) + self.shapes.append([5]) + +def sample_inputs_max_pool(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + params_generator_type_dict = { + 'nn.functional.max_pool1d': _TestParamsMaxPool1d, + 'nn.functional.max_pool2d': _TestParamsMaxPool2d, + 'nn.functional.max_pool3d': _TestParamsMaxPool3d, + 'max_pool2d_with_indices_backward': _TestParamsMaxPool2d, + } + + params_generator = params_generator_type_dict[op_info.name]() + for (shape, memory_format), kwargs in params_generator.gen_input_params(): + arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) + yield SampleInput(arg, kwargs=kwargs) + +def max_pool2d_backward(*args, kernel_size=(), stride=(), padding=(0,), dilation=(1,), ceil_mode=False, **kwargs): + out, indices = torch.nn.functional.max_pool2d_with_indices( + *args, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, return_indices=True) + grad_out = torch.ones_like(out) + if stride is None: + stride = kernel_size + out_b = torch.ops.aten.max_pool2d_with_indices_backward.default( + grad_out, *args, kernel_size, stride, padding, dilation, ceil_mode, indices) + return out_b + +def error_inputs_max_pool1d(op_info, device, **kwargs): + # Toggle requires_grad because `max_pool1d` has different path + # based on whether `requires_grad` is set or not. + for requires_grad in (True, False): + make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=requires_grad) + # error inputs when pad is negative + x = make_arg((0, 1, 49)) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs when pad > ((kernel_size - 1) * dilation + 1) / 2, when dilation is not default + yield ErrorInput(SampleInput(x, + kwargs={'kernel_size': 3, 'dilation': 2, 'stride': 1, 'padding': 3, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs for input tensor + error_msg = r'Expected 2D or 3D \(batch mode\) tensor with optional 0 dim batch size for input' + yield ErrorInput(SampleInput(make_arg((), requires_grad=requires_grad), kwargs={'kernel_size': 1}), + error_regex=error_msg) + + # error inputs for empty input + yield ErrorInput(SampleInput(torch.tensor([], device=device, requires_grad=requires_grad), + kwargs={'kernel_size': 1}), + error_regex=error_msg) + + # error: unbatched input with 0 sized non-batch dims. + yield ErrorInput(SampleInput(make_arg((0, 10), requires_grad=requires_grad), + kwargs={'kernel_size': 1}), + error_regex=error_msg) + + # error: batched input with 0 sized non-batch dims. + yield ErrorInput(SampleInput(make_arg((1, 10, 0), requires_grad=requires_grad), + kwargs={'kernel_size': 1}), + error_regex=error_msg) + + # error inputs for empty input with stride=0 + error_msg = 'stride must be greater than zero, but got 0' + yield ErrorInput(SampleInput(make_arg((3, 3, 3)), kwargs={'kernel_size': 1, 'stride': 0}), + error_regex=error_msg) + + # error inputs for empty input with dilation=0 + error_msg = 'dilation must be greater than zero, but got 0' + yield ErrorInput(SampleInput(make_arg((3, 3, 3)), + kwargs={'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 0}), + error_regex=error_msg) + + # error inputs for invalid output size + error_msg = 'Invalid computed output size: -2' + yield ErrorInput(SampleInput(make_arg((2, 2, 2)), + kwargs={'kernel_size': 5, 'stride': 1, 'padding': 0, 'dilation': 1}), + error_regex=error_msg) + + # error inputs when kernel_size=0 + error_msg = 'kernel_size must be greater than zero' + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 0}), + error_regex=error_msg) + + # error inputs for strides > 0 + error_msg = 'stride must be greater than zero' + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 0}), + error_regex=error_msg) + + +def error_inputs_max_pool2d(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + # error inputs when pad is negative + x = make_arg((0, 1, 49)) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + # 2-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 (kernel_size : int) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs when pad > kernel_size / 2 (kernel_size : tuple) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error: unbatched input with 0 sized non-batch dims. + err_msg = r'Expected 3D or 4D \(batch mode\) tensor with optional 0 dim batch size for input' + yield ErrorInput(SampleInput(make_arg((1, 0, 10)), + kwargs={'kernel_size': 1}), + error_regex=err_msg) + + # error: batched input with 0 sized non-batch dims. + yield ErrorInput(SampleInput(make_arg((2, 1, 10, 0)), + kwargs={'kernel_size': 1}), + error_regex=err_msg) + + # error: inputs when kernel size too large for input + yield ErrorInput(SampleInput(make_arg((1, 1, 4)), + kwargs={'kernel_size': 2}), + error_regex='Output size is too small') + + +def error_inputs_max_pool3d(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + # error inputs when pad is negative + x = make_arg((0, 1, 49, 50)) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + # 3-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, + 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 (kernel_size: int) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs when pad > kernel_size / 2 (kernel_size: tuple) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, + 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error: unbatched input with 0 sized non-batch dims. + err_msg = r'Expected input\'s non-batch dimensions to have positive length' + yield ErrorInput(SampleInput(make_arg((0, 1, 2, 10)), + kwargs={'kernel_size': 1}), + error_regex=err_msg) + + # error: batched inputs with 0 sized non-batch dims. + yield ErrorInput(SampleInput(make_arg((2, 1, 0, 1, 2)), + kwargs={'kernel_size': 1}), + error_regex=err_msg) + + # error: inputs when kernel size too large for input + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 4, 4)), + kwargs={'kernel_size': 2}), + error_regex='Output size is too small') + + + +def sample_inputs_normalize(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, low=-1, high=1, device=device, dtype=dtype, requires_grad=requires_grad) + + cases: tuple[tuple[int, ...], dict] = ( + ((2, 1, 4, 5), {'p': 1., 'dim': 2}), + ((2, 3, 4, 5), {'p': 2., 'dim': 1}), + ((1, 2, 4, 5), {'p': 0.5, 'dim': 0}), + ((1, 3, 4, 5), {'p': -1., 'dim': 1}), + ((1, 3, 4, 5), {'p': 0., 'dim': -1}), + ((), {'p': 1.2, 'dim': 0}), + ((2, 3, 4, 5), {}), + ((2, 3, 4, 5), {'eps': 1e-4})) + + for input_shape, kwargs in cases: + yield SampleInput(make_arg(input_shape), kwargs=kwargs) + + +def complex_conv(fn, input_size, weight, grad_output, stride, padding, dilation, groups): + # conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0)) + # a = conv(Wr, xr, br), + # b = conv(Wi, xi, 0), + # c = conv(Wr + Wi, xr + xi, br + bi) + # conv(W, x, b) = a - b + i(c - a - b) + + grad_output_ = torch.view_as_real(grad_output) + grad_output_r = grad_output_[..., 0] + grad_output_i = grad_output_[..., 1] + + weight_ = torch.view_as_real(weight) + weight_r = weight_[..., 0] + weight_i = weight_[..., 1] + + a = fn(input_size, weight_r, grad_output_r, stride, padding, dilation, groups) + b = fn(input_size, weight_i, grad_output_i, stride, padding, dilation, groups) + c = fn(input_size, weight_r + weight_i, grad_output_r + grad_output_i, stride, padding, dilation, groups) + + return (a - b) + 1j * (c - a - b) + + +def conv_transpose_ref(input, weight, bias, stride=1, padding=0, + output_padding=0, dilation=1, groups=1, + fn=None): + # Derivative of `conv` is `conv_transpose`. + # To verify the correctness of `conv_transpose`, + # we rely `torch.nn.grad` implementation (which is tested in test_nn.py) + # for floating dtypes. + + assert fn is not None + + grad_fn_map = {torch.nn.functional.conv_transpose1d: torch.nn.grad.conv1d_input, + torch.nn.functional.conv_transpose2d: torch.nn.grad.conv2d_input, + torch.nn.functional.conv_transpose3d: torch.nn.grad.conv3d_input} + batched_dim_map = {torch.nn.functional.conv_transpose1d: 3, + torch.nn.functional.conv_transpose2d: 4, + torch.nn.functional.conv_transpose3d: 5} + + # Input for `ref` is ndarray. + input, weight = torch.from_numpy(input), torch.from_numpy(weight) + + is_batched = len(input.shape) == batched_dim_map[fn] + if not is_batched: + input = input.unsqueeze(0) + + if bias is not None: + bias = torch.from_numpy(bias) + unsqueeze_dims = input.ndim - 2 + for _ in range(unsqueeze_dims): + bias = bias.unsqueeze(1) + + grad_output = input + # Get the input shape for grad_fn. + conv_transpose_output = fn(grad_output.to('meta'), weight.to('meta'), None, + stride=stride, padding=padding, output_padding=output_padding, + groups=groups, dilation=dilation) + input_size = conv_transpose_output.shape + + grad_fn = grad_fn_map[fn] + if weight.dtype.is_complex: + out = complex_conv(grad_fn, input_size, weight, grad_output, stride, padding, dilation, groups) + else: # Floating + out = grad_fn(input_size, weight, grad_output, stride, padding, dilation, groups) + + if bias is not None: + out = out + bias + + return out.squeeze(0) if not is_batched else out + + +def sample_inputs_conv_transpose1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and a dict of values of (stride, padding, output_padding, groups, dilation) + cases: tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], dict] = ( + ((1, 3, 4), (3, 3, 3), (3,), + {'stride': (2,), 'padding': 2, 'output_padding': (1,), 'groups': 1}), + ((2, 2, 4), (2, 2, 4), (4,), + {'stride': (3,), 'padding': (1,), 'output_padding': (2,), 'groups': 2, 'dilation': (4,)}), + ((1, 1, 4), (1, 1, 4), (1,), + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2,)}), + ((1, 1, 4), (1, 2, 3), None, + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}), + ((1, 4, 5), (4, 8, 3), None, + {}) + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def sample_inputs_conv_transpose2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and a dict of values of (stride, padding, output_padding, groups, dilation) + cases: tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], dict] = ( + ((1, 3, 4, 4), (3, 3, 3, 3), (3,), + {'stride': (2, 2), 'padding': 2, 'output_padding': (1, 1), 'groups': 1}), + ((2, 2, 4, 4), (2, 2, 4, 5), (4,), + {'stride': (3, 2), 'padding': (1, 2), 'output_padding': (2, 3), 'groups': 2, 'dilation': (4, 4)}), + ((1, 1, 4, 5), (1, 1, 4, 3), (1,), + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3)}), + ((1, 1, 4, 3), (1, 2, 3, 4), None, + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}), + ((2, 4, 4, 4), (4, 1, 3, 3), None, {'groups': 4}), + ((1, 2, 5, 5), (2, 4, 3, 3), None, {}) + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + +def sample_inputs_conv_transpose3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and a dict of values of (stride, padding, output_padding, groups, dilation) + cases: tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], dict] = ( + ((1, 3, 4, 4, 4), (3, 3, 3, 3, 3), (3,), + {'stride': (2, 2, 2), 'padding': 2, 'output_padding': (1, 1, 1), 'groups': 1}), + ((2, 2, 4, 4, 4), (2, 2, 4, 5, 6), (4,), + {'stride': (3, 2, 1), 'padding': (1, 2, 3), 'output_padding': (2, 3, 1), 'groups': 2, 'dilation': (4, 4, 4)}), + ((1, 1, 4, 5, 2), (1, 1, 4, 3, 1), (1,), + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3, 2)}), + ((1, 1, 4, 3, 4), (1, 2, 3, 4, 5), None, + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}), + ((1, 4, 5, 5, 5), (4, 8, 3, 3, 3), None, + {}) + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias, + # and a dict of values of (stride, padding, dilation, groups) + cases: tuple = ( + ((1, 3, 4), (3, 3, 3), (3,), {'stride': (2,), 'padding': 2, 'groups': 1}), + ((2, 4, 8), (2, 2, 3), (2,), {'stride': 3, 'padding': 1, 'groups': 2, 'dilation': 2}), + ((1, 4, 5), (1, 4, 3), None, {'stride': (2,), 'padding': 'valid'}), + ((2, 2, 4), (2, 1, 4), (2,), {'stride': (1,), 'padding': 'same', 'groups': 2, 'dilation': (2,)}), + # With defaults + ((1, 4, 5), (3, 4, 3), None, {}), + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def error_inputs_conv1d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) + make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_int_arg((1, 1, 4)), args=(make_int_arg((1, 1, 2)), make_arg((1,)))), + error_regex="should be the same") + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_complex_arg((1,)))), + error_regex="should be the same") + + # error inputs for negative strides + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 3)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2)), make_arg((1,)))), + error_regex="weight should have at least three dimensions") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") + + +def error_inputs_conv2d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) + make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_int_arg((2, 4, 4)), args=(make_int_arg((3, 2, 3, 3)), make_arg((3,)))), + error_regex="should be the same") + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_arg((2, 4, 4)), args=(make_arg((3, 2, 3, 3)), make_complex_arg((3,)))), + error_regex="should be the same") + + # error inputs for negative strides + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 2, 2, 3)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2, 4)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 2)), args=(make_arg((1, 1, 2, 5)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 1, 3, 2)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'padding': 'same'}), error_regex="Expected 3-dimensional input for 3-dimensional weight") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), + kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for groups the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 5)), args=(make_arg((2, 2, 1, 4)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 4, 3)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") + + +def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and a dict of values of (stride, padding, groups, dilation) + cases: tuple = ( + ((1, 3, 4, 4), (3, 3, 3, 3), (3,), + {'stride': (2, 2), 'padding': 2, 'groups': 1}), + ((2, 4, 8, 8), (2, 2, 3, 3), (2,), + {'stride': (3, 2), 'padding': (2, 1), 'groups': 2, 'dilation': (4, 4)}), + ((1, 4, 5, 5), (1, 4, 2, 3), (1,), + {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}), + ((1, 4, 5, 5), (1, 4, 2, 3), (1,), + {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}), + ((1, 2, 4, 3), (4, 2, 3, 4), None, + {'stride': 2, 'padding': 1, 'groups': 1}), + ((1, 4, 5, 5), (1, 4, 2, 3), (1,), + {'stride': 2, 'padding': "valid"}), + ((1, 4, 5, 5), (1, 4, 2, 3), (1,), + {'stride': 1, 'padding': "same", 'dilation': 3}), + # Below are the group related samples from common_nn.py + ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4}), + ((2, 4, 6, 6), (8, 1, 3, 3), (8,), {'groups': 4}), + ((2, 4, 6, 6), (8, 1, 3, 3), None, {'groups': 4}), + ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'stride': (3, 2)}), + ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'padding': (1, 1)}), + ((2, 4, 5, 5), (4, 1, 2, 2), (4,), {'groups': 4, 'dilation': (2, 2)}), + ((2, 4, 6, 5), (6, 2, 3, 2), (6,), {'groups': 2}), + # With defaults + ((1, 4, 5, 5), (3, 4, 3, 3), None, {}), + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def sample_inputs_conv3d(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and dict of values of (stride, padding, dilation, groups) + cases: tuple = ( + ((1, 1, 4, 4, 4), (1, 1, 1, 1, 1), (1,), {'padding': 'same'}), + ((1, 1, 4, 4, 4), (1, 1, 4, 4, 4), (1,), {'stride': (2, 2, 2)}), + ((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (1,), {'dilation': 2}), + ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), + ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same'}), + ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same', 'dilation': 2}), + ((1, 1, 10, 11, 12), (1, 1, 4, 4, 4), None, {'padding': 'same', 'dilation': 3}), + ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), + ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'groups': 3}), + ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'stride': (2, 2, 2), 'dilation': 1, 'groups': 3}), + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def error_inputs_conv3d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) + make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_int_arg((1, 1, 4, 4, 4)), args=(make_int_arg((1, 1, 2, 2, 2)), make_arg((1,)))), + error_regex="should be the same") + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_complex_arg((1,)))), + error_regex="should be the same") + + # error inputs for negative strides + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 3, 3, 3)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput( + SampleInput(make_arg((1, 1, 3, 4, 5)), args=(make_arg((1, 1, 4, 3)), make_arg((1,))), + kwargs={'padding': 'same'}), error_regex="Expected 4-dimensional input for 4-dimensional weight") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'groups': 3}), + error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'padding': 'same', 'groups': 3}), + error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'padding': 'same', 'groups': 0}), + error_regex="non-positive groups is not supported") + + # error inputs for padding='same' not supported by strided convolutions + yield ErrorInput( + SampleInput(make_arg((18, 27, 9, 1, 9)), args=(make_arg((9, 9, 9, 1, 9)), + make_arg((9,))), kwargs={'stride': 2, 'padding': 'same', 'groups': 3}), + error_regex="padding='same' is not supported for strided convolutions") + + +def sample_inputs_group_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, num groups, and kwargs for eps + cases: tuple[tuple[int, ...], int, float] = ( + ((1, 6, 3), 2, {'eps' : 0.5}), + ((2, 6, 3), 2, {'eps' : -0.5}), + ((1, 3), 1, {'eps' : 1e-5}), + ((0, 2), 1, {'eps' : 1e-5}), + ((S, S, S), 1, {'eps' : 0.5}), + ) + + # num_channels is inferred to be input.shape[1] dimension + for input_shape, num_groups, kwargs in cases: + # Shape of weight and bias should be the same as num_channels + channels = input_shape[1] if len(input_shape) > 1 else 0 + weight_tensor = make_arg(channels) + bias_tensor = make_arg(channels) + + # Checking for permutations of weights and biases as `None` + weights = [weight_tensor, None] + biases = [bias_tensor, None] + for weight, bias in itertools.product(weights, biases): + kwargs = { + 'weight': weight, + 'bias': bias, + **kwargs + } + yield SampleInput(make_arg(input_shape), num_groups, **kwargs) + + # Without any optional args + yield SampleInput(make_arg((1, 2)), args=(1,)) + +def reference_inputs_group_norm(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_group_norm( + op_info, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, num groups, and kwargs for eps + cases: tuple[tuple[int, ...], int, float] = ( + ((20, 6, 10, 10), 3, {'eps' : 1e-5}), + # equivalent with InstanceNorm + # GroupNorm(C, num_groups=C) == InstanceNorm(num_features=C) + ((20, 6, 10, 10), 6, {'eps' : 1e-5}), + # equivalent with LayerNorm + # GroupNorm(C, num_groups=1, affine=False) == LayerNorm(normalized_shape=[C, H, W], elementwise_affine=False) + ((20, 6, 10, 10), 1, {'eps' : 1e-5}), + ) + + # num_channels is inferred to be input.shape[1] dimension + for input_shape, num_groups, kwargs in cases: + # Shape of weight and bias should be the same as num_channels + channels = input_shape[1] if len(input_shape) > 1 else 0 + input_tensor = make_arg(input_shape) + weight_tensor = make_arg(channels) + bias_tensor = make_arg(channels) + + # Checking for permutations of weights and biases as `None` + weights = [weight_tensor, None] + biases = [bias_tensor, None] + for weight, bias in itertools.product(weights, biases): + kwargs = { + 'weight': weight, + 'bias': bias, + **kwargs + } + yield SampleInput(input_tensor, num_groups, **kwargs) + + +def sample_inputs_instance_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + # Ordered as: input shape, kwargs for momentum, eps + cases: tuple[tuple[int, ...], dict] = ( + ((S, S, S), {'momentum': 0.5, 'eps': 0.6}), + ((S, S, S), {'momentum': 0.5, 'eps': 0.6, 'use_input_stats': True}), + ((3, 2, 4), {'momentum': -1.2}), + ((3, 2, 4), {'momentum': 0.0}), + ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}), + ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}), + ) + + for input_shape, kwargs in cases: + # args: running mean, running var, weight and bias should necessarily be of shape: (channels,) + channels = input_shape[1] + weight = make_arg(channels) + bias = make_arg(channels) + running_mean = make_arg_without_requires_grad(channels, low=0) + running_var = make_arg_without_requires_grad(channels, low=0) + new_kwargs = { + 'running_mean': running_mean, + 'running_var': running_var, + 'weight': weight, + 'bias': bias, + **kwargs + } + + yield SampleInput( + make_arg(input_shape), + args=(), + kwargs=new_kwargs + ) + + # Checking for permutations of weights and biases as `None` + # instance_norm assumes that if there's a bias, there's a weight + weights = [channels, None] + biases = [None, None] + + for weight_channels, bias_channels in zip(weights, biases, strict=True): + running_mean = make_arg_without_requires_grad(channels, low=0) + running_var = make_arg_without_requires_grad(channels, low=0) + yield SampleInput( + make_arg(input_shape), + args=(), + kwargs={ + 'running_mean': running_mean, + 'running_var': running_var, + 'weight': make_arg(weight_channels) if weight_channels is not None else None, + 'bias': make_arg(bias_channels) if bias_channels is not None else None + } + ) + + # Test case for no optional kwargs + yield SampleInput(make_arg((1, 2, 3)), kwargs={}) + +def sample_inputs_safe_softmax(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + def make_bool_mask(*shape): + return torch.randint(0, 2, shape, device=device, dtype=torch.bool) + + def mask_two_rows(rows, cols): + mask_two_rows = torch.ones((rows, cols), dtype=torch.bool, device=device) + mask_two_rows[rows - 1] = False + mask_two_rows[rows - 3] = False + return mask_two_rows + + def convert_to_float_mask(mask: torch.Tensor) -> torch.Tensor: + return torch.where(~mask, float('-inf'), 0.0) + + def with_requires_grad(tensor): + return tensor.requires_grad_(requires_grad) + + def generate_input_from_mask(mask_shape, dim): + mask = make_bool_mask(*mask_shape) + input_tensor = make_arg(mask_shape) + masked_input = input_tensor + convert_to_float_mask(mask) + return SampleInput(with_requires_grad(masked_input), kwargs={'dim': dim}) + + samples = [ + # Basic 3D tensor with mask + generate_input_from_mask((2, 3, 4), dim=1), + # 2D tensor with mask, testing different dim + generate_input_from_mask((5, 5), dim=0), + # 4D tensor, testing with a different dim + generate_input_from_mask((2, 3, 4, 5), dim=2), + # Edge case: 1D tensor + generate_input_from_mask((10,), dim=0), + # Edge case: tensor with one dimension of size 1 + generate_input_from_mask((1, 5, 5), dim=1), + # Testing with all elements masked + SampleInput( + with_requires_grad( + make_arg((3, 3)) + + convert_to_float_mask( + torch.zeros((3, 3), dtype=torch.bool, device=device) + ) + ), + kwargs={"dim": 1}, + ), + # Testing with no elements masked + SampleInput( + with_requires_grad( + make_arg((3, 3)) + + convert_to_float_mask( + torch.ones((3, 3), dtype=torch.bool, device=device) + ) + ), + kwargs={"dim": 1}, + ), + # Testing with two rows masked + SampleInput( + with_requires_grad( + make_arg((6, 3)) + convert_to_float_mask(mask_two_rows(6, 3)) + ), + kwargs={"dim": 1}, + ), + ] + yield from samples + +def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, normalized_shape and a kwarg dict for eps + cases: tuple[tuple[int, ...], tuple[int, ...], dict] = ( + ((1, 2, 3), (1, 2, 3), {'eps': 0.5}), + ((2, 2, 3), (2, 3), {'eps': -0.5}), + ((1,), (1,), {}), + ((1, 2), (2,), {}), + ((0, 1), (1,), {}), + ) + + for input_shape, normalized_shape, kwargs in cases: + # Shape of weight and bias should be the same as normalized_shape + weight = make_arg(normalized_shape) + bias = make_arg(normalized_shape) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight, bias), + kwargs=kwargs + ) + # Without any optional args + yield SampleInput(make_arg((1, 2)), args=((2,),)) + + # TODO: @krshrimali, once to_numpy method in SampleInput class is modified to take None inputs, + # enable these inputs; see https://github.com/pytorch/pytorch/pull/63276#discussion_r691950400 + + # With weight and a `None` bias + # yield SampleInput(make_arg((1, 2)), args=((2,), make_arg((2,)), None)) + + # With `None` weight and bias (tests failing for this, see the link above) + # yield SampleInput(make_arg((1, 2)), args=((2,), None, make_arg((2,)))) + + +def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, normalized_shape, eps + cases: tuple[tuple[int, ...], tuple[int, ...], float] = ( + ((1, 2, 3), (1, 2, 3), 0.5), + ((2, 2, 3), (2, 3), -0.5), + ((1,), (1,), 1e-5), + ((1, 2), (2,), 1e-5), + ((0, 1), (1,), 1e-5), + ) + + for input_shape, normalized_shape, eps in cases: + # Shape of weight and bias should be the same as normalized_shape + weight = make_arg(normalized_shape) + bias = make_arg(normalized_shape) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight, bias, eps), + ) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, None, bias, eps), + ) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight, None, eps), + ) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, None, None, eps), + ) + +def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, high=1000) + + # Ordered as input shape, normalized_shape and a kwarg dict for eps + cases: tuple[tuple[int, ...], tuple[int, ...], dict] = ( + ((1, 2, 3), (1, 2, 3), {'eps': 0.5}), + ((2, 2, 3), (2, 3), {'eps': -0.5}), + ((1,), (1,), {}), + ((1, 2), (2,), {}), + ((0, 1), (1,), {}), + ) + + for input_shape, normalized_shape, kwargs in cases: + # Shape of weight and bias should be the same as normalized_shape + weight = make_arg(normalized_shape) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight), + kwargs=kwargs + ) + # Without any optional args + yield SampleInput(make_arg((1, 2)), args=((2,),)) + +def error_inputs_group_norm(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + + # check that input has minimum number of dimensions + err_msg1 = "Expected at least 2 dimensions for input tensor but received" + s1 = SampleInput(make_arg(1), args=(1,)) + yield ErrorInput(s1, error_regex=err_msg1) + + # check that the channels dimension is compatible with number of groups + err_msg2 = "Expected number of channels in input to be divisible by num_groups, but got input of shape" + s2 = SampleInput(make_arg((2, 7, 4)), args=(2,)) + yield ErrorInput(s2, error_regex=err_msg2) + +def error_inputs_native_layer_norm(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + input_shape = (1, 2, 3) + + err_msg1 = "Expected normalized_shape to be at least 1-dimensional" + s1 = SampleInput( + make_arg(input_shape), args=((), None, None, 1e-5) + ) + yield ErrorInput(s1, error_regex=err_msg1) + + normalized_shape = (1, 2, 3) + weight = make_arg((1, 2)) + err_msg2 = "Expected weight to be of same shape as normalized_shape" + s2 = SampleInput( + make_arg(input_shape), args=(normalized_shape, weight, None, 1e-5) + ) + yield ErrorInput(s2, error_regex=err_msg2) + + bias = make_arg((1, 2)) + err_msg3 = "Expected bias to be of same shape as normalized_shape" + s3 = SampleInput( + make_arg(input_shape), args=(normalized_shape, None, bias, 1e-5) + ) + yield ErrorInput(s3, error_regex=err_msg3) + + err_msg4 = "Given normalized_shape=" + s4 = SampleInput( + make_arg((2, 2, 3)), args=((2, 2), None, None, 1e-5) + ) + yield ErrorInput(s4, error_regex=err_msg4) + +def error_inputs_rms_norm(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + input_shape = (1, 2, 3) + + err_msg1 = "Expected normalized_shape to be at least 1-dimensional" + s1 = SampleInput( + make_arg(input_shape), args=((), None, 1e-5) + ) + yield ErrorInput(s1, error_regex=err_msg1) + + normalized_shape = (1, 2, 3) + weight = make_arg((1, 2)) + err_msg2 = "Expected weight to be of same shape as normalized_shape" + s2 = SampleInput( + make_arg(input_shape), args=(normalized_shape, weight, 1e-5) + ) + yield ErrorInput(s2, error_regex=err_msg2) + + + err_msg4 = "Given normalized_shape=" + s4 = SampleInput( + make_arg((2, 2, 3)), args=((2, 2), None, 1e-5) + ) + yield ErrorInput(s4, error_regex=err_msg4) + + +def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, size and a kwarg dict for alpha, beta, and k + cases: tuple[tuple[int, ...], tuple[int, ...], dict] = ( + ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), + ((1, 6, 3), 2, {'beta': 0.5, 'k': 1.25}), + ((1, 6, 3), 2, {'alpha': 3e-05, 'k': 1.25}), + ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5}), + ((1, 6, 3), 2, {'alpha': 3e-05}), + ((1, 6, 3), 2, {'beta': 0.5}), + ((1, 6, 3), 2, {'k': 1.25}), + ((1, 6, 3), 2, {}), + ((2, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), + ((1, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), + ((0, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), + ) + + for input_shape, size, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(size,), kwargs=kwargs) + +def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs): + N = 5 + # make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ? + make_arg = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-5, high=5) + return (SampleInput(make_arg((N * 2, N * 2))) for _ in range(1, N)) + +def sample_inputs_linear(self, device, dtype, requires_grad, **kwargs): + features_options = [[3, 4], [8, 8]] + batch_options: list[list[int]] = [ + [], # no batch + [0], + [8], + [2, 3], + ] + create_tensor = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-2, high=2) + + for has_bias, (in_feat, out_feat), batch_shape in \ + itertools.product([True, False], features_options, batch_options): + input_tensor = create_tensor(batch_shape + [in_feat]) + weight = create_tensor([out_feat, in_feat]) + if not has_bias: + yield SampleInput(input_tensor, weight) + continue + + bias = create_tensor([out_feat]) + yield SampleInput(input_tensor, weight, bias) + + # 5D tensor, used to crash on MPS, see https://github.com/pytorch/pytorch/issues/114942 + yield SampleInput(create_tensor(2, 1, 2, 1, 2), create_tensor(4, 2)) + yield SampleInput(create_tensor(2, 1, 2, 1, 2), create_tensor(4, 2), create_tensor(4)) + +def sample_inputs_bilinear(self, device, dtype, requires_grad, **kwargs): + features_options = [[3, 4, 5], [8, 8, 8]] + batch_options: list[list[int]] = [ + [], # no batch + [0], + [8], + [2, 3], + ] + create_tensor = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-2, high=2) + + for has_bias, (in_feat1, in_feat2, out_feat), batch_shape in \ + itertools.product([True, False], features_options, batch_options): + input_tensor1 = create_tensor(batch_shape + [in_feat1]) + input_tensor2 = create_tensor(batch_shape + [in_feat2]) + weight = create_tensor([out_feat, in_feat1, in_feat2]) + if not has_bias: + yield SampleInput(input_tensor1, input_tensor2, weight) + continue + bias = create_tensor([out_feat]) + yield SampleInput(input_tensor1, input_tensor2, weight, bias) + +def sample_inputs_glu(self, device, dtype, requires_grad, **kwargs): + features_options = [[2], [2, 4], [8, 8], [3, 6, 8], [1, 4, 6, 7]] + batch_options: list[list[int]] = [ + [], # no batch + [0], + [8], + [2, 3], + ] + create_tensor = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-2, high=2) + + for features, batch_shape in itertools.product(features_options, batch_options): + ndim = len(features) + len(batch_shape) + for dim in range(ndim): + input_tensor = create_tensor(batch_shape + features) + dim_size = input_tensor.size(dim) + if dim_size > 0 and dim_size % 2 == 0: + yield SampleInput(input_tensor, dim) + +def sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs): + N, C = 2, 3 + D = 4 + S = 3 + L = 5 + + align_corners_options: tuple[Any, ...] = (None,) + if mode in ('linear', 'bilinear', 'bicubic', 'trilinear'): + align_corners_options = (True, False, None) + ranks_for_mode = { + 'nearest': [1, 2, 3], + 'nearest-exact': [1, 2, 3], + 'linear': [1], + 'bilinear': [2], + 'bicubic': [2], + 'trilinear': [3], + 'area': [1, 2, 3] + } + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + def uneven_shape(size, rank, with_batch_channel=True): + rc = list(shape(size, rank, with_batch_channel)) + rc[-1] += 1 + if rank > 2: + rc[-2] -= 1 + return tuple(rc) + + if mode in ('bilinear', 'bicubic') and dtype == torch.uint8: + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype + high=256 if dtype == torch.uint8 else None, + ) + # provide few samples for a more close to typical image processing usage + rank = 2 + for memory_format in [torch.contiguous_format, torch.channels_last]: + yield SampleInput( + make_arg(shape(270, rank), memory_format=memory_format), + shape(130, rank, False), + scale_factor=None, + mode=mode, + align_corners=False, + ) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for align_corners in align_corners_options: + for rank in ranks_for_mode[mode]: + yield SampleInput( + make_arg(shape(D, rank)), + shape(S, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + yield SampleInput( + make_arg(shape(D, rank)), + shape(L, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + if rank > 1 and dtype.is_floating_point: + yield SampleInput( + make_arg(uneven_shape(D, rank)), + uneven_shape(S, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + yield SampleInput( + make_arg(uneven_shape(D, rank)), + uneven_shape(L, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + for recompute_scale_factor in [False, True]: + for scale_factor in [1.7, 0.6]: + yield SampleInput( + make_arg(shape(D, rank)), + size=None, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + ) + +def reference_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs) + + if mode in ('bilinear', 'bicubic'): + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype + high=256 if dtype == torch.uint8 else None, + ) + # provide few samples for more typical image processing usage + for memory_format in [torch.contiguous_format, torch.channels_last]: + for aa in [True, False]: + yield SampleInput( + make_arg((2, 3, 345, 456), memory_format=memory_format), + (270, 270), + scale_factor=None, + mode=mode, + align_corners=False, + antialias=aa, + ) + +def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): + N, C = 2, 3 + D = 4 + S = 3 + L = 5 + + ranks_for_mode = { + 'nearest': [1, 2, 3], + 'bilinear': [2], + } + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return torch.Size([N, C] + ([size] * rank)) + return torch.Size([size] * rank) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for rank in ranks_for_mode[mode]: + yield SampleInput(make_arg(shape(D, rank)), size=shape(S, rank, False)) + yield SampleInput(make_arg(shape(D, rank)), size=shape(L, rank, False)) + yield SampleInput(make_arg(shape(D, rank)), scale_factor=1.7) + yield SampleInput(make_arg(shape(D, rank)), scale_factor=0.6) + +def reference_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs) + + if mode == 'bilinear': + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype + high=256 if dtype == torch.uint8 else None, + ) + # provide a single sample for more typical image processing usage + for memory_format in [torch.contiguous_format, torch.channels_last]: + yield SampleInput( + make_arg((2, 3, 345, 456), memory_format=memory_format), + (270, 270), + ) + +def sample_inputs_upsample_aa(mode, self, device, dtype, requires_grad, **kwargs): + N = 6 + C = 3 + H = 10 + W = 20 + S = 3 + L = 5 + + input_tensor = make_tensor(torch.Size([N, C, H, W]), device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scale_factors=None) + yield SampleInput(input_tensor, output_size=torch.Size([L, L]), align_corners=False, scale_factors=None) + yield SampleInput(input_tensor, output_size=None, align_corners=False, scale_factors=[1.7, 0.9]) + yield SampleInput(input_tensor, output_size=None, align_corners=True, scale_factors=[0.8, 1.0]) + + yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scales_h=None, scales_w=None) + yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scales_h=1.7, scales_w=0.9) + yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=True, scales_h=1.7, scales_w=0.9) + +def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs): + N = 5 + for _ in range(1, N): + for approximate in ['none', 'tanh']: + yield SampleInput( + make_tensor((N * 2, N * 2), device=device, dtype=dtype, + requires_grad=requires_grad, low=-3, high=3), + approximate=approximate) + + +def error_inputs_gelu(op, device, **kwargs): + # Tests that gelu errors out when passed an approximation we don't know. + yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device), kwargs={"approximate": "asdf"}), + error_regex="approximate argument must be either") + + +def sample_inputs_max_min_reduction_with_dim(op_info, device, dtype, requires_grad, **kwargs): + args_for_reduction_with_dim = ( + ((S, S, S), (1,),), + ((S, S, S), (1, True, ),), + ((), (0,),), + ((), (0, True,),), + ) + return ((SampleInput(make_tensor(input_tensor, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad), + *args)) + for input_tensor, args in args_for_reduction_with_dim) + +def sample_inputs_max_min_reduction_no_dim(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + yield SampleInput(make_arg((S, S, S))) + yield SampleInput(make_arg(())) + +def _generate_nan_reduction_inputs(device, dtype, requires_grad, **kwargs): + yield from _generate_reduction_inputs(device, dtype, requires_grad) + # NaN only exists for floating point numbers + if dtype.is_complex or dtype.is_floating_point: + yield torch.tensor([2, torch.nan, -1], device=device, dtype=dtype, requires_grad=requires_grad) + yield torch.tensor([[torch.nan, 2], [0, 1]], device=device, dtype=dtype, requires_grad=requires_grad) + +def sample_inputs_nan_reduction(supports_multiple_dims): + # Generates sample inputs for reduction ops that contain the input tensor + # and dim and keepdim kwargs. If a reduction op needs to test additional + # args/kwargs then create a separate sample_inputs function + def fn(op_info, device, dtype, requires_grad, **kwargs): + for t in _generate_nan_reduction_inputs(device, dtype, requires_grad): + # Add case without dim and keepdim kwargs + yield SampleInput(t.clone().requires_grad_(requires_grad)) + for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims): + yield SampleInput(t.clone().requires_grad_(requires_grad), **kwargs) + + return fn + +def sample_inputs_reduction_quantile(op_info, device, dtype, requires_grad, **kwargs): + test_quantiles = (0.5, make_tensor((2,), dtype=dtype, device=device, low=0, high=1, requires_grad=requires_grad)) + test_interpolations = ['linear', 'midpoint'] + + for quantiles in test_quantiles: + for t in _generate_reduction_inputs(device, dtype, requires_grad): + # Add case without dim and keepdim kwargs + input = t.clone().requires_grad_(requires_grad) + yield SampleInput(input, quantiles) + for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims=False): + # Interpolation kwarg for now is only supported when providing both dim and keepdim + kwargs.setdefault('dim', 0) + kwargs.setdefault('keepdim', False) + for interpolation in test_interpolations: + kwargs['interpolation'] = interpolation + input = t.clone().requires_grad_(requires_grad) + yield SampleInput(input, quantiles, **kwargs) + +def sample_inputs_reduction_count_nonzero(*args, **kwargs): + """Sample inputs for count_nonzero""" + # count_nonzero does not support keepdim yet + for sample in sample_inputs_reduction(*args, **kwargs): + sample.kwargs.pop('keepdim', None) + yield sample + +def sample_inputs_leaky_relu(op_info, device, dtype, requires_grad, **kwargs): + N = 10 + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + return (SampleInput(make_arg((N, N))) for _ in range(1, N)) + +def sample_inputs_fractional_max_pool2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size + cases = (((1, 3, 9, 9), 3), + ((1, 3, 9, 9), (4, 4)), + ((1, 3, 9, 9), (6, 6)), + ((2, 3, 9, 9), (3, 3)), + ((1, 1, 4, 4), (2, 2)), + ((1, 2, 6, 6), (4, 4))) + + for input_shape, kernel_size in cases: + for return_indices in [False, True]: + # test case passing a single output size + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_size=2, + return_indices=return_indices, + ) + + # test case passing a tuple output size + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_size=(2, 3), + return_indices=return_indices, + ) + + # test case passing an output ratio + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_ratio=(0.5, 0.5), + return_indices=return_indices, + ) + + yield SampleInput( + make_arg((1, 1, 16, 16)), + (1, 1), + output_ratio=(0.5, 0.5), + return_indices=True, + _random_samples=make_tensor((1, 1, 2), device=device, dtype=dtype, requires_grad=False), + ) + +def sample_inputs_fractional_max_pool3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size + cases = (((2, 3, 5, 5, 5), (2, 2, 2)), + ((1, 2, 6, 5, 4), 2), + ((1, 2, 5, 6, 5), (2, 3, 2)), + ((1, 2, 6, 6, 6), (2, 3, 2)), + ((1, 1, 7, 6, 7), (2, 3, 4)), + ((1, 1, 4, 5, 4), (2, 2, 1)), + ((1, 1, 8, 7, 6), (4, 3, 2)), + ((0, 1, 4, 5, 4), (2, 2, 1))) + + for input_shape, kernel_size in cases: + for return_indices in [False, True]: + # test case passing a single output size + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_size=2, + return_indices=return_indices, + ) + + # test case passing a tuple output size + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_size=(2, 3, 2), + return_indices=return_indices, + ) + + # test case passing an output ratio + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_ratio=(0.5, 0.5, 0.5), + return_indices=return_indices, + ) + +def sample_inputs_avgpool2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override + cases = (((1, 3, 9, 9), 3, 1, 1, True, False, 2), + ((1, 3, 9, 9), (4, 4), (2, 3), 1, True, False, 2), + ((1, 3, 9, 9), (6, 6), (3, 3), (2, 3), True, True, 2), + ((2, 3, 9, 9), (3, 3), (1, 1), (1, ), True, False, 2), + ((1, 1, 4, 4), (2, 2), (), (0, ), False, True, -2), + ((1, 2, 6, 6), (4, 4), (2, 2), (2, ), True, True, None)) + + for input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override in cases: + yield SampleInput(make_arg(input_shape), + args=(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)) + # Case with just input_shape and kernel_size + yield SampleInput(make_arg((1, 3, 9, 9)), args=((3, 3))) + +def sample_inputs_avgpool1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size, kwargs + cases: list[tuple[tuple[int, ...], Union[int, tuple[int, ...]], dict]] = [ + ((2, 3, 9), (3,), {}), + ((1, 3, 9), 3, dict(stride=1, padding=1, ceil_mode=True, count_include_pad=False)), + ((1, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=True, count_include_pad=True)), + ((2, 3, 9), (3,), dict(stride=(1,), padding=(1,), ceil_mode=False, count_include_pad=True)), + ((0, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=False, count_include_pad=True)), + ((1, 2, 9), (7,), dict(stride=(3,), padding=(2,), ceil_mode=False)), + ((1, 2, 9), (7,), dict(stride=(3,), padding=(3,), ceil_mode=True)), + ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=False)), + ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=True)), + ] + + for input_shape, kernel_size, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs) + +def sample_inputs_avgpool3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override + cases: list[tuple[tuple[int, ...], Union[int, tuple[int, ...]], dict]] = [ + ((2, 3, 3, 4, 4), (2, 2, 2), {}), + ((1, 2, 4, 4, 4), 2, dict(stride=1, padding=1, ceil_mode=True, + count_include_pad=False, divisor_override=2)), + ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=True, + count_include_pad=True, divisor_override=2)), + ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=False)), + ((1, 1, 7, 5, 7), (6, 3, 4), dict(stride=(2, 3, 2), padding=(3, 1, 0), ceil_mode=False, + count_include_pad=False, divisor_override=2)), + ((1, 1, 4, 5, 4), (2, 2, 3), dict(stride=(2, 2, 1), padding=0, ceil_mode=False, + count_include_pad=True, divisor_override=-2)), + ((1, 1, 6, 5, 6), (4, 5, 6), dict(stride=(2, 3, 2), padding=2, ceil_mode=True, + count_include_pad=True, divisor_override=None)), + ((0, 1, 4, 5, 4), (2, 3, 1), dict(stride=(2, 1, 2), padding=0, ceil_mode=False, + count_include_pad=True, divisor_override=None)), + ] + + for input_shape, kernel_size, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs) + +def error_inputs_avg_pool1d(op_info, device, **kwargs): + # error inputs when pad is negative + x = torch.rand([0, 1, 49], dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + +def error_inputs_avg_pool2d(op_info, device, **kwargs): + # error inputs when pad is negative + x = torch.rand([0, 1, 49], dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + # 2-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + # 2-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs for zero divisor + x = torch.zeros(3, 3, 3) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2), 'divisor_override': 0}), + error_regex='divisor must be not zero') + +def error_inputs_avg_pool3d(op_info, device, **kwargs): + # error inputs when pad is negative + x = torch.rand([0, 1, 49, 50], dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + # 3-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + # 3-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs for zero divisor + x = torch.zeros(3, 3, 3, 3) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2, 2), 'divisor_override': 0}), + error_regex='divisor must be not zero') + + # error inputs for invalid input dimension + x = torch.rand([0, 1, 49], dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 0}), + error_regex='non-empty 4D or 5D') + + +def sample_inputs_to(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # test_multiple_devices_to_cuda would fail if we use a different device than given + devices = [device] + if torch.device(device).type == 'cpu': + devices = [torch.device('cpu'), torch.device('cuda:0')] if torch.cuda.is_available() else devices + memory_formats = [torch.preserve_format, torch.channels_last] + + # TODO: can't switch `to.device` overload to use positional arguments + # https://github.com/pytorch/pytorch/issues/84265 + # to.device overload + for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats): + kwargs = { + "memory_format": mem_f, + } + yield SampleInput(make_arg((S, S, S, S)), args=(device, torch.float64, nb, cp), kwargs=kwargs) + + # to.dtype overload + for nb, cp, mem_f in product([True, False], [True, False], memory_formats): + kwargs = { + "memory_format": mem_f, + } + yield SampleInput(make_arg((S, S, S, S)), args=(torch.float64, nb, cp), kwargs=kwargs) + + # to.other overload + for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats): + kwargs = { + "memory_format": mem_f, + } + other = make_arg((S, S, S, S), dtype=torch.float64, device=device) + yield SampleInput(make_arg((S, S, S, S)), args=(other, nb, cp), kwargs=kwargs) + + +def sample_inputs_topk(op_info, device, dtype, requires_grad, **kwargs): + def get_tensor_input(size): + return make_tensor(size, dtype=dtype, device=device, requires_grad=requires_grad) + + yield SampleInput(get_tensor_input((S, M, S)), 3) + yield SampleInput(get_tensor_input((S, M, S)), 3, 1) + yield SampleInput(get_tensor_input((S, M, S)), 3, -2) + yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True) + yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True) + yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True, True) + yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True, True) + + yield SampleInput(get_tensor_input(()), 1) + yield SampleInput(get_tensor_input(()), 1, 0) + yield SampleInput(get_tensor_input(()), 1, -1) + yield SampleInput(get_tensor_input(()), 1, 0, True) + yield SampleInput(get_tensor_input(()), 1, -1, True) + yield SampleInput(get_tensor_input(()), 1, 0, True, True) + yield SampleInput(get_tensor_input(()), 1, -1, True, True) + +def sample_inputs_outer(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(S), make_arg(M)) + +def sample_inputs_dist(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + sizes = ((S, S, S), (S,), (S, 1, S), (), (S, S)) + ps = (2, 4) + + for size_x, size_y, p in product(sizes, sizes, ps): + yield SampleInput(make_arg(size_x), args=(make_arg(size_y), p)) + +# Missing to test the nondeterminism of the operation +# https://github.com/pytorch/pytorch/issues/53352 +def sample_inputs_index(op_info, device, dtype, requires_grad, reference=False, **kwargs): + # target.index_add(dim, idx, source, *, alpha=1) + add = "index_add" in op_info.name + # target.index_copy(dim, idx, source) + copy = "index_copy" in op_info.name + # target.index_fill(dim, idx, value) + fill = "index_fill" in op_info.name + + # Extended reference inputs. We generate that exercise atomic adds / writing + # several times to one location + if reference: + make_arg = partial(torch.ones, device=device, dtype=dtype, requires_grad=requires_grad) + make_idx = partial(torch.zeros, device=device, dtype=torch.int64) + else: + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # idx They need to be different for copy and add to be deterministic + if copy or add: + make_idx = partial(torch.randperm, device=device, dtype=torch.int64) + else: + def make_idx(n): + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=n) + + shapes = [(), (1,), (S, S)] + # extra parameter for add + if add: + if dtype == torch.bool: + alphas = (True, False) + else: + alphas = (-1, 0, 2) + else: + alphas = (None,) + + if fill: + # A weird number to catch errors. + # The former one tests `index_fill.int_Scalar`, and the latter one tests `index_fill.int_Tensor`. + values = (make_arg((1,)).item(), make_arg(())) + else: + values = (None,) + + for shape, alpha, value in product(shapes, alphas, values): + t = make_arg(shape) + args = [] + + # dim. We handle the scalar case + dim = -1 if t.ndim == 2 else 0 + args.append(dim) + + idx = make_idx(t.shape[dim] if t.ndim != 0 else 1) + args.append(idx) + + # source + if copy or add: + args.append(make_arg(shape)) + elif fill: + args.append(value) + + args = tuple(args) + kwargs = {} if alpha is None else {"alpha": alpha} + + yield SampleInput(t, args=args, kwargs=kwargs) + +def sample_inputs_index_reduce(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_idx(n, m): + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m) + + shapes = [((), ()), ((1,), (1,)), ((S, S), (S, M)), ((S, S, S), (S, M, S))] + include_selfs = (True, False) + reduce = op_info.variant_test_name + assert reduce in ('prod', 'mean', 'amin', 'amax') + + for shape, include_self in product(shapes, include_selfs): + self_shape, src_shape = shape + # dim. We handle the scalar case + dim = 1 if len(self_shape) >= 2 else 0 + idx = make_idx(src_shape[dim] if len(src_shape) != 0 else 1, + self_shape[dim] if len(self_shape) != 0 else 1) + args = (dim, idx, make_arg(src_shape), reduce) + yield SampleInput(make_arg(self_shape), + args=args, + kwargs={'include_self' : include_self}) + + # Sample inputs to test edge cases for backward + if requires_grad and reduce == 'prod': + # Check that gradients are propagated correctly for prod when zeros in self/src are reduced + # This sample tests gradients for the following cases + # (a) 1 zero reduced (from source (self[0, 1]), from self (self[0, 0])) + # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0], self[1, 1]) + # (c) no zeros reduced (self[2, 1], self[2, 2]) + # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py + # test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad + input = torch.tensor([[0, 13], [0, 0], [15, 19]], dtype=dtype, device=device, requires_grad=requires_grad) + src = torch.tensor([[2, 0], [0, 0], [2, 3], [2, 2]], dtype=dtype, device=device, requires_grad=requires_grad) + idx = torch.tensor([0, 1, 2, 0], dtype=torch.long, device=device) + + yield SampleInput(input, + args=(0, idx, src, reduce), + kwargs={'include_self': True}) + +def sample_inputs__unsafe_masked_index(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_idx(n, m, dim, d): + view_shape = [1] * dim + view_shape[d] = n + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape) + + cases = [ + ((S, S), S, M), + ((S, S), M, S), + ((S, S, S), S, M), + ] + + fill_value = make_tensor([], dtype=dtype, device="cpu").item() + + for c in cases: + self_shape, high, idx_size = c + dim = len(self_shape) + indices = [make_idx(idx_size, high, dim, d) for d in range(dim)] + masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + yield SampleInput(make_arg(self_shape), mask, indices, fill_value) + + masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + yield SampleInput(make_arg(self_shape), mask, indices, fill_value) + +def sample_inputs__unsafe_masked_index_put_accumulate(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_idx(n, m, dim, d): + view_shape = [1] * dim + view_shape[d] = n + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape) + + cases = [ + ((S, S), S, (M, M)), + ((S, S), M, (S, S + 1)), + ((S, S, S), S, (M, M - 1, M + 1)), + ] + + for c in cases: + self_shape, high, idx_sizes = c + dim = len(self_shape) + indices = [make_idx(idx_sizes[d], high, dim, d) for d in range(dim)] + masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + values = make_arg(idx_sizes) + yield SampleInput(make_arg(self_shape), mask, indices, values) + + masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + yield SampleInput(make_arg(self_shape), mask, indices, values) + + +def sample_inputs_mode(op_info, device, dtype, requires_grad, **kwargs): + args = ( + ((S, S, S), (),), + ((S, S, S), (1, ),), + ((S, S, S), (1, True, ),), + ((), (),), + ((), (0,),), + ((), (0, True,),), + # Non-fused mode kernel on CUDA + ((3000,), ()), + ) + make_arg = partial(make_tensor, dtype=dtype, device=device, + requires_grad=requires_grad, low=None, high=None) + return (SampleInput(make_arg(input_tensor), *args) + for input_tensor, args in args) + +# Missing to test the nondeterminism of the operation +# https://github.com/pytorch/pytorch/issues/53352 +def sample_inputs_put(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False) + + S = 3 + + # Generic inputs + idx = torch.randperm(S * S, device=device, dtype=torch.int64)[:S] + idx_list = [idx, -idx - 1] + for idx, acc in product(idx_list, (True, False)): + yield SampleInput(input=make_arg((S, S)), + args=(idx.clone(), + make_arg((S,)), + acc)) + + # Scalar cases + scalar_sizes = [(), (1,)] + tgt_gen = (make_arg(size) for size in scalar_sizes) + idx_gen = (make_idx(size, high=1) for size in scalar_sizes) + src_gen = (make_arg(size) for size in scalar_sizes) + for tgt, idx, src, acc in product(tgt_gen, idx_gen, src_gen, (True, False)): + yield SampleInput(input=tgt.clone().requires_grad_(requires_grad), + args=(idx.clone(), + src.clone().requires_grad_(requires_grad), + acc)) + + # Empty cases + tgt_sizes = [(0,), (), (1,), (3, 2)] + tgt_gen = (make_arg(size) for size in tgt_sizes) + idx = make_idx((0,), high=1) + src = make_arg((0,)) + for tgt, acc in product(tgt_gen, (True, False)): + yield SampleInput(input=tgt.clone().requires_grad_(requires_grad), + args=(idx.clone(), + src.clone().requires_grad_(requires_grad), + acc)) + +def sample_inputs_take(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False) + + S = 3 + + # Generic inputs: take S elements out of S * S + index = make_idx((S,), high=(S * S)) + for idx in (index, -index - 1): + yield SampleInput(input=make_arg((S, S)), args=(idx,)) + + # Scalar cases + scalar_sizes = [(), (1,)] + src_gen = (make_arg(size) for size in scalar_sizes) + idx_gen = (make_idx(size, high=1) for size in scalar_sizes) + for src, idx in product(src_gen, idx_gen): + yield SampleInput(input=src.clone().requires_grad_(requires_grad), + args=(idx.clone(),)) + + # Empty cases + src_sizes = [(0,), (), (1,), (3, 2)] + src_gen = (make_arg(size) for size in src_sizes) + + idx = make_idx((0,), high=1) + for src in src_gen: + yield SampleInput(input=src.clone().requires_grad_(requires_grad), + args=(idx.clone(),)) + +def sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg((4, 3, 2, 1)), [0, 1, 2, 3], [3, 2, 1, 0]) + yield SampleInput(make_arg((4, 3, 2, 1)), [0, -1, -2, -3], [-3, -2, -1, -0]) + +def reference_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # shape, source, destination + args = ( + # empty inputs + ((), (), ()), + # int inputs, negative + ((3, 5, 7, 2), -2, 1), + # swap bounds + ((3, 5, 7, 2), (-1, 0), (0, -1)), + # non-sequential, negative + ((2, 3, 4, 5, 6), (3, -3, 4), (1, 0, -1)), + # idempotence, negative + ((2, 3, 4, 5, 6), (-3, 4, 3, 1), (-3, 4, 3, 1)), + # reverse, sequential, positive + ((6, 2, 3, 5, 4), (4, 3, 2, 1, 0), (0, 1, 2, 3, 4)), + # reverse, non-sequential + ((6, 2, 3, 5, 4), (-3, -2, -4, -5, -1), (2, 1, 3, 4, 0)), + # reverse, sequential, negative + ((6, 2, 3, 5, 4), (4, -2, 2, -4, -5), (-5, 1, 2, -2, -1)), + ) + + for shape, source, destination in args: + yield SampleInput(make_arg(shape), args=(source, destination)) + +def error_movedim_moveaxis(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # source length < destination length + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3), (1, 0, -1))), + error_regex=(r"movedim: Invalid source or destination dims: source " + r"\(\[3, -3\] dims\) should contain the same number of " + r"dims as destination \(\[1, 0, -1\] dims\)"), + ) + + # source length > destination length + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3, 4), (1, 0))), + error_regex=(r"movedim: Invalid source or destination dims: source " + r"\(\[3, -3, 4\] dims\) should contain the same number of " + r"dims as destination \(\[1, 0\] dims\)"), + ) + + # repeated source dim, with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 4, -5), (1, 0, 2))), + error_regex=r"movedim: repeated dim in `source` \(\[0, 4, -5\]\)", + ) + + # repeated destination dim, with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, 2), (0, 4, -5))), + error_regex=r"movedim: repeated dim in `destination` \(\[0, 4, -5\]\)", + ) + + # repeated dim (both), with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, -4), (0, 4, -5))), + error_regex=r"movedim: repeated dim in `source` \(\[1, 0, -4\]\)", + ) + + # out of bounds source inputs, with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 1, -6), (1, 4, 2))), + error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", + error_type=IndexError, + ) + + # out of bounds destination inputs, with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 4, 2), (0, 1, -6))), + error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", + error_type=IndexError, + ) + + # out of bounds source input, int + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=(-6, 1)), + error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", + error_type=IndexError, + ) + + # out of bounds destination input, int + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=(3, -6)), + error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", + error_type=IndexError, + ) + +def sample_repeat_tile(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + rep_dims = ((), (0, ), (1, ), (0, 2), (1, 1), (2, 3), (2, 3, 2), (0, 2, 3), (2, 1, 1, 1),) + shapes = ((), (0,), (2,), (3, 0), (3, 2), (3, 0, 1)) + + if requires_grad: + # Tests for variant_consistency_jit, grad, gradgrad + # are slower. Use smaller bags of `rep_dims` and `shapes` + # in this case. + rep_dims = ((), (0, ), (0, 2), (1, 1), (2, 3), (1, 3, 2), (3, 1, 1)) # type: ignore[assignment] + shapes = ((), (0,), (2,), (3, 2)) # type: ignore[assignment] + + is_repeat_op = op_info.name in ['repeat', '_refs.repeat'] + for rep_dim, shape in product(rep_dims, shapes): + # `torch.repeat` errors for `len(rep_dims) < t.dim()`, + # so we filter such combinations. + if is_repeat_op and len(rep_dim) < len(shape): + continue + yield SampleInput(make_arg(shape), rep_dim) + + +def sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): + shapes_and_args = ( + ((S, S, S), 1, 2, 2), + ((S, S, S), -1, 2, 2), + ((S, S, S), 1, 0, 0), + ((S, S, S), -1, 0, 0), + ((S, S, S), 2, 1, 2), + ) + + for shape, dim, start, length in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(tensor, dim, start, length) + # narrow also accepts the start argument being a Tensor + if is_narrow: + yield SampleInput(tensor, dim, torch.tensor(start), length) + +def reference_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): + yield from sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, is_narrow=is_narrow, **kwargs) + + shapes_and_args = ( + # 1-dim + ((M,), 0, 0, 0), # 0 elems from the left + ((M,), -1, -1, 0), # 0 elems from the right + ((M,), 0, 5, 3), # 3 elems from the left + ((M,), 0, -5, 2), # 2 elems from the right + ((M,), -1, 0, M), # M elems from the left + ((M,), 0, -M, M), # M elems from the right + + # 2-dim + ((M, S), 1, 0, 0), # dim 1, 0 elems from the left + ((S, M), -2, -1, 0), # dim 0, 0 elems from the right + ((L, S), 1, 2, 3), # dim 1, 3 elems from the left + ((L, S), -1, 3, 2), # dim 1, 2 elems from the left + ((M, L), 0, 0, M), # dim 0, M elems from the left + ((M, L), -1, -L, L), # dim 1, L elems from the right + + # 3-dim + ((L, M, S), 2, 0, 0), # dim 2, 0 elems from the left + ((M, S, L), -1, -1, 0), # dim 2, 0 elems from the right + ((S, L, M), 2, 0, M), # dim 2, M elems from the left + ((L, S, M), -1, -M, M), # dim 2, M elems from the right + ((S, L, M), 1, 0, 0), # dim 1, 0 elems from the left + ((S, L, M), 0, 2, 1), # dim 0, 1 elem from the left + ((M, S, M), -1, -5, 4), # dim 2, 4 elems from the right + ) + + for shape, dim, start, length in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(tensor, dim, start, length) + # narrow also accepts the start argument being a Tensor + if is_narrow: + yield SampleInput(tensor, dim, torch.tensor(start), length) + +def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # 0-dim + yield ErrorInput(SampleInput(make_arg(()), 0, 0, 1), + error_type=RuntimeError, + error_regex=r"narrow\(\) cannot be applied to a 0-dim tensor\.") + + # out of bounds dim + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), + error_type=RuntimeError, + error_regex=r"Expected dim < static_cast\(self_sizes.size\(\)\) to be true, but got false\.") + else: + yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got 3\)") + # out of bounds dim (negative) + yield ErrorInput(SampleInput(make_arg((L, S, M)), -4, 0, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got -4\)") + + # out of bounds start + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0), + error_type=IndexError, + error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got 11\)") + # out of bounds start (negative) + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0), + error_type=IndexError, + error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got -11\)") + + # out of bounds length + yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1), + error_type=RuntimeError, + error_regex=r"start \(0\) \+ length \(11\) exceeds dimension size \(10\)\.") + # out of bounds length (negative) + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), + error_type=RuntimeError, + error_regex=r"start \(0\) \+ length \(-1\) exceeds dimension size \(10\)\.") + else: + yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), + error_type=RuntimeError, + error_regex=r"narrow\(\): length must be non-negative\.") + + # Test Tensor overload that was added for XLA. Start must be an 0-dim + # integral Tensor. narrow_copy doesn't have this overload. + # https://github.com/pytorch/pytorch/issues/31558 + if is_narrow: + # *1-dim* integral Tensor + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, make_arg(S, dtype=torch.int), 2), + error_type=RuntimeError, + error_regex=r"start must be an 0-dim integral Tensor\.") + + # 0-dim *bool* Tensor (bools are not allowed) + yield ErrorInput(SampleInput(make_arg((L, M, S)), -3, make_arg((), dtype=torch.bool), 3), + error_type=RuntimeError, + error_regex=r"start must be an 0-dim integral Tensor\.") + + +def sample_trapezoid(op_info, device, dtype, requires_grad, **kwargs): + y_shape_x_shape_and_kwargs = [ + ((2, 3), (2, 3), {}), + ((2, 3), (2, 3), {'dim': 1}), + ((6,), (6,), {}), + ((6,), None, {}), + # When 'trapezoid' is called with an empty input, it does not produce an output with requires_grad + # See Issue #{61619} + # ((6,0), (6,0), {}), + ((2, 3), (1, 3), {}), + ((3, 3), (3, 3), {}), + ((3, 3), (3, 3), {'dim': -2}), + ((5,), None, {'dx': 2.0}), + ((2, 2), None, {'dx': 3.0}) + ] + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs: + y_tensor = make_arg(y_shape) + if x_shape is not None: + x_tensor = make_arg(x_shape) + yield SampleInput(y_tensor, x_tensor, **kwarg) + else: + yield SampleInput(y_tensor, **kwarg) + +def sample_cumulative_trapezoid(op_info, device, dtype, requires_grad, **kwargs): + + y_shape_x_shape_and_kwargs = [ + ((2, 3), (2, 3), {}), + ((2, 3), (2, 3), {'dim': 1}), + ((6,), (6,), {}), + ((6,), None, {}), + # When 'cumulative_trapezoid' is called with an empty input, it does not produce an output with requires_grad + # See Issue #{61619} + # ((6,0), (6,0), {}), + ((2, 3), (1, 3), {}), + ((3, 3), (3, 3), {}), + ((3, 3), (3, 3), {'dim': -2}), + ((5,), None, {'dx': 2.0}), + ((2, 2), None, {'dx': 3.0}) + ] + make_arg = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=None, high=None) + for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs: + y_tensor = make_arg(y_shape) + if x_shape is not None: + x_tensor = make_arg(x_shape) + yield SampleInput(y_tensor, x_tensor, **kwarg) + else: + yield SampleInput(y_tensor, **kwarg) + +def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs): + shapes_and_axes = [ + ((3, 4, 5), 0), + ((3, 4, 5), 1), + ((3, 4, 5), 3), + ((3, 4, 5), -1), + ((3, 4, 5), -3), + ((), 0), + ((), -1), + ((1,), 0), + ((1,), -1), + ] + + for shape, axis in shapes_and_axes: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(tensor, axis) + + +def sample_inputs_nn_unfold(op_info, device, dtype, requires_grad, **kwargs): + shapes = ((0, 1, 5, 5), (2, 3, 5, 5)) + kernel_sizes = (2, (2, 2), (2, 3)) + dilations = (1, 2, (1, 2)) + paddings = (0, 1, (1, 2)) + strides = (1, 2, (1, 2)) + + cases = product(shapes, kernel_sizes, dilations, paddings, strides) + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + for shape, kernel_size, dilation, padding, stride in cases: + tensor = make_arg(shape) + yield SampleInput(tensor, kernel_size, dilation, padding, stride) + + # With default args + yield SampleInput(make_arg((1, 1, 5, 5)), (3, 3)) + + +def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs): + shapes_and_args = ( + ((S, 1, S, 1), ()), + ((1, 1, 1, 1), ()), + ((1, 1, 1, 1), (0,)), + ((S, 1, S, 1), (1,)), + ((S, 1, S, 1), (-1,)), + ((S, 1, S, 1), (2,)), + ((S, 1, S, 1), (-2,)), + ((), (0, )), + ) + + for shape, args in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + + yield SampleInput(tensor, args=args) + + +def sample_inputs_squeeze_multiple(op_info, device, dtype, requires_grad, **kwargs): + shapes_and_args = ( + ((1, 1, 1, 1), ()), + ((S, 1, S, 1), (1,)), + ((S, 1, S, 1), (-1,)), + ((S, 1, S, 1), (1, 3)), + ((S, 1, S, 1), (1, 2,)), + ((), (0,)), + ) + + for shape, dims in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + + yield SampleInput(tensor, dims) + + +def _squeeze_ref(x, axis=None): + # NumPy doesn't allow squeezing scalars + if x.ndim == 0: + return x + + if isinstance(axis, Sequence): + # Numpy doesn't allow specifying non-singular dimensions + axis = tuple(a for a in axis if x.shape[a] == 1) + + if isinstance(axis, int) and x.shape[axis] != 1: + return x + + return np.squeeze(x, axis) + +def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs): + assert mode in ('constant', 'reflect', 'replicate', 'circular') + if mode in ['reflect', 'replicate']: + cases: tuple = ( # ignore + ((1, 3), (1, 2)), + ((1, 3), (0, 1)), + ((0, 3, 3), (1, 2)), + ((0, 3, 3), (0, 1)), + ((1, 3, 3), (1, 2)), + ((1, 3, 3), (0, 1)), + ((1, 3, 3), (0, 2, 0, 1)), + ((0, 3, 3, 3), (0, 2, 0, 1)), + ((3, 3, 5, 5), (0, 2, 0, 1)), + ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)), + ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), + ((1, 3, 4, 4), (-1, 1, -2, 1)), + ) + elif mode == 'constant': + cases = ( + ((1, 3), (1, 2)), + ((1, 3), (0, 1)), + ((1, 3), (0, 2, 0, 1)), + ((5, 3), (-1, -2, 1, 1)), + ((0, 3, 3), (1, 2)), + ((0, 3, 3), (0, 1)), + ((0, 3, 3), (0, 2, 0, 1)), + ((0, 3, 3), (1, 1, 1, 1, 1, 1)), + ((1, 3, 3), (1, 2)), + ((1, 3, 3), (0, 1)), + ((1, 3, 3), (0, 2, 0, 1)), + ((1, 3, 3), (1, 1, 1, 1, 1, 1)), + ((0, 3, 3, 3), (1, 2)), + ((0, 3, 3, 3), (0, 1)), + ((0, 3, 3, 3), (0, 2, 0, 1)), + ((0, 3, 3, 3), (1, 1, 1, 1, 1, 1)), + ((3, 3, 5, 5), (1, 2)), + ((3, 3, 5, 5), (0, 1)), + ((3, 3, 5, 5), (0, 2, 0, 1)), + ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)), + ((1, 3, 3, 3, 3), (1, 2)), + ((1, 3, 3, 3, 3), (0, 1)), + ((1, 3, 3, 3, 3), (0, 2, 0, 1)), + ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), + ((1, 3, 4, 4), (-1, 1, -2, 1)), + ) + else: # mode == 'circular' + if dtype == torch.bool: + # test_dtypes fails on ASAN with for the case ab + # runtime error: load of value 190, which is not a valid value for type 'bool' + # Reference: https://github.com/pytorch/pytorch/pull/62814#issuecomment-894156562 + # Reference Issue: https://github.com/pytorch/pytorch/issues/63034 + cases = ( + ((2, 3, 3), (1, 2)), + ((1, 3, 3), (1, 2)), + ) + else: + cases = ( + ((0, 3, 3), (1, 2)), + ((0, 3, 3), (0, 1)), + ((1, 3, 3), (1, 2)), + ((1, 3, 3), (0, 1)), + ((0, 3, 3, 3), (0, 2, 0, 1)), + ((3, 3, 5, 5), (0, 2, 0, 1)), + ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), + ((1, 3, 4, 4), (-1, 1, -2, 1)), + ) + + make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if mode == 'constant': + # Default args + yield SampleInput(make_inp((1, 3, 3)), args=((2, 2),)) + + if mode in ['reflect', 'replicate', 'circular']: + for shape, pad in cases: + yield SampleInput(make_inp(shape), args=(pad, mode)) + else: # mode == 'constant' + for pad_value in (1., 2.): + for shape, pad in cases: + yield SampleInput(make_inp(shape), args=(pad, mode, pad_value)) + +def sample_inputs_nn_pad_replicate_negative(op_info, device, dtype, requires_grad, **kwargs): + cases: tuple = ( + ((5, 3, 4, 4), (-4, 5, 0, 0)), + ((6, 2, 4, 4), (0, 0, 2, -4)), + ((5, 6, 4, 4), (5, -4, -4, 3)), + ((4, 2, 5, 5), (-2, -1, 4, 6)), + ((2, 6, 5, 5), (8, -1, -1, -3)), + ((8, 1, 5, 5), (-2, -1, -1, -3)), + ) + make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for shape, pad in cases: + yield SampleInput(make_inp(shape), args=(pad, 'replicate')) + +def sample_inputs_constant_pad_nd(op_info, device, dtype, *args, **kwargs): + # Inherit sample inputs from nn.pad, but transform them to fit + # constant_pad_nd's interface + nn_samples = sample_inputs_nn_pad(op_info, device, dtype, *args, + mode='constant', **kwargs) + + # NOTE: primTorch is more strict about the type of the fill value argument + # So we must cast it to the correct dtype + from torch._prims_common import dtype_to_type + scalar_type = dtype_to_type(dtype) + + def drop_mode_argument(input, pad, mode=None, value=None): + if value is None: + return SampleInput(input, args=(pad,)) + else: + return SampleInput(input, args=(pad, scalar_type(value))) + + for sample in nn_samples: + yield drop_mode_argument(sample.input, *sample.args, **sample.kwargs) + +def sample_inputs_repeat_interleave(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_input(()), repeats=2) + yield SampleInput(make_input((2, 3, 4)), repeats=2) + yield SampleInput(make_input((2, 3, 4)), repeats=2, dim=1) + yield SampleInput(make_input((2, 3, 4)), repeats=torch.arange(3, device=device), dim=1) + yield SampleInput(make_input((4, 1)), repeats=torch.arange(4, device=device), dim=0, output_size=6) + + +def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs): + def mt(shape, **kwargs): + return make_tensor(shape, device=device, dtype=dtype, + requires_grad=requires_grad, **kwargs) + + yield SampleInput(mt(100), n_fft=10, return_complex=True) + yield SampleInput(mt(100), n_fft=10, return_complex=False) + if dtype.is_complex: + yield SampleInput(mt(100), n_fft=10) + + for center in [False, True]: + yield SampleInput(mt(10), n_fft=7, center=center, return_complex=True) + yield SampleInput(mt((10, 100)), n_fft=16, hop_length=4, + center=center, return_complex=True) + + window = mt(16, low=.5, high=2.0) + yield SampleInput( + mt((2, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center)) + yield SampleInput( + mt((3, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center)) + if not dtype.is_complex: + yield SampleInput( + mt((10, 100)), n_fft=16, window=window, onesided=False, + return_complex=True) + + +def sample_inputs_istft(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def mt(shape, **kwargs): + real_shape = shape if dtype.is_complex else shape + (2,) + return make_arg(real_shape, **kwargs) + + yield SampleInput(mt((10, 2)), kwargs=dict(n_fft=10)) + yield SampleInput(mt((6, 3)), kwargs=dict(n_fft=6, onesided=False)) + yield SampleInput(mt((6, 4)), kwargs=dict(n_fft=10, onesided=True)) + + for center in [False, True]: + yield SampleInput(mt((10, 10, 6)), kwargs=dict(n_fft=10, center=center)) + yield SampleInput(mt((1, 9, 10)), kwargs=dict(n_fft=16, hop_length=4, center=center)) + + window = make_arg(10, low=.5, high=2.0) + yield SampleInput(mt((10, 10, 6)), kwargs=dict( + n_fft=10, window=window, center=center, return_complex=dtype.is_complex)) + yield SampleInput(mt((10, 10, 10)), kwargs=dict( + n_fft=10, window=window[:8], win_length=8, center=center, return_complex=True)) + + real_window = window if not dtype.is_complex else window.real + yield SampleInput(mt((10, 5, 6)), kwargs=dict(n_fft=8, window=real_window[:8], center=center)) + +def sample_inputs_ormqr(op_info, device, dtype, requires_grad, **kwargs): + # create a helper function wrapping `make_tensor` + make_input = partial(make_tensor, dtype=dtype, device=device, low=-1, high=1) + + batches = [(), (0, ), (2, ), (2, 1)] + ns = [5, 2, 0] + tf = [True, False] + for batch, (m, n), left, transpose in product(batches, product(ns, ns), tf, tf): + input = make_input((*batch, m, n)) + reflectors, tau = torch.geqrf(input) + reflectors.requires_grad_(requires_grad) + tau.requires_grad_(requires_grad) + other_matrix_shape = (m, n) if left else (n, m) + other = make_input((*batch, *other_matrix_shape), requires_grad=requires_grad) + yield SampleInput(reflectors, tau, other, left=left, transpose=transpose) + + +def sample_inputs_cholesky_solve(op_info, device, dtype, requires_grad=False, **kwargs): + cholesky_inverse_samples = sample_inputs_linalg_cholesky_inverse( + op_info, device, dtype, requires_grad=False + ) + + for sample in cholesky_inverse_samples: + psd_matrix = sample.input + sample.input = make_tensor(psd_matrix.shape, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None) + sample.args = (psd_matrix.requires_grad_(requires_grad),) + yield sample + + +def sample_inputs_lu(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_fullrank_matrices_with_distinct_singular_values, + dtype=dtype, device=device, requires_grad=requires_grad) + + # not needed once OpInfo tests support Iterables + batch_shapes = ((), (3,), (3, 3)) + for batch_shape, get_infos, size_delta in product(batch_shapes, (True, False), (-2, -1, 0, +1, +2)): + shape = batch_shape + (S + size_delta, S) + input = make_arg(*shape) + yield SampleInput(input, args=(True, get_infos)) + + +def sample_inputs_lu_unpack(op_info, device, dtype, requires_grad=False, **kwargs): + def out_fn(output): + return output[1], output[2] + + for lu_sample in sample_inputs_linalg_lu(op_info, device, dtype, requires_grad, **kwargs): + lu_data, pivots = torch.linalg.lu_factor(lu_sample.input) + lu_data.requires_grad_(requires_grad) + yield SampleInput(lu_data, pivots).with_metadata(output_process_fn_grad=out_fn) + + +def sample_inputs_roll(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + args = ((0, 0), (1, 2), (0, 2), (2, 0), (-1, 0), (10000, 1), (2,), ((1, 2, -1), (0, 1, 2))) + + for arg in args: + yield SampleInput(make_arg((0, 0, 0)), args=arg) + yield SampleInput(make_arg((S, S, S)), args=arg) + + # Scalar tensor + yield SampleInput(make_arg(()), args=(10, )) + +def error_inputs_roll(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + err_msg1 = "`shifts` required" + s1 = SampleInput(make_arg((S,)), ()) + yield ErrorInput(s1, error_regex=err_msg1) + + err_msg2 = ("shifts and dimensions must align") + s2 = SampleInput(make_arg((S, S)), (2, 1), 0) + yield ErrorInput(s2, error_regex=err_msg2) + + err_msg3 = ("out of range") + s3 = SampleInput(make_arg((S, )), 0, 2) + yield ErrorInput(s3, error_regex=err_msg3, error_type=IndexError) + + err_msg4 = ("Dimension specified as 0") + s4 = SampleInput(make_arg(()), 0, 0) + yield ErrorInput(s4, error_regex=err_msg4, error_type=IndexError) + +def sample_inputs_rot90(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + args = itertools.product(range(-5, 6), [(0, 1), (1, 2), (1, -1)]) + + yield SampleInput(make_arg((S, S, S))) + for arg in args: + yield SampleInput(make_arg((S, S, S)), args=arg) + + +def error_inputs_rot90(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + err_msg1 = "expected total rotation dims" + s1 = SampleInput(make_arg((S, S)), dims=(0,)) + yield ErrorInput(s1, error_regex=err_msg1) + + err_msg2 = "expected total dims >= 2" + s2 = SampleInput(make_arg((S,))) + yield ErrorInput(s2, error_regex=err_msg2) + + err_msg3 = "expected rotation dims to be different" + s3 = SampleInput(make_arg((S, S)), dims=(1, 1)) + yield ErrorInput(s3, error_regex=err_msg3) + + +def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs): + tensor_nd = partial(make_tensor, (S, S, S), device=device, dtype=dtype, + requires_grad=requires_grad) + tensor_1d = partial(make_tensor, (S,), device=device, dtype=dtype, + requires_grad=requires_grad) + + yield SampleInput(tensor_nd()) + yield SampleInput(tensor_nd(), dim=1) + yield SampleInput(tensor_nd(), dim=1, unbiased=True, keepdim=True) + yield SampleInput(tensor_1d(), dim=0, unbiased=True, keepdim=True) + yield SampleInput(tensor_1d(), dim=0, unbiased=False, keepdim=False) + + yield SampleInput(tensor_nd(), dim=(1,), correction=1.3) + yield SampleInput(tensor_nd(), dim=(1,), correction=S // 2) + yield SampleInput(tensor_nd(), dim=None, correction=0, keepdim=True) + yield SampleInput(tensor_nd(), dim=None, correction=None) + yield SampleInput(tensor_nd(), dim=None, correction=-1) + yield SampleInput(tensor_nd(), dim=None, correction=-5) + yield SampleInput(tensor_nd(), correction=0.5, keepdim=True) + yield SampleInput(tensor_nd(), correction=0, keepdim=True) + yield SampleInput(make_tensor(3, 4, 5, device=device, dtype=dtype, requires_grad=requires_grad), dim=-3) + + +def sample_inputs_std_var_unbiased(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad) + + # Test var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + yield SampleInput(make_arg((S, S)), True) + yield SampleInput(make_arg((S,)), False) + + +def _generate_correlation_inputs(device, dtype, requires_grad, **kwargs): + shapes = [(2,), (1, 2), (3, 2), (2, 3)] + for shape in shapes: + yield make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad) + + +def sample_inputs_corrcoef(op_info, device, dtype, requires_grad, **kwargs): + return (SampleInput(t) for t in _generate_correlation_inputs(device, dtype, requires_grad)) + +def sample_inputs_copysign(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_binary(op_info, device, dtype, requires_grad, **kwargs) + if dtype.is_floating_point: + yield SampleInput(make_tensor(5, dtype=dtype, device=device, requires_grad=requires_grad), -3.14) + + +def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs): + for t in _generate_correlation_inputs(device, dtype, requires_grad): + yield SampleInput(t) + num_observations = t.numel() if t.ndimension() < 2 else t.size(1) + fweights = make_tensor((num_observations,), dtype=torch.int, device=device, low=1, high=10) + aweights = make_tensor((num_observations,), dtype=torch.float, device=device, low=0, high=1, requires_grad=requires_grad) + for correction, fw, aw in product(range(num_observations), [None, fweights], [None, aweights]): + yield SampleInput(t.clone().requires_grad_(requires_grad), + correction=correction, fweights=fw, aweights=aw) + + +def error_inputs_cov(op_info, device, **kwargs): + a = torch.rand(S, device=device) + yield ErrorInput( + SampleInput(torch.rand(S, S, S, device=device)), + error_regex="expected input to have two or fewer dimensions") + yield ErrorInput( + SampleInput(a, fweights=torch.rand(S, S, device=device)), + error_regex="expected fweights to have one or fewer dimensions") + yield ErrorInput( + SampleInput(a, aweights=torch.rand(S, S, device=device)), + error_regex="expected aweights to have one or fewer dimensions") + yield ErrorInput( + SampleInput(a, fweights=torch.rand(S, device=device)), + error_regex="expected fweights to have integral dtype") + yield ErrorInput( + SampleInput(a, aweights=torch.tensor([1, 1], device=device)), + error_regex="expected aweights to have floating point dtype") + yield ErrorInput( + SampleInput(a, fweights=torch.tensor([1], device=device)), + error_regex="expected fweights to have the same numel") + yield ErrorInput( + SampleInput(a, aweights=torch.rand(1, device=device)), + error_regex="expected aweights to have the same numel") + yield ErrorInput( + SampleInput(a, fweights=torch.tensor([-1, -2, -3, -4 , -5], device=device)), + error_regex="fweights cannot be negative") + yield ErrorInput( + SampleInput(a, aweights=torch.tensor([-1., -2., -3., -4., -5.], device=device)), + error_regex="aweights cannot be negative") + + +def sample_inputs_permute(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = [((1, 2, 3, 4), (0, 2, 3, 1)), + ((1, 2, 3, 4), (0, -2, -1, 1)), + ((), ()), + ((1, 2, 3, 4), (2, 1, 3, 0))] + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=(args,)) + +def reference_inputs_permute(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_permute(op, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + ((), ()), + ((1,), (0,)), + ((2, 2), (1, 0)), + ((2, 2), (0, 1)), + ((2, 0, 1), (0, 2, 1)), + ((3, 4, 2), (2, 1, 0)), + ((3, 4, 2), (1, 0, 2)), + ((3, 4, 2), (0, 1, 2)), + ) + + # Adds tricky permutations and permutations with noncontiguity + for shape, permutation in cases: + for p in itertools.permutations(permutation): + a = make_arg(shape).permute(p) + yield SampleInput(a, args=(permutation,)) + + a = make_arg(shape, noncontiguous=True).permute(p) + yield SampleInput(a, args=(permutation,)) + +def error_inputs_softshrink(op, device, **kwargs): + yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"lambd": -0.5}), + error_regex=r"lambda must be in range \[0,.*input dtype.*found -0\.5") + +def sample_inputs_softshrink(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # The additional sample is to check additional values of lambd beyond the default + # value (what is already checked by sample_inputs_elementwise_unary) + for lbda in (0., 0.5): + yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda}) + + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) + +def sample_inputs_hardshrink(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # The additional sample is to check additional values of lambd beyond the default + # value (what is already checked by sample_inputs_elementwise_unary) + # Note that unlike softshrink, lambd is allowed to be negative for hardshrink + for lbda in (-0.5, 0., 0.5): + yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda}) + + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) + + +def sample_inputs_hardtanh(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # The additional sample is to check additional values of min_val and max_val beyond the default + # value (what is already checked by sample_inputs_elementwise_unary) + for max_val, min_val in ((0.5, -0.5), (0., 0.)): + yield SampleInput(make_arg(S, S), kwargs={"min_val": min_val, "max_val": max_val}) + + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) + +def error_inputs_hardtanh(op_info, device, **kwargs): + # Tests that hardtanh errors out when passed min_val > max_val. + yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"min_val": 0.5, "max_val": -0.5}), + error_type=ValueError, error_regex="min_val cannot be greater than max_val") + +def sample_inputs_einsum(op_info, device, dtype, requires_grad=False, **kwargs): + def c(t): + return t.clone().requires_grad_(requires_grad) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + x = make_arg((3,)) + y = make_arg((4,)) + A = make_arg((2, 3,)) + B = make_arg((1, 3,)) + C = make_arg((1, 2, 3,)) + D = make_arg((1, 3, 4,)) + E = make_arg((4, 4,)) + H = make_arg((3, 3,)) + I = make_arg((1, 3, 1,)) + + # Vector operations + yield SampleInput([c(x)], 'i->') # sum + yield SampleInput([c(x), c(y)], 'i,j->ij') # outer + + # Matrix operations + yield SampleInput([c(A)], "ij->i") # col sum + yield SampleInput([c(A), c(B)], "ij,kj->ik") # matmul + yield SampleInput([c(A), c(E)], "ij,Ab->ijAb") # matrix outer product + + # Tensor operations + yield SampleInput([c(C), c(D)], "aij,ajk->aik") # batch matmul + yield SampleInput([c(D), c(E)], "aij,jk->aik") # tensor matrix contraction + yield SampleInput([c(C), c(B)], "ijk,ik->j") # non contiguous + + # Test diagonals + yield SampleInput([c(I)], 'iji->j') # non-contiguous trace + + # Test ellipsis + yield SampleInput([c(H)], "i...->...") + yield SampleInput([c(C), c(x)], '...ik, ...j -> ij') + + +def sample_inputs_flip(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + sizes = ((S, M, S), (S, 0, M)) + all_dims = ((0, 1, 2), (0,), (0, 2), (-1,), ()) + + for size, dims in product(sizes, all_dims): + yield SampleInput(make_arg(size), kwargs={"dims": dims}) + +def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad, **kwargs): + shapes = [ + (S, M, S), + (S, 0, M), + ] + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + return (SampleInput(make_arg(shape, low=None, high=None)) for shape in shapes) + +def error_inputs_fliplr(op, device, **kwargs): + yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device)), + error_regex="Input must be >= 2-d.") + +def error_inputs_flipud(op, device, **kwargs): + yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device)), + error_regex="Input must be >= 1-d.") + +def sample_inputs_clamp(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) + make_integral_arg = partial(make_tensor, dtype=torch.int32, device=device, low=None, high=None, requires_grad=False) + shape = (S, M, S) + + yield SampleInput(make_arg(shape), args=(make_arg(shape), make_arg(shape))) + yield SampleInput(make_arg(shape), args=(make_arg(shape[1:]), make_arg(shape[1:]))) + yield SampleInput(make_arg(shape), args=(make_arg((S, 1, S)),)) + yield SampleInput(make_arg(shape), args=(None, make_arg(shape))) + yield SampleInput(make_arg(shape), args=(make_arg(shape), None)) + # test type promotion + yield SampleInput(make_arg(shape), args=(make_integral_arg(shape), None)) + yield SampleInput(make_arg(shape), args=(make_arg(shape), make_integral_arg(shape))) + +def reference_inputs_elementwise_ternary(op, device, dtype, requires_grad, *, sample_inputs_func, supports_scalars=False, **kwargs): + yield from sample_inputs_func(op, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_scalar_tensor = partial(make_tensor, (), device='cpu', dtype=dtype, requires_grad=requires_grad) + supported_dtypes = op.supported_dtypes(device) + + # broadcasting and oncontiguous cases + cases = ( + ((4, 4), (4, 4), (4, 4)), + ((4, 4), (1, 4, 4), (4, 4)), + ((4, 4), (1, 4, 4), (4, 1, 4)), + ((4, 4, 1), (1, 4, 4), (4, 4)), + ((4, 1), (1, 4, 4), (1, 4)), + ((4, 4), (), (4, 4)), + ((4, 4), (), ()), + ((), (4, 4), (1, 4, 4)), + ) + + for a, b, c in cases: + yield SampleInput(make_arg(a), args=(make_arg(b), make_arg(c))) + yield SampleInput(make_arg(a, noncontiguous=True), + args=(make_arg(b).transpose(0, -1), make_arg(c, noncontiguous=True).transpose(0, -1))) + + # scalar cases + if supports_scalars: + cases = [ + ((), 1, 2,), + ((), 1., 2), + ((4, 4), 1., 2,), + ((3, 4), make_scalar_tensor(), make_scalar_tensor()), + ] + + if torch.complex64 in supported_dtypes: + cases.extend([ + ((3, 1, 4), complex(1, 2), 3.), + ]) + + for a, b, c in cases: + yield SampleInput(make_arg(a), args=(b, c)) + + # type promotion cases + # int x float + if torch.float in supported_dtypes and torch.long in supported_dtypes: + a = make_arg((), dtype=torch.long) + b = make_arg((1, 4), dtype=torch.float) + c = make_arg((3, 4)) + + cases = ( + (a, b, c), + (c, a, b), + ) + + for a, b, c in cases: + yield SampleInput(a, args=(b, c)) + + # NaN propagation + if dtype.is_floating_point or dtype.is_complex: + nan = float('nan') if dtype.is_floating_point else complex(float('nan'), float('nan')) + + a = make_arg((12,)) + a[4] = nan + a[7] = nan + b = make_arg((12,)) + b[1] = nan + b[7] = nan + c = make_arg((12,)) + c[9] = nan + + yield SampleInput(a, args=(b, c)) + + +def _clamp_min_numpy(a, min=None): + return np.maximum(a, min) + + +def _clamp_max_numpy(a, max=None): + return np.minimum(a, max) + + +def _clamp_numpy(a, min=None, max=None): + if min is None: + return np.minimum(a, max) + if max is None: + return np.maximum(a, min) + + return np.minimum(max, np.maximum(a, min)) + + +def sample_inputs_cumprod(op_info, device, dtype, requires_grad, **kwargs): + def make_arg(shape): + # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck + return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad) + + def prod_zeros(dim_select): + assert len(dim_select) == 2 + result = make_arg(3 * (S,)) + result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_() + result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_() + result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_() + return result + + for dim in range(3): + yield SampleInput(make_arg((S, S, S)), args=(dim,)) + # Scalar tensors and empty tensor + for size in [(), (1,), (0,)]: + yield SampleInput(make_arg(size), args=(0,)) + + yield SampleInput(prod_zeros([0, 1]), args=(1,)) + yield SampleInput(prod_zeros([0, 2]), args=(1,)) + yield SampleInput(prod_zeros([1, 2]), args=(1,)) + + # test dtype kwarg + yield SampleInput(prod_zeros([1, 2]), args=(1,), kwargs={'dtype': dtype}) + +def sample_inputs_view_as_complex(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput(make_tensor((S, 2), dtype=dtype, device=device, requires_grad=requires_grad)) + +def sample_inputs_view_as_real(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + sizes = ((S, S), ()) + return (SampleInput(make_arg(size)) for size in sizes) + +def error_inputs_complex(op_info, device, is_ref=False, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + other_dtype = torch.float16 if device.startswith("mps") else torch.float64 + other_dtype_name = "Half" if device.startswith("mps") else "Double" + + if is_ref: + error_float = "Expected both inputs to be Half, Float or Double tensors but got torch.float32 and torch.int32" + error_dtype = "Expected object of scalar type torch.float32 but got scalar type torch.float64 for second argument" + error_out = "Expected out tensor to have dtype torch.complex128 but got torch.complex64 instead" + else: + error_float = "Expected both inputs to be Half, Float or Double tensors but got Float and Int" + error_dtype = f"Expected object of scalar type Float but got scalar type {other_dtype_name} for second argument" + error_out = f"Expected object of scalar type Complex{other_dtype_name} but got scalar type ComplexFloat for argument 'out'" + + yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.int)), + error_type=RuntimeError, error_regex=error_float) + + yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=other_dtype)), + error_type=RuntimeError, error_regex=error_dtype) + + yield ErrorInput(SampleInput(make_arg(M, S, dtype=other_dtype), make_arg(M, S, dtype=other_dtype), + out=make_arg(M, S, dtype=torch.complex64)), + error_type=RuntimeError, error_regex=error_out) + +def sample_inputs_logaddexp(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + shape = (S, S) + yield SampleInput(make_arg(shape), make_arg(shape)) + +def sample_inputs_prod(op_info, device, dtype, requires_grad, **kwargs): + def make_arg(shape): + # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck + return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad) + + def prod_single_zero(): + result = make_arg(2 * (S,)) + result[0, 1] = 0 + return result + + for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad): + # only Tensor, ignore other inputs + yield SampleInput(sample.input.clone().requires_grad_(requires_grad)) + yield sample + + # Generates samples with keepdim = True + for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad): + sample.kwargs['keepdim'] = True + yield sample + + yield SampleInput(prod_single_zero()) + yield SampleInput(make_arg((3, 3, 3)), args=(1,)) + yield SampleInput(make_arg((3, 3, 3)), args=(1,), kwargs={'keepdim': True}) + + yield SampleInput(make_arg((3, 0)), args=(1,)) + yield SampleInput(make_arg((3, 0)), args=(1,), kwargs={'keepdim': True}) + yield SampleInput(torch.tensor([2., 3, 0, 0], dtype=dtype, device=device, requires_grad=requires_grad)) + + # test zero scalar tensor + zero = make_arg(()) + zero.zero_() + yield SampleInput(zero.clone().requires_grad_(requires_grad)) + yield SampleInput(zero.clone().requires_grad_(requires_grad), args=(0,)) + yield SampleInput(zero.clone().requires_grad_(requires_grad), + args=(0,), + kwargs={'keepdim': True}) + +def error_inputs_neg(op_info, device, **kwargs): + si = SampleInput(torch.tensor((False, True), device=device)) + msg = ("Negation, the `\\-` operator, on a bool tensor is not supported." + " If you are trying to invert a mask, use the `\\~` or" + " `logical_not\\(\\)` operator instead.") + yield ErrorInput(si, error_regex=msg) + +def sample_inputs_diag(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + yield SampleInput(make_arg(M)) + + tensors = ( + make_arg((M, M)), + make_arg((3, 5)), + make_arg((5, 3)), + ) + + args = ((), (2,), (-2,), (1,), (2,)) + + for tensor, arg in product(tensors, args): + yield SampleInput(tensor.clone().requires_grad_(requires_grad), *arg) + +def reference_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_diagonal_diag_embed( + op_info, device, dtype, requires_grad, **kwargs) + + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes1d = ((0,), (1,)) + shapes2d = ((L, M),) + shapes3d = ((L, M, S),) + + kwargs1d = {} + + kwargs2d = ( + # dim1 > dim2 is allowed + dict(dim1=1, dim2=0), + # negative dims are allowed + dict(dim1=-2, dim2=-1), + # one dim negative and the other nonnegative is allowed + dict(dim1=-1, dim2=0), + # out of bounds offset should return an empty tensor in diagonal and + # offset the diagonal in diag_embed + dict(offset=100), + ) + + kwargs3d = kwargs2d + ( + # make sure we can use non-sequential dims + dict(offset=-1, dim1=0, dim2=2), + ) + + samples1d = product(shapes1d, kwargs1d) + samples2d = product(shapes2d, kwargs2d) + samples3d = product(shapes3d, kwargs3d) + + for shape, kwargs in chain(samples1d, samples2d, samples3d): + if 'diagonal' in op_info.name: + # these are error inputs for diagonal + if shape in ((0,), (1,)): + continue + yield SampleInput(input=make_arg(shape), kwargs=kwargs) + + +def sample_inputs_diagonal_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # Shapes for 2D Tensors + shapes_2d = ((M, M), (3, 5), (5, 3)) + + # Shapes for 3D Tensors + shapes_3d = ((M, M, M),) + + args_2d = ((), (2,), (-2,), (1,)) + args_3d = ((1, 1, 2), (2, 0, 1), (-2, 0, 1)) + + for input_shape, arg in chain(product(shapes_2d, args_2d), product(shapes_3d, args_3d)): + input_ = make_arg(input_shape) + # We can programmatically figure out the right shape for src: + # It should be the same size as input.diagonal(other_args...) + if not isinstance(arg, tuple): + arg_tuple = (arg,) + else: + arg_tuple = arg + src_shape = input_.diagonal(*arg_tuple).size() + src = make_arg(src_shape) + yield SampleInput(input_, args=(src, *arg_tuple)) + + +def sample_inputs_to_sparse(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S))).with_metadata(output_process_fn_grad=lambda x: x.to_dense()) + yield SampleInput(make_arg((S, S)), 1).with_metadata(output_process_fn_grad=lambda x: x.to_dense()) + +def sample_inputs_cross_entropy(op_info, device, dtype, requires_grad, **kwargs): + batch_size, num_classes = shape = (2, 3) + reductions = ("mean", "sum", "none") + + input_shape_and_kwargs: list[tuple[tuple[int, ...], dict[str, Any]]] = [ + (shape, {}), + ((*shape, 1), {}), + ((*shape, 1, 2), {}), + ((*shape, 1, 2, 3), {}), + *[(shape, dict(reduction=reduction)) for reduction in reductions], + *[ + ( + shape, + dict( + weight=make_tensor((num_classes,), device=device, dtype=dtype), + reduction=reduction, + ), + ) + for reduction in reductions + ], + (shape, dict(ignore_index=1)), + ] + + for (input_shape, kwargs), probabilities_target in itertools.product(input_shape_and_kwargs, (False, True)): + input = make_tensor(input_shape, device=device, dtype=dtype, requires_grad=requires_grad) + + if probabilities_target: + # ignore_index is not supported for probabilities target + if "ignore_index" in kwargs: + continue + + target = make_tensor( + input_shape, + low=0, + high=1, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + else: + target = make_tensor( + (batch_size, *input_shape[2:]), + low=0, + high=num_classes, + device=device, + dtype=torch.long, + ) + + if "ignore_index" in kwargs and torch.all(target == kwargs["ignore_index"]): + # make sure at least one item in target is not ignored + target[0] = random.sample(sorted(set(range(num_classes)) - {kwargs["ignore_index"]}), 1)[0] + + yield SampleInput(input, target, **kwargs) + + +def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs): + low, high = op_info.domain + + # Note: Operator is very sensitive at points near the + # start and end of domain and leads to NaN for float16 + # if domain_eps is 1e-5. + if dtype.is_floating_point or dtype.is_complex: + domain_eps = op_info._domain_eps if dtype != torch.float16 else 3e-2 + + low = low + domain_eps + high = high - domain_eps + + make_arg = partial(make_tensor, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S, S))) + yield SampleInput(make_arg((S, S, S)), 0.2) + yield SampleInput(make_arg(())) + yield SampleInput(make_arg(()), 0.2) + +def sample_inputs_isin(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # isin has two paths based on the size of elements and test_elements. + # if elements.numel() < 10 * pow(test_elements.numel(), 0.145): + yield SampleInput(make_arg((L,)), args=(make_arg((S,)),)) + # else: + yield SampleInput(make_arg((S,)), args=(make_arg((L,)),)) + +def sample_inputs_masked_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg((S, S)))) + yield SampleInput(make_arg((S, S)), args=(torch.randn((S,), device=device) > 0, make_arg((S, S)))) + yield SampleInput(make_arg((S, S)), args=(bernoulli_scalar().to(device), make_arg((S, S)))) + yield SampleInput(make_arg((S,)), + args=(torch.randn(S, S, device=device) > 0, make_arg((S, S))), + broadcasts_input=True) + +def error_inputs_masked_scatter(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float) + for mask_dtype in [torch.float, torch.uint8]: + yield ErrorInput(SampleInput(make_arg(1, 3), args=(torch.ones(1, 3, device=device, dtype=mask_dtype), + make_arg(3, 4))), + error_regex=r"masked_scatter_ only supports boolean masks") + +def sample_inputs_masked_fill(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, 10)) + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg(()))) + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, device=device) > 0, 10)) + yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, 10)) + yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, make_arg(()))) + yield SampleInput(make_arg((S, S)), args=(torch.randn((), device=device) > 0, 10)) + + yield SampleInput(make_arg((S,)), + args=(torch.randn(S, S, device=device) > 0, make_arg(())), + broadcasts_input=True) + yield SampleInput(make_arg((S,)), + args=(torch.randn(S, S, device=device) > 0, 10), + broadcasts_input=True) + + if torch.device(device).type == 'cuda': + # `self` and `mask` on CUDA but `value` is a CPU scalar tensor. + yield SampleInput(make_arg((S, S)), + args=(torch.randn(S, S, device=device) > 0, + make_tensor((), device="cpu", dtype=dtype))) + +def error_inputs_masked_fill(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + # `value` is not a 0-D tensor. + yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, make_arg((1,)))), + error_regex="only supports a 0-dimensional value tensor, but got tensor with 1 dimension") + # downcasting complex value (scalar overload) + yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, 1j)), + error_regex=r"value cannot be converted to type .* without overflow") + # downcasting complex value (tensor overload) + yield ErrorInput(SampleInput(torch.ones(2, dtype=torch.long, device=device), + args=(make_arg(()) > 0, torch.tensor(1j, device=device))), + error_regex=r"value cannot be converted to type .* without overflow") + + if torch.device(device).type == 'cuda': + # `self` and `mask` on CPU but `value` is a CUDA scalar tensor. + yield ErrorInput(SampleInput(torch.randn((S, S), device='cpu'), + args=(torch.randn(S, S, device='cpu') > 0, + torch.randn((), device='cuda'))), + error_regex=r"to be on same device") + + +def sample_inputs_masked_select(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + + yield SampleInput(make_arg((M, M)), torch.randn(M, M, device=device) > 0) + + yield SampleInput(make_arg((M, M)), torch.randn((M,), device=device) > 0) + yield SampleInput(make_arg((M,)), torch.randn((M, M), device=device) > 0) + + yield SampleInput(make_arg((M, 1, M)), torch.randn((M, M), device=device) > 0) + + yield SampleInput(make_arg(()), torch.tensor(1, device=device, dtype=torch.bool)) + + yield SampleInput(make_arg((M, M)), torch.tensor(1, device=device, dtype=torch.bool)) + + yield SampleInput(make_arg(()), torch.randn((M, M), device=device) > 0) + +def sample_inputs_matrix_exp(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg((S, S))) + yield SampleInput(make_arg((S, S, S))) + +def sample_inputs_matmul(op_info, device, dtype, requires_grad, is_rmatmul=False, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, + high=None, requires_grad=requires_grad) + test_cases = (((L,), (L,)), + ((S, M), (M,)), + ((M,), (M, S)), + ((S, M), (M, S)), + ((S, 0), (0, M)), + ((S, S, M), (M,)), + ((S, S, M), (M, S)), + ((S, S, 0), (0, S)), + ((M,), (S, M, S)), + ((S, M), (S, M, S)), + ((0, 0), (S, 0, 0)), + ((S, S, M, M), (S, S, M, S)), + ((S, S, M, M), (M,)), + ((M,), (S, S, M, S)), + ((S, S, S), (1, S, S)) + ) + for lhs_shape, rhs_shape in test_cases: + lhs = make_arg(lhs_shape) + rhs = make_arg(rhs_shape) + if not is_rmatmul: + yield SampleInput(lhs, rhs) + else: + yield SampleInput(rhs, lhs) + + +def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.dtype, + requires_grad: bool, + *, variant: str, **kwargs) -> list[SampleInput]: + if variant == 'variadic': + def make_inputs( + tensors: list[torch.Tensor]) -> tuple[Union[torch.Tensor, + list[torch.Tensor]], + tuple[torch.Tensor, ...]]: + return tensors + elif variant == 'list': + def make_inputs( + tensors: list[torch.Tensor]) -> tuple[Union[torch.Tensor, + list[torch.Tensor]], + tuple[torch.Tensor, ...]]: + return [tensors] + else: + raise ValueError( + 'Unsupported variant, must be one of {"variadic", "list"}. ' + f'Got "{variant}".') + + SCALAR = torch.Size([]) + VECTOR = torch.Size([3]) + test_cases: list[list[torch.Size]] = [ + [SCALAR], + [VECTOR], + [VECTOR, SCALAR], + [VECTOR, SCALAR, VECTOR], + [VECTOR, SCALAR, VECTOR, SCALAR], + ] + + for shapes, indexing in itertools.product(test_cases, {'xy', 'ij'}): + args = make_inputs( + [make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in shapes]) + yield SampleInput(*args, indexing=indexing) + + +def sample_inputs_mvlgamma(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + tensor_shapes = ((S, S), ()) + ns = (1, 2, 3, 4, 5) + + # Since the accepted lower bound for input + # to mvlgamma depends on `p` argument, + # the following function computes the lower bound + # which we pass to `make_tensor`. + def compute_min_val(p): + return (p - 1.) / 2 + + for shape, n in product(tensor_shapes, ns): + min_val = compute_min_val(n) + if not dtype.is_floating_point: + # Round-up minimum value for integral dtypes + min_val += 1 + else: + min_val += 2 * torch.finfo(dtype).eps + yield SampleInput(make_arg(shape, low=min_val), args=(n,)) + + +# Since `mvlgamma` has multiple entries, +# there are multiple common skips for the additional +# entries. Following function is a helper to that end. +def skips_mvlgamma(skip_redundant=False): + skips = ( + # outside domain values are hard error for mvlgamma op. + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_float_domains'), + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', + 'test_reference_numerics_extremal'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.float16, torch.int8)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + dtypes=(torch.int8,)), + ) + if skip_redundant: + # Redundant tests + skips = skips + ( # type: ignore[assignment] + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), + ) + return skips + + +# To test reference numerics against multiple values of argument `p`, +# we make multiple OpInfo entries with each entry corresponding to different value of p. +# We run the op tests from test_ops.py only for `p=1` to avoid redundancy in testing. +def make_mvlgamma_opinfo(variant_test_name, domain, skips, sample_kwargs): + return UnaryUfuncInfo('mvlgamma', + ref=reference_mvlgamma if TEST_SCIPY else None, + aliases=('special.multigammaln',), + variant_test_name=variant_test_name, + domain=domain, + decorators=(precisionOverride({torch.float16: 5e-2}),), + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_mvlgamma, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=skips, + sample_kwargs=sample_kwargs) + + +def sample_inputs_cumulative_ops(op_info, device, dtype, requires_grad, supports_dtype_kwargs=True, **kwargs): + def _make_tensor_helper(shape, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + yield SampleInput(_make_tensor_helper((S, S, S)), 0) + yield SampleInput(_make_tensor_helper((S, S, S)), 1) + yield SampleInput(_make_tensor_helper(()), 0) + + if supports_dtype_kwargs: + # NOTE: if `dtype` is not same as input, then inplace variants fail with + # `provided dtype must match the dtype of self tensor in cumsum` + yield SampleInput(_make_tensor_helper((S, S, S)), 1, dtype=dtype) + + +def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs): + test_cases = ( + ((), (0, 1, 1)), + ((S, S, S, S), (0, 3, 1)), + ((S, S, S, S), (1, 3, 1)), + ((S, S, S, S), (2, 3, 1)), + ((S, S, S, S), (3, 3, 1)), + ((S, S, S, S), (0, 3, 2)), + ((S, S, S, S), (1, 3, 2)), + ((S, S, S, S), (2, 3, 2)), + ((S, S, S, S), (3, 3, 2)), + ((S, S, S, S), (0, 4, 1)), + ((S, S, S, S), (1, 4, 1)), + ((S, S, S, S), (2, 4, 1)), + ((S, S, S, S), (3, 4, 1)), + ((M,), (0, 3, 1)), + ((M,), (0, 3, 2)), + ((M,), (0, 3, 3)), + ((1000,), (0, 3, 11)), + ((1000,), (0, 2, 27)), + ((10, 10), (0, 1, 2)), + ((10, 10), (1, 2, 3)), + ((10, 10), (1, 2, 2)), + ((S, S, S), (2, 3, 2)), + ) + + for shape, arguments in test_cases: + yield SampleInput(make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad), + *arguments) + +def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if list_args: + cases = ( + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), + ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), 2),), + ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), -2),) + ) + else: + cases = ( # type: ignore[assignment] + ((S, S, S), (2,)), + ((S, S, S), (S, 1)), + ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + +def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = (((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3), 0]),)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), 2)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), -2)), + ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + +def sample_inputs_msort(op_info, device, dtype, requires_grad, **kwargs): + def apply_grad(t): + if dtype in floating_types_and(torch.float16, torch.bfloat16): + t.requires_grad_(requires_grad) + + def large_1d_unique(dtype, device): + res = torch.randperm(L * L * L, dtype=torch.int64, device=device) + res = res.to(dtype) + apply_grad(res) + return res + + # Test case for large tensor. + yield SampleInput(large_1d_unique(dtype, device)) + + yield SampleInput(make_tensor((S, M, S), dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad)) + +def sample_inputs_lerp(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # no broadcast + yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4) + # broadcast rhs + yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4) + # scalar tensor + yield SampleInput(make_arg(()), make_arg(()), 0.4) + # broadcast rhs scalar-tensor + yield SampleInput(make_arg((S, S)), make_arg(()), 0.4) + # broadcast rhs with weight tensor + yield SampleInput(make_arg((S, S)), make_arg((S,)), make_arg((S, S))) + # broadcast rhs and weight tensor + yield SampleInput(make_arg((S, S)), make_arg((S, 1)), make_arg((S,))) + # broadcast lhs + yield SampleInput(make_arg((S,)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True) + # scalar broadcast_lhs + yield SampleInput(make_arg(()), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True) + # broadcast all + yield SampleInput(make_arg((S, 1)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True) + # tensor broadcast all + yield SampleInput(make_arg((S, 1)), make_arg((S, S)), make_arg((S, 1))).with_metadata( + broadcasts_input=True) + # no broadcast with weight tensor + yield SampleInput(make_arg((S, S)), make_arg((S, S)), make_arg((S, S))) + # broadcast lhs with weight tensor + yield SampleInput(make_arg((S,)), make_arg((S, S)), make_arg((S, S))).with_metadata( + broadcasts_input=True) + # broadcast lhs and weight tensor + yield SampleInput(make_arg((S,)), make_arg((S, S, S)), make_arg((S, S))).with_metadata( + broadcasts_input=True) + # broadcast lhs and weight tensor variant + yield SampleInput(make_arg((S, S)), make_arg((S, S, S)), make_arg((S,))).with_metadata( + broadcasts_input=True) + + if dtype.is_complex: + # no broadcast + yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4j) + yield SampleInput(make_arg((S, S)), make_arg((S, S)), 1.2 + 0.1j) + # broadcast rhs + yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4j) + yield SampleInput(make_arg((S, S)), make_arg((S, S)), 5.4 + 9j) + # scalar tensor + yield SampleInput(make_arg(()), make_arg(()), 0.4j) + yield SampleInput(make_arg(()), make_arg(()), 6.1 + 0.004j) + # broadcast rhs scalar-tensor + yield SampleInput(make_arg((S, S)), make_arg(()), 0.4j) + yield SampleInput(make_arg((S, S)), make_arg(()), 1 + 2j) + +def sample_inputs_tensordot(self, device, dtype, requires_grad, **kwargs): + cases = ( + ((2, 2, 2), (2, 2, 2), (2)), + ((2, 2, 1), (2, 1, 2), ([0, 1], [2, 0])), + ((1, 1, 1), (2, 1, 2), ([0, 1], [2, 0])), + ) + for first_shape, second_shape, dims in cases: + yield SampleInput(make_tensor(first_shape, dtype=dtype, device=device, + requires_grad=requires_grad, low=-1, high=+2), + make_tensor(second_shape, dtype=dtype, device=device, + requires_grad=requires_grad, low=-1, high=+2), + dims=dims) + +def sample_inputs_kron(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None) + test_cases = ( + ((S, S), (M, L)), + ) + + for input_shape, other_shape in test_cases: + input = make_arg(input_shape) + other = make_arg(other_shape) + yield SampleInput(input, other) + +def sample_inputs_inner(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(make_arg(S), make_arg(S)) + yield SampleInput(make_arg(), make_arg(S, S)) + +def sample_inputs_scatter(op_info, device, dtype, requires_grad, **kwargs): + def _tensor(shape, dtype=dtype, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + def _gather(shape, index_dim, max_indices): + return gather_variable(shape, index_dim, max_indices, device=device) + + zero = torch.tensor(0, dtype=torch.long, device=device) + test_cases = ( + (_tensor((M, S)), (0, _gather((S, S), 1, M), _tensor((S, S)))), + (_tensor((M, S)), (0, _gather((S, S), 1, M).to(torch.int32), _tensor((S, S)))), + (_tensor((M, S)), (1, _gather((S, S), 0, S), _tensor((S, S)))), + (_tensor((M, S)), (-1, _gather((S, S), 0, S), _tensor((S, S)))), + (_tensor((M, S)), (0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))), + (_tensor((M, S)), (1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))), + (_tensor((M, S)), (-1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))), + (_tensor(()), (0, zero.detach().clone(), _tensor(()))), + (_tensor(()), (0, zero.detach().clone(), 2.5)), + ) + + for tensor, args in test_cases: + yield SampleInput(tensor, *args) + + if not requires_grad: + yield SampleInput(tensor.detach().clone(), *args, reduce='add') + + if dtype.is_floating_point: + yield SampleInput(tensor.detach().clone(), *args, reduce='multiply') + +def sample_inputs_scatter_add(op_info, device, dtype, requires_grad, **kwargs): + def _tensor(shape, dtype=dtype, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + def _gather(shape, index_dim, max_indices): + return gather_variable(shape, index_dim, max_indices, device=device) + + zero = torch.tensor(0, dtype=torch.long, device=device) + yield SampleInput(_tensor((M, S)), 0, _gather((S, S), 1, M), _tensor((S, S))) + yield SampleInput(_tensor((M, S)), 1, _gather((S, S), 0, S), _tensor((S, S))) + yield SampleInput(_tensor((M, S)), -1, _gather((S, S), 0, S), _tensor((S, S))) + yield SampleInput(_tensor((M, S)), 0, _gather((M, S // 2), 1, M), _tensor((M, S // 2))) + yield SampleInput(_tensor((M, S)), 1, _gather((M, S // 2), 0, S), _tensor((M, S // 2))) + yield SampleInput(_tensor((M, S)), -1, _gather((M, S // 2), 0, S), _tensor((M, S // 2))) + yield SampleInput(_tensor(()), 0, zero.detach().clone(), _tensor(())) + +def sample_inputs_scatter_reduce(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + gather = partial(gather_variable, device=device) + + zero = torch.tensor(0, dtype=torch.long, device=device) + test_cases = ( + ((M, S), 0, gather((S, S), 1, M), (S, S)), + ((M, S), 1, gather((S, S), 0, S), (S, S)), + ((M, S), -1, gather((S, S), 0, S), (S, S)), + ((M, S), 0, gather((M, S // 2), 1, M), (M, S // 2)), + ((M, S), 1, gather((M, S // 2), 0, S), (M, S // 2)), + ((M, S), -1, gather((M, S // 2), 0, S), (M, S // 2)), + ((), 0, zero.detach().clone(), ()), + ) + + reduce = op_info.variant_test_name + for (inp_shape, dim, index, src_shape), include_self in product(test_cases, [False, True, False]): + yield SampleInput(make_arg(inp_shape), + args=(dim, index, make_arg(src_shape), reduce), + kwargs={'include_self': include_self}) + + + # Sample inputs to test edge cases for backward + # Check that gradients are propagated correctly for prod when zeros in self/src are reduced + if requires_grad and reduce == 'prod': + # This sample tests gradients for the following cases + # (a) 1 zero reduced (from src (self[0, 1], self[1, 1]), from self (self[0, 0], self[2, 0])) + # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0]) + # (c) no zeros reduced (self([2, 1])) + # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py + # test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad + input = torch.tensor([[0, 13], [0, 17], [0, 19]], dtype=dtype, device=device, requires_grad=requires_grad) + src = torch.tensor([[0, 1, 2, 3], [0, 4, 0, 1], [2, 3, 5, 6]], dtype=dtype, device=device, requires_grad=requires_grad) + idx = torch.tensor([[1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.long, device=device) + + yield SampleInput(input, + args=(1, idx, src, reduce), + kwargs={'include_self': True}) + +def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode='lengths', **kwargs): + def _tensor(shape, dtype=dtype, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + test_cases = ( + # inp_shape, dim, lengths, unsafe + ((S,), 0, [0, 1, 2, 2], False), + ((S,), 0, [0, 1, 2, 2], True), + ((S,), 0, [2, 0, 3, 0], False), + ((S, S), 0, [0, 1, 2, 2], False), + # test when lengths do not sum to dim size + ((M, S, S), 0, [1, 2, 0, 6, 0], True), + # test for higher dimensions + ((S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False), + ((S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False), + ((S, S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False), + ((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False), + ) + + reductions = ["max", "mean", "min", "sum", "prod"] + for args, reduce, initial in product(test_cases, reductions, [1, 2]): + inp_shape, dim, lengths, unsafe = args + lengths_t = torch.tensor(lengths, dtype=torch.long, device=device) + sample_input_kwargs = {'axis': dim, 'unsafe': unsafe, 'initial': initial} + if mode == 'lengths': + sample_input_kwargs['lengths'] = lengths_t + elif mode == 'offsets': + zeros_shape = list(lengths_t.shape) + zeros_shape[dim] = 1 + offsets_t = torch.cat((lengths_t.new_zeros(zeros_shape), lengths_t), dim).cumsum_(dim) + sample_input_kwargs['offsets'] = offsets_t + else: + raise RuntimeError(f"mode most be one of 'offsets' or 'lengths' got '{mode}'.") + yield SampleInput(_tensor(inp_shape), + args=(reduce,), + kwargs=sample_input_kwargs) + + +def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, + low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg((S, S, S))) + yield SampleInput(make_arg(())) + yield SampleInput(make_arg((S, S, S), noncontiguous=True)) + +def sample_inputs_unravel_index(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput( + torch.tensor( + [[3, 8, 13], [0, 5, 10]], + device=device, + dtype=dtype), + (4, 5)) + yield SampleInput( + torch.tensor([[3, 8, 13], [0, 5, 10]], device=device, dtype=dtype), + (4, 2**30)) + yield SampleInput( + torch.tensor([[3, 8, 13], [0, 5, 10]], device=device, dtype=dtype), + (2**30, 4)) + yield SampleInput( + torch.tensor(2, device=device, dtype=dtype), + (2, 2)) + max_val = 2**(8 * dtype.itemsize - (1 if dtype.is_signed else 0)) - 1 + yield SampleInput( + torch.tensor(max_val - 1, device=device, dtype=dtype), + (1, max_val)) + yield SampleInput( + torch.tensor([22, 41, 37], device=device, dtype=dtype), + (7, 6)) + yield SampleInput( + torch.tensor(min(1621, max_val), device=device, dtype=dtype), + (6, 7, 8, 9)) + yield SampleInput( + torch.tensor([], device=device, dtype=dtype), + (10, 3, 5)) + yield SampleInput( + torch.tensor( + [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0]], + device=device, + dtype=dtype), + (5, 8)) + yield SampleInput( + torch.tensor( + [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0], [1, 3, 1, 0, 9, 5]], + device=device, + dtype=dtype), + (5, 8, 10)) + yield SampleInput( + torch.tensor(0, device=device, dtype=dtype), + ()) + + a = np.array([[2, 4, 5, 6], [7, 8, 1, 15]]) + b = np.array([[3, 2, 7, 6], [10, 12, 8, 9]]) + _, i1, i2 = np.intersect1d(a, b, assume_unique=True, return_indices=True) + yield SampleInput(torch.tensor(i1, device=device, dtype=dtype), a.shape) + yield SampleInput(torch.tensor(i2, device=device, dtype=dtype), b.shape) + + a = np.array([[2, 4, 5, 6, 6], [4, 7, 8, 7, 2]]) + b = np.array([[3, 2, 7, 7], [10, 12, 8, 7]]) + _, i1, i2 = np.intersect1d(a, b, return_indices=True) + yield SampleInput(torch.tensor(i1, device=device, dtype=dtype), a.shape) + yield SampleInput(torch.tensor(i2, device=device, dtype=dtype), b.shape) + + +def sample_inputs_tril_triu(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + cases = (((M, M), ()), + ((M, M), (2,),), + ((M, S), ()), + ((M, S), (-1,)), + ((M, M), (2,),), + ((S, M, S), ()), + ((S, M, S), (2,)), + ((3, 3, S, S), ()),) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + +def error_inputs_tril_triu(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for input.ndim <= 2 + yield ErrorInput(SampleInput(make_arg((4,))), error_regex="input tensor must have at least 2 dimensions") + +def sample_inputs_trilu_indices(op_info, device, dtype, requires_grad, **kwargs): + # (row, col, offset) + args_list = ((0, 0), + (20, 0), + (0, 20), + (20, 21, 0), + (20, 21, 7), + (20, 21, -7), + # Large test cases below are deliberately commented out to speed up CI + # tests and to avoid OOM error. When modifying implementations of + # tril_indices and triu_indices, please enable these tests and make sure + # they pass. + # (2, 68435455, 3), + # (5000, 5000), + # (5000, 5000, 1234), + # (5000, 5000, -1233), + ) + for args in args_list: + yield SampleInput(args[0], args=args[1:], kwargs={"dtype": dtype, "device": device}) + +def sample_inputs_clone_contiguous(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, M, S))) + yield SampleInput(make_arg(())) + +def reference_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs): + # NOTE: the default memory format for clone is torch.preserve_format, for contiguous it's torch.contiguous_format + # This exploits that default to test torch.preserve_format for clone, without causing an error when testing contiguous + yield from sample_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs) + + shapes = ( + (3, 5, 6), + (1, 1, 3, 5, 6), + (1, 1, 3, 5, 6, 1, 1), + (1, 0, 3, 5, 0, 2), + (1, 0, 3, 5, 0, 0, 1, 1, 2), + (), + ) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in shapes: + yield SampleInput(make_arg(shape)) + yield SampleInput(make_arg(shape).transpose(0, -1)) + yield SampleInput(make_arg(shape, noncontiguous=True)) + yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1)) + + yield SampleInput(make_arg(shape), kwargs={'memory_format': torch.contiguous_format}) + yield SampleInput(make_arg(shape).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format}) + yield SampleInput(make_arg(shape, noncontiguous=True), kwargs={'memory_format': torch.contiguous_format}) + yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format}) + + # shape, strides, offset + strided_cases = ( + ((5, 6, 2), (1, 1, 7), 2), + ((5, 5, 4), (1, 1, 7), 2), + ((5, 5, 2), (4, 5, 7), 3), + ((5, 5, 2), (5, 5, 7), 3), + ((5, 5, 2), (5, 5, 5), 3), + ((9, 5, 2), (0, 1, 7), 3), + ) + + for shape, strides, offset in strided_cases: + yield SampleInput(make_arg(500,).as_strided(shape, strides, offset)) + yield SampleInput(make_arg(500,).as_strided(shape, strides, offset), kwargs={'memory_format': torch.contiguous_format}) + + # channels last 2D + yield SampleInput(make_arg((2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last}) + a = make_arg((2, 2, 2, 2)).permute(0, 3, 1, 2) + yield SampleInput(a, kwargs={'memory_format': torch.channels_last}) + + # channels last 3D + yield SampleInput(make_arg((2, 2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last_3d}) + a = make_arg((2, 2, 2, 2, 2)).permute(0, 4, 1, 2, 3) + yield SampleInput(a, kwargs={'memory_format': torch.channels_last_3d}) + + +def sample_inputs_sum_to_size(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # list of tuples (shape, shape) defining the shapes of the input and output tensors + sample_shapes = [ + ((), ()), + ((S,), (1,)), + ((S, S), (1, 1)), + ((S, S), (1, S)), + ((S, S), (S, S)), + ((S, S, S), (S, 1, S)), + ] + + for input_shape, output_shape in sample_shapes: + yield SampleInput(make_arg(input_shape), args=(output_shape,)) + if output_shape == (): + continue + yield SampleInput(make_arg(input_shape), args=(list(output_shape),)) + yield SampleInput(make_arg(input_shape), args=(*output_shape,)) + + +def error_inputs_sum_to_size(op_info, device, **kwargs): + shape = (M, S, M) + err_msg = "is not expandable to size" + si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M, M)) + yield ErrorInput(si, error_regex=err_msg) + + shape = (M + 1, S, S, M) + err_msg = "is not expandable to size" + si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M + 1, 1)) + yield ErrorInput(si, error_regex=err_msg) + + +def sample_inputs_resize_ops(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device) + cases = (((S, S, S), (S * S, S)), + ((), ()), + ((), (1, 1, 1)), + ) + + for shape, args_or_shape in cases: + # Update `args` based on operator + if op_info.name == 'resize_': + # resize_ takes shape/tuple of ints, + args = (args_or_shape, ) + elif op_info.name == 'resize_as_': + # resize_as_ takes another tensor + args = (make_arg(shape, requires_grad=False), ) # type:ignore[assignment] + else: + raise ValueError("sample_inputs_resize_ops is being used with incorrect operator") + + yield SampleInput(make_arg(shape, requires_grad=requires_grad), args=args) + +def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = ( + # a, b, is_tensor_supported + ((S, S, S), (S * S, S), True), + ((S * S, S), (S, S, S), True), + ((S * S, S), (S, -1, S), False), # neg index + ((S * S * 2, S), (S, -1), False), # neg index + ((S,), (S,), True), + ((), (), False), # empty + ((), (1,), True), + ) + + for a, b, is_tensor_supported in cases: + # skip unsupported cases + if kwargs.get("tensor_arg") and not is_tensor_supported: + continue + + # convert to tensor + if kwargs.get("tensor_arg"): + b = make_arg(b, requires_grad=False) + + yield SampleInput(make_arg(a), args=(b,)) + +def reference_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs) + + cases = ( + # a, b, is_tensor_supported + ((125,), (25, 5), True), + ((25, 25), (1, 5, 5, 1, 5, 1, 5, 1), True), + ((16, 32), (2, 4, 1, 4, 4, 1, 4), True), + ((16, 12), (12, 16), True), + ((1, 16, 12), (12, 16), True), + ((1, 5, 1, 5), (25, 1), True), + ((2, 4, 2), (4, 4), True), + ((1, 4), (1, 1, 2, 1, 2), True), + ((3, 5, 7), (7, 5, 3), True), + ((1,), (), False), # empty + ((5, 0, 2, 3), (5, 0, 2, 3), True), + ((2, 1, 0, 3, 1), (5, 0), True), + ((1,), (), False), # empty + ((4, 5, 6), (4, 5, 6, 1, 1, 1), True), + ((), (1, 1, 1, 1), False), # empty + ) + + irreversible_cases = ( + ((), (-1,), False), # neg index, empty + ((4, 7, 9, 1, 1), (1, 4, 3, -1, 1), False), # neg index + ) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for a, b, is_tensor_supported in cases: + # skip unsupported cases + if kwargs.get("tensor_arg") and not is_tensor_supported: + continue + + if kwargs.get("tensor_arg"): + # convert to tensor + yield SampleInput(make_arg(a), args=(make_arg(b, requires_grad=False),)) + yield SampleInput(make_arg(b), args=(make_arg(a, requires_grad=False),)) + else: + yield SampleInput(make_arg(a), args=(b,)) + yield SampleInput(make_arg(b), args=(a,)) + + for a, b, is_tensor_supported in irreversible_cases: + # skip unsupported cases + if kwargs.get("tensor_arg") and not is_tensor_supported: + continue + + # convert to tensor + if kwargs.get("tensor_arg"): + b = make_arg(b, requires_grad=False) + + yield SampleInput(make_arg(a), args=(b,)) + +def error_inputs_view_reshape(op, device, **kwargs): + + cases = ( + # a, b, is_tensor_supported + # Reshape to different numel + ((2,), (), False), # empty + ((1, 3, 0), (), False), # empty + ((4, 3), (4, 2), True), + ((1, 3, 5), (5, 2, 2), True), + # No valid inference + ((1, 3, 5), (5, -1, 2), False), # neg index + # Two inferred shapes + ((1, 3, 5), (5, -1, -1), False), # neg index + ((1), (0, -1), False), # neg index + ((0, 5), (0, -1), False), # neg index + ) + + make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False) + for a, b, is_tensor_supported in cases: + # skip unsupported cases + if kwargs.get("tensor_arg") and not is_tensor_supported: + continue + + if b == (5, -1, -1): + error_regex = "only one dimension can be inferred" + elif a == (0, 5): + error_regex = (r"cannot reshape tensor of 0 elements into shape " + r"\[0, -1\] because the unspecified dimension size " + r"-1 can be any value and is ambiguous") + else: + # to avoid having issues with a regex + shape = ', '.join(map(str, b)) + size = a if type(a) is int else functools.reduce(operator.mul, a, 1) + error_regex = rf"shape '\[{shape}\]' is invalid for input of size {size}" + + # convert to tensor + if kwargs.get("tensor_arg"): + b = make_arg(b, requires_grad=False) + + yield ErrorInput(SampleInput(make_arg(a), args=(b,)), error_type=Exception, + error_regex=error_regex) + + +def sample_inputs_atleast1d2d3d(op_info, device, dtype, requires_grad, **kwargs): + shapes = ((S, S, S, S), (S, S, S), (S, S), (S, ), (),) + make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in shapes: + yield SampleInput(make_tensor_partial(shape)) + yield SampleInput([make_tensor_partial(shape) for shape in shapes]) + +def sample_inputs_column_stack(op_info, device, dtype, requires_grad, **kwargs): + cases: tuple[tuple, tuple] = ( # type: ignore[assignment] + ((S, 2, 1), (S, 3, 1)), + ((S), (S, 5)), ((), (1, S)) + ) + make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape1, shape2 in cases: + yield SampleInput([make_tensor_partial(shape1), make_tensor_partial(shape2)]) + +def sample_inputs_flatten(op_info, device, dtype, requires_grad, **kwargs): + shapes = ((S, S, S), (S, S), (S, ), (),) + make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in shapes: + yield SampleInput(make_tensor_partial(shape)) + if len(shape) > 1: + yield SampleInput(make_tensor_partial(shape), start_dim=1, end_dim=-1) + +def reference_inputs_flatten(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_flatten(op, device, dtype, requires_grad, **kwargs) + + # shape x start_dim x end_dim + cases = ( + ((5, 4, 0, 1, 3, 7), 1, 3), + ((5, 4, 0, 1, 3, 7), 4, 5), + ((5, 4, 1, 1, 3, 7), 2, 3), + ((), 0, -1), + ((1,), 0, -1), + ((3, 7, 5), 1, 2), + ((4, 5), 1, 1), + ((1, 5, 5, 1, 5, 1, 5, 1), 0, 2), + ((1, 5, 5, 1, 5, 1, 5, 1), 3, -1), + ((1, 5, 5, 1, 5, 7, 5, 1), -2, -1), + ((2, 4, 2), 0, 1), + ((4, 2, 2), 1, 2), + ((0, 3, 4, 5), 1, 3), + ) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape, start, end in cases: + yield SampleInput(make_arg(shape), args=(start, end,)) + yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), args=(start, end,)) + yield SampleInput(make_arg(shape).transpose(0, -1), args=(start, end,)) + +def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs): + # in_shape, dim, sizes + args = (((8,), 0, (8,)), + ((8,), 0, (4, 2)), + ((8,), -1, (2, 2, 2)), + ((8,), -1, (-1, 2)), + ((3, 6, 2), 1, (2, 3)), + ((3, 6, 2), -2, (2, 3)), + ((3, 6, 2), -2, (-1, 3)), + ((3, 2, 12), 2, (3, 2, 2)), + ((4, 0), 0, (2, 2)), + ((4, 0), 1, (2, 0, 0, 0)), + ) + make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for in_shape, dim, sizes in args: + yield SampleInput(make_tensor_partial(in_shape), args=(dim, sizes)) + + +def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((S, S, S), (1, 2)), + ((S, S, S), (-1, 2)), + ((S, S, S), (-1, -1)), + ((S, S, S), (1, -1)), + ((S, S), (-1, 2)), + ((S,), (0, 2)) + ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + +def sample_inputs_select_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((S, S, S), (S, S), (1, 2)), + ((S, S, S), (S, S), (-1, 2)), + ((S, S, S), (S, S), (-1, -1)), + ((S, S, S), (S, S), (1, -1)), + ((S,), (), (0, 2)) + ) + + for input_shape, src_shape, args in cases: + input_ = make_arg(input_shape) + src = make_arg(src_shape) + yield SampleInput(input_, args=(src, *args)) + + +def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((L, L, L), (L, L, L,), (0, 0, L, 1)), + ((L, L, L), (L // 2, L, L,), (0, L // 2, L, 1)), + ((L, L, L), (L // 4, L, L,), (0, L // 2, L, 2)), + ((L, L, L), (L, L, L,), (1, 0, L, 1)), + ((L, L, L), (L, L // 2, L,), (1, L // 2, L, 1)), + ((L, L, L), (L, L // 4, L,), (1, L // 2, L, 2)), + ((L, L, L), (L, L, L,), (2, 0, L, 1)), + ((L, L, L), (L, L, L // 2,), (2, L // 2, L, 1)), + ((L, L, L), (L, L, L // 4,), (2, L // 2, L, 2)), + ) + + for input_shape, src_shape, args in cases: + input_ = make_arg(input_shape) + src = make_arg(src_shape) + yield SampleInput(input_, args=(src, *args)) + +def sample_inputs_expand(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((S, 1, 1), (S, S, S)), + ((S, 1, S), (S, S, S)), + ((S, 1, S), (-1, S, -1)), + ((S, 1, S), (-1, S, S)), + ((S, 1), (S, S, S)), + ((1,), (S, S, S)), + ((1, S), (1, 1, S)), + ((), ()), + ((), (1, 3, 2)), + ) + + for case in cases: + shape, args = case + yield SampleInput(make_arg(shape), args=(args,)) + +def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + shapes = ((), + (2, 3)) + memory_format_options = [None, torch.contiguous_format] + + for shape, memory_format in itertools.product(shapes, memory_format_options): + yield SampleInput(make_arg(shape), + kwargs={'memory_format': memory_format} if memory_format else {}) + yield SampleInput(make_arg((2, 3, 2, 3)), kwargs={'memory_format': torch.channels_last}) + +def sample_inputs_byte(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=0, high=255, requires_grad=requires_grad) + + shapes = ((), + (2, 3)) + memory_format_options = [None, torch.contiguous_format] + + for shape, memory_format in itertools.product(shapes, memory_format_options): + yield SampleInput(make_arg(shape), + kwargs={'memory_format': memory_format} if memory_format else {}) + yield SampleInput(make_arg((2, 3, 2, 3)), kwargs={'memory_format': torch.channels_last}) + +def sample_inputs_expand_as(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device) + + cases = (((S, 1, 1), (S, S, S)), + ((), ()), + ((), (1, 1)), + ) + + for shape, shape_other in cases: + yield SampleInput(make_arg(shape, requires_grad=requires_grad), + args=(make_arg(shape_other, requires_grad=False),)) + + +def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + def make_bool_mask(shape): + # Make sure at least one element is nonzero, + # except for empty tensor + mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False) + + if mask_t.numel() == 0: + return mask_t + elif mask_t.numel() == 1: + mask_t.fill_(True) + return mask_t + + if mask_t.sum() == 0: + def random_index(shape): + return tuple(random.randrange(0, max_idx) for max_idx in shape) + + mask_t[random_index(mask_t.shape)] = True + return mask_t + + return mask_t + + cases = (((M, M), (M, M), (M, M), False), + ((M, 1, M), (M, M), (M, M, 1), True), + ((), (), (), False), + ((M, 1, M), (), (M, M, 1), True), + ((), (M, M), (), True), + ((), (2), (1, 1), True), + ) + + for shape, mask_shape, other_shape, broadcasts_input in cases: + yield SampleInput(make_arg(shape), + args=(make_bool_mask(mask_shape), make_arg(other_shape)), + broadcasts_input=broadcasts_input) + +# TODO: add reference inputs for where(condition) signature +def reference_inputs_where(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_where(op, device, dtype, requires_grad, **kwargs) + + make_cond = partial(make_tensor, dtype=torch.bool, device=device, requires_grad=requires_grad) + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # noncontiguous + c = make_cond((10, 3), noncontiguous=True) + a = make_arg((10, 1), noncontiguous=True) + b = make_arg((3, 10, 3)).transpose(0, -1) + + # NOTE that the OpInfo for where takes samples of the form a, cond, b + yield SampleInput(a, args=(c, b)) + + # MPS does not support float64, which causes issues in the following tests + if torch.device(device).type == "mps": + return + + # type promoting + # FIXME(rec): shouldn't other_dtype be used two lines below? + other_dtype = torch.double if dtype is not torch.double else torch.long # noqa: F841 + c = make_cond((10, 3), noncontiguous=True) + a = make_arg((10, 1), dtype=torch.long) + b = make_arg((10, 1)) + + yield SampleInput(a, args=(c, b)) + + # two python scalars + c = make_cond((10, 3), noncontiguous=True) + a = make_arg((1,)).item() + b = make_arg((1,)).item() + + yield SampleInput(a, args=(c, b)) + + # NaN propagation + if dtype.is_floating_point or dtype.is_complex: + if dtype.is_floating_point: + nan = float('nan') + else: + # dtype.is_complex + nan = complex(float('nan'), float('nan')) + c = make_cond((1, 10, 3)) + a = make_arg((10, 3), noncontiguous=True) + a[2, 1] = nan + b = make_arg((1, 3)) + b[0, 2] = nan + + yield SampleInput(a, args=(c, b)) + + # Python scalars type promotion + for scalar in (0, 0.0, 2j, False): + yield SampleInput(scalar, args=(c, b)) + yield SampleInput(a, args=(c, scalar)) + + +def error_inputs_where(op_info, device, **kwargs): + shape = (S,) + err_msg = "Expected all tensors to be on the same device" + for devices in product(('cpu', device), repeat=3): + if len(set(devices)) == 2: + si = SampleInput(make_tensor(shape, device=devices[0], dtype=torch.float32), + args=(make_tensor(shape, dtype=torch.bool, device=devices[1]), + make_tensor(shape, device=devices[2], dtype=torch.float32))) + yield ErrorInput(si, error_regex=err_msg) + +def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + inputs = [] + for shape in sizes: + # construct input without any non-zero elements + zeros = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad) + inputs.append(zeros) + + # construct input with mixed zero and non-zero elements + mixed = make_arg(shape).requires_grad_(False) + mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False) + mixed[mask_t] = 0 + inputs.append(mixed) + + for input_t, as_tuple in product(inputs, [False, True]): + yield SampleInput(input_t.clone().requires_grad_(requires_grad), + kwargs=dict(as_tuple=as_tuple)) + +def sample_inputs_nonzero_static(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + inputs = [] + for shape in sizes: + # construct input without any non-zero elements + zeros = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad) + inputs.append(zeros) + + # construct input with mixed zero and non-zero elements + mixed = make_arg(shape).requires_grad_(False) + mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False) + mixed[mask_t] = 0 + inputs.append(mixed) + + nonzero_sizes = [0, 1, XS, S, M] + + for input_t, nonzero_size in product(inputs, nonzero_sizes): + yield SampleInput(input_t.clone().requires_grad_(requires_grad), + kwargs=dict(size=nonzero_size)) + +def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((S, S, S), (2,)), + ((S, S, S), (S, 1)), + ((S, S, S), (S, -1))) + + for case in cases: + shape, args = case + yield SampleInput(make_arg(shape), args=args) + +def reference_inputs_chunk(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_chunk(op, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # shape x chunks x dim + cases = ( + ((13, 9, 11), 17, -1), + ((13, 9, 11), 11, -1), + ((13,), 12, -1), + ((15,), 12, -1), + ((15,), 7, 0), + ((15,), 9, 0), + ((3, 7), 9, 1), + ((3, 7), 9, 0), + ((3, 7), 2, 0), + ((3, 7), 3, 0), + ((3, 7), 1, 0), + ((3, 7), 1, 1), + ((4, 4), 2, 0), + ) + + for shape, chunks, dim in cases: + yield SampleInput(make_arg(shape), args=(chunks, dim)) + +def sample_inputs_kthvalue(op_info, device, dtype, requires_grad, **kwargs): + def _tensor(shape, dtype=dtype, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + test_cases = [ + ((S, S, S), (2,)), + ((S, S, S), (2, 1,)), + ((S, S, S), (2, -1,)), + ((S, S, S), (2, 1, True,)), + ((S, S, S), (2, -1, True,)), + ((S,), (2, 0,)), + ((S,), (2, 0, True,)), + ((), (1,)), + ((), (1, 0,)), + ((), (1, 0, True)), + ] + + yield from (SampleInput(_tensor(tensor), *args) for tensor, args in test_cases) + +def error_inputs_kthvalue(op_info, device, **kwargs): + # tests overlapping output fails + t = make_tensor(10, dtype=torch.float32, device=device) + indices = torch.empty((), device=device, dtype=torch.long) + yield ErrorInput(SampleInput(t, 5, out=(t, indices)), + error_regex="unsupported operation") + + k_out_of_range_err = "selected number k out of range for dimension" + yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3, 0), + error_regex=k_out_of_range_err) + yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3), + error_regex=k_out_of_range_err) + yield ErrorInput(SampleInput(torch.tensor(2, device=device), 3), + error_regex=k_out_of_range_err) + +def sample_inputs_dropout(op_info, device, dtype, requires_grad, *, + train=None, valid_input_dim=None, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if valid_input_dim: + cases = ((S,) * i for i in valid_input_dim) + else: + cases = ((S, S), (S,), ()) + p_vals = [0.0, 0.5, 1.0] + # This is to handle special case for feature_alpha_dropout which has different + # supported dtypes depending on `train` parameter + training_vals = [train] if train is not None else [True, False] + + for case, p, training in product(cases, p_vals, training_vals): + yield SampleInput(make_arg(case), p=p, training=training) + yield SampleInput(make_arg(case)) + +def sample_inputs_dropout_backward(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_mask = partial(make_tensor, device=device, dtype=torch.bool, requires_grad=False) + + cases = ((S, S, S, S), (S,), ()) + scale_vals = [0.0, 1.0, 2.0] + + for case, scale in product(cases, scale_vals): + yield SampleInput(make_arg(case), make_mask(case), scale) + +def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs): + def make_input(shape): + return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_long_input(shape, *, low, high, noncontiguous=False): + return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high, + noncontiguous=noncontiguous) + + def make_per_sample_weight(flag, idx): + # a tensor of float / double weights, or None + # to indicate all weights should be taken to be 1 + if flag: + return make_input(idx.shape) + return None + + offsets = torch.tensor([0, 3], device=device, dtype=torch.long) + for generate_per_sample_weight in (True, False): + for mode in ('sum', 'mean', 'max'): + # per_sample_weights is only supported for mode='sum' (got mode='****') + if generate_per_sample_weight and mode in ('mean', 'max'): + continue + + # 1-D index tensor + idx = make_long_input((S,), low=0, high=M) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'offsets': offsets, 'mode': mode, + 'per_sample_weights': per_sample_weights}) + + idx = make_long_input((S,), low=0, high=M, noncontiguous=True) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'offsets': offsets, 'mode': mode, + 'per_sample_weights': per_sample_weights}) + + # bag with zero length + idx = make_long_input((S,), low=0, high=M, noncontiguous=True) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'offsets': torch.tensor([0, 0, 3], device=device, dtype=torch.long), + 'mode': mode, + 'per_sample_weights': per_sample_weights}) + + # 2-D index tensor + idx = make_long_input((S, S), low=0, high=M) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'mode': mode, 'per_sample_weights': per_sample_weights}) + + idx = make_long_input((S, S), low=0, high=M, noncontiguous=True) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'mode': mode, 'per_sample_weights': per_sample_weights}) + + # The gradient vector at `padding_idx` is not updated. + # Negative padding_idx + idx = make_long_input((6,), low=0, high=S) + idx[0] = 4 + idx[4] = 4 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((S, S)), args=(idx,), + kwargs={'padding_idx': -1, 'offsets': offsets, + 'mode': mode, 'per_sample_weights': per_sample_weights},) + + idx = make_long_input((3, 3), low=0, high=S) + # Positive padding_idx + idx[0, 0] = 2 + idx[1, 1] = 2 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((S, S)), args=(idx,), + kwargs={'padding_idx': 2, 'mode': mode, + 'per_sample_weights': per_sample_weights},) + + idx = make_long_input((6, ), low=0, high=S) + weights = make_input((S, S)) + offsets_ = torch.tensor([0, 3, 6], device=device, dtype=torch.long) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'mode': mode, 'offsets': offsets_, 'include_last_offset': True},) + + if not requires_grad: + # Following inputs return different gradient from the numerical gradient. + # This is expected and relevant tests are present in `test_nn.py`. + + # Due to inplace renorming of weight, the numerical gradient doesn't match the + # analytical gradient. + idx = make_long_input((2, 2), low=0, high=S) + weights = make_input((S, S)) * 2 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'max_norm': 1., 'mode': mode, + 'per_sample_weights': per_sample_weights},) + + idx = make_long_input((6, ), low=0, high=S) + weights = make_input((S, S)) * 2 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'max_norm': 1., 'norm_type': 1.0, + 'mode': mode, 'offsets': offsets, + 'per_sample_weights': per_sample_weights},) + + if mode != 'max': + # Scale the gradient based on the inverse frequency of a particular index. + # Note : smax mode does not support sparse weights + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 1 + idx[0, 1] = 1 + weights = make_input((S, S)) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'scale_grad_by_freq': True, 'mode': mode, + 'per_sample_weights': per_sample_weights},) + + # gradcheck not implemented for sparse tensors. + # Note : max mode does not support sparse weights + idx = make_long_input((6, ), low=0, high=S) + weights = make_input((S, S)) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'sparse': True, 'offsets': offsets, + 'mode': mode, 'per_sample_weights': per_sample_weights}) + + idx = make_long_input((6, ), low=0, high=S) + idx[0] = 1 # freq more than 1 + idx[1] = 1 # freq more than 1 + idx[3] = 0 # padding_idx + weights = make_input((S, S)) * 2 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'sparse': True, 'scale_grad_by_freq': True, 'padding_idx': 0, + 'max_norm': 1., 'offsets': offsets, + 'mode': mode, 'per_sample_weights': per_sample_weights}) + + +def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs): + def make_input(shape): + return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_long_input(shape, *, low, high): + return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high) + + # 0-D index tensor + idx = make_long_input((), low=0, high=M) + yield SampleInput(make_input((M, S)), args=(idx,),) + + # 1-D index tensor + idx = make_long_input((S,), low=0, high=M) + yield SampleInput(make_input((M, S)), args=(idx,),) + + # 2-D index tensor + idx = make_long_input((S, S), low=0, high=M) + yield SampleInput(make_input((M, S)), args=(idx,),) + + if not requires_grad: + # Following inputs return different gradient from the numerical gradient. + # This is expected and relevant tests are present in `test_nn.py`. + + # The gradient vector at `padding_idx` is not updated. + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 2 + idx[1, 1] = 2 + yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': 2},) + + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 4 + idx[1, 1] = 4 + yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': -1},) + + # Due to inplace renorming of weight, the numerical gradient doesn't match the + # analytical gradient. + idx = make_long_input((2, 2), low=0, high=S) + weights = make_input((S, S)) * 2 + yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1.},) + + idx = make_long_input((2, 2), low=0, high=S) + weights = make_input((S, S)) * 2 + yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1., 'norm_type': 1.0},) + + # Scale the gradient based on the inverse frequency of a particular index. + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 1 + idx[0, 1] = 1 + weights = make_input((S, S)) + yield SampleInput(weights, args=(idx,), kwargs={'scale_grad_by_freq': True},) + + # gradcheck not implemented for sparse tensors. + idx = make_long_input((2, 2), low=0, high=S) + weights = make_input((S, S)) + yield SampleInput(weights, args=(idx,), kwargs={'sparse': True}) + + idx = make_long_input((3, 3), low=0, high=S) + idx[0, 0] = 1 # freq more than 1 + idx[0, 1] = 1 # freq more than 1 + idx[1, 0] = 0 # padding_idx + weights = make_input((S, S)) * 2 + yield SampleInput(weights, args=(idx,), + kwargs={'sparse': True, 'scale_grad_by_freq': True, + 'padding_idx': 0, 'max_norm': 1.}) + + +def sample_inputs_one_hot(op_info, device, dtype, requires_grad, **kwargs): + def make_input(shape, *, low, high): + return make_tensor(shape, device=device, dtype=dtype, low=low, high=high, requires_grad=requires_grad) + + shapes = ((), (S,), (L, M, S)) + num_classess = (-1, 10) + + return ( + SampleInput( + make_input( + shape, + low=0, + high=10 if num_classes == -1 else num_classes // 2, + ), + kwargs=dict(num_classes=num_classes), + ) + for shape, num_classes in itertools.product(shapes, num_classess) + ) + + +def sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs): + rhs_requires_grad = kwargs.get('rhs_requires_grad', requires_grad) + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Although most losses also support the reduce and size_average combination instead of reduce, the former is + # deprecated since 0.4.1 and thus is not tested + shapes_and_kwargs = ( + ((), None), + ((S,), dict(reduction="mean")), + ((S,), dict(reduction="sum")), + ((S,), dict(reduction="none")), + ((S, S), None), + ((S, S, S), None), + ) + + for shape, kwargs in shapes_and_kwargs: + yield SampleInput(_make_tensor(shape), + args=(_make_tensor(shape, requires_grad=rhs_requires_grad),), + kwargs=kwargs) + +def sample_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs): + # We get better tests if we change the range of the values to something like [-2,2] + # because for grid (second tensor argument) the "useful" range is [-1,1] and this way + # you get a better combination of out-of-range and in-range test cases + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, + low=-2, high=2) + + batch_size = 2 + num_channels = 3 + modes = ("bilinear", "nearest") + align_cornerss = (False, True) + padding_modes = ("zeros", "border", "reflection") + + for dim in (2, 3): + + modes_ = (*modes, "bicubic") if dim == 2 else modes + + for mode, padding_mode, align_corners in itertools.product(modes_, padding_modes, align_cornerss): + yield SampleInput( + _make_tensor((batch_size, num_channels, *[S] * dim)), + _make_tensor((batch_size, *[S] * dim, dim)), + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + +def reference_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs): + + batch_size = 2 + num_channels = 3 + height = 345 + width = 456 + modes = ("bilinear", "nearest", "bicubic") + align_cornerss = (False, True) + padding_modes = ('zeros', 'border', 'reflection') + + # Create an affine transformation matrix + a = torch.deg2rad(torch.tensor(45.0)) + ca, sa = torch.cos(a), torch.sin(a) # rotation angles + s1, s2 = 1.23, 1.34 # scales + + theta = torch.tensor([[ + [ca / s1, sa, 0.0], + [-sa, ca / s2, 0.0], + ]], dtype=dtype, device=device) + theta = theta.expand(batch_size, 2, 3).contiguous() + + x = torch.arange(batch_size * num_channels * height * width, device=device) + x = x.reshape(batch_size, num_channels, height, width).to(torch.uint8) + x = x.to(dtype=dtype) + x.requires_grad_(requires_grad) + + for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss): + grid = torch.nn.functional.affine_grid( + theta, size=(batch_size, num_channels, height, width), align_corners=align_corners + ) + yield SampleInput( + x, + grid, + mode, + padding_mode, + align_corners, + ) + +def sample_inputs_grid_sampler_2d(op_info, device, dtype, requires_grad, **kwargs): + # We get better tests if we change the range of the values to something like [-2,2] + # because for grid (second tensor argument) the "useful" range is [-1,1] and this way + # you get a better combination of out-of-range and in-range test cases + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, + low=-2, high=2) + + batch_size = 2 + num_channels = 3 + modes = (0, 1, 2) + align_cornerss = (False, True) + padding_modes = (0, 1, 2) + + for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss): + yield SampleInput( + _make_tensor((batch_size, num_channels, S, L)), + _make_tensor((batch_size, M + 3, M, 2)), + mode, + padding_mode, + align_corners, + ) + +def sample_inputs_grid_sampler_3d(op_info, device, dtype, requires_grad, **kwargs): + _make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, + low=-1, high=1) + # Test both out-of-range and in-range grid values + _make_grid = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, + low=-4, high=4) + + modes = (0,) + padding_modes = (0, 1, 2) + align_cornerss = (False, True) + shape_pairs = [ + # [input_shape, grid_shape] + [(1, 1, 2, 2, 2), (1, 1, 1, 1, 3)], + [(2, 3, S, L, L), (2, M + 2, M + 1, M, 3)], + [(L, L + 1, L + 2, L + 3, L + 4), (L, M + 2, M + 1, M, 3)], + [(M, M + 1, M + 2, M + 3, M + 4), (M, L + 3, L + 2, L + 1, 3)], + [(L, M + 1, M + 2, M + 3, M + 4), (L, L + 3, L + 2, L + 1, 3)], + ] + + params_prod = itertools.product(modes, padding_modes, align_cornerss, shape_pairs) + + for mode, padding_mode, align_corners, (input_shape, grid_shape) in params_prod: + yield SampleInput( + _make_input(input_shape), + _make_grid(grid_shape), + mode, + padding_mode, + align_corners, + ) + +def sample_inputs_cosine_embedding_loss(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_target(shape): + shape = () if len(shape) == 1 else (shape[0], ) + t = torch.randint(0, 2, shape, device=device, dtype=torch.long) + # Label with -1 or 1 + t = t * 2 - 1 + target = t.to(dtype=dtype).detach_().requires_grad_(requires_grad) + return target + + shapes = ((S, S), (S,)) + reductions = ('none', 'mean', 'sum') + for s, r in product(shapes, reductions): + yield SampleInput( + make_input(s), + args=(make_input(s), make_target(s)), + kwargs=dict(reduction=r, margin=random.uniform(-1, 1)) + ) + +def sample_inputs_ctc_loss(op_info, device, dtype, requires_grad, **kwargs): + input_length = 50 + batch = 16 + num_char = 20 + target_length = 30 + + def make_log_probs(s): + t = make_tensor(s, device=device, dtype=dtype) + log_probs = t.log_softmax(2).to(device=device, dtype=dtype).detach().requires_grad_(requires_grad=requires_grad) + return log_probs + + reductions = ('none', 'mean', 'sum') + zero_inf = (True, False) + lengths_type = (list, torch.Tensor) + for r, z, lt in product(reductions, zero_inf, lengths_type): + log_probs = make_log_probs((input_length, batch, num_char)) + targets = torch.randint(1, num_char, (batch, target_length), dtype=torch.long, device=device) + input_lengths = torch.full((batch, ), input_length, dtype=torch.long, device=device) + target_lengths = torch.randint(10, target_length, (batch, ), dtype=torch.long, device=device) + + # Dont generate int[] types if reduction = "Mean" since this results in non composite compliant calls + # to ctc_loss.IntList since a tensor needs to be created from the target lengths. + # Creating such a tensor requires the use of pointers to copy data from int[] -> torch.Tensor + # e.g. via std::copy. Similarly symbolic/real tracing with fx will also not work + if lt is list and r in ["none", "sum"]: + input_lengths = input_lengths.tolist() + target_lengths = target_lengths.tolist() + + yield SampleInput(log_probs, args=(targets, input_lengths, target_lengths,), + kwargs=dict(reduction=r, zero_infinity=z)) + + +def sample_inputs_nll_loss(op_info, device, dtype, requires_grad, **kwargs): + shape = (2, 3) + num_classes = shape[1] + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # FIXME: Derivative wrt. weight not implemented + make_weight = partial(make_tensor, num_classes, device=device, dtype=dtype, requires_grad=False) + + def make_target(shape, zeros=False): + s = (shape[0], *shape[2:]) if len(shape) > 1 else () + if zeros: + return torch.zeros(s, device=device, dtype=torch.long) + else: + return make_tensor(s, + low=0, + high=shape[1] if len(shape) > 1 else shape[0], + device=device, + dtype=torch.long) + + + def gen_shape_kwargs(): + # Batched, non-batched and 2d + shapes = (shape, (num_classes,), shape + (2, 2)) + reductions = ('none', 'mean', 'sum') + for reduction, s in product(reductions, shapes): + yield make_input(s), make_target(s), dict(reduction=reduction) + yield make_input(s), make_target(s), dict(weight=make_weight(), reduction=reduction) + yield make_input(s), make_target(s), dict(weight=make_weight(low=0), reduction=reduction) + if dtype.is_floating_point or dtype.is_complex: + yield make_input(s), make_target(s), dict(weight=make_weight(high=0), reduction=reduction) + t = make_target(s) + ignore = num_classes // 2 + # If "mean", nll returns NaN, so it's not differentiable at those points + if t.eq(ignore).all() and reduction == "mean": + t.fill_(0) + yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction) + yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction, weight=make_weight()) + # Test ignoring all the targets + # If "mean", nll returns NaN, so it's not differentiable at those points + if reduction != "mean": + yield make_input(s), make_target(s, zeros=True), dict(ignore_index=0, reduction=reduction) + + for input, target, kwargs in gen_shape_kwargs(): + yield SampleInput(input, args=(target,), kwargs=kwargs) + + target = torch.tensor([-1, 2], device=device, dtype=torch.long) + yield SampleInput(make_input(shape), args=(target,), kwargs={'ignore_index': -1}) + + +def sample_inputs_binary_cross_entropy_with_logits( + op_info, device, dtype, requires_grad, **kwargs +): + make = partial(make_tensor, device=device, dtype=dtype) + make_prob = partial(make, low=0, high=1) + reductions = ("mean", "sum", "none") + + def make_weight_shape_kwargs(): + kwargs = [] + for shape in ((1,), (1, S), (S), (S, S)): + kwargs.extend([((S, S), dict(reduction=reduction, weight=make(shape))) for reduction in reductions]) + return kwargs + + shapes_and_kwargs = [ + *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))], + *[((S, S), dict(reduction=reduction)) for reduction in reductions], + *make_weight_shape_kwargs(), + *[((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions], + *[((S, S), dict(reduction=reduction, weight=make((S, S)), pos_weight=make((S,), low=0))) for reduction in reductions], + ] + + for shape, kwargs in shapes_and_kwargs: + yield SampleInput( + make(shape, requires_grad=requires_grad), + args=(make_prob(shape, requires_grad=requires_grad),), + kwargs=kwargs, + ) + +def sample_inputs_argwhere(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput(torch.tensor([1, 0, 2, 0], dtype=dtype, device=device, requires_grad=requires_grad)) + mask = torch.tensor([[0, 1, 0, 1, 0], + [1, 1, 1, 1, 0], + [0, 0, 0, 1, 0], + [1, 0, 1, 1, 0], + [1, 0, 0, 1, 0]], dtype=torch.bool, device=device) + t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad) + t[mask] = 0 + yield SampleInput(t) + + t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True) + t[mask] = 0 + yield SampleInput(t) + + t = make_tensor((S, 0), dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(t) + + yield SampleInput(torch.zeros((S,), dtype=dtype, device=device, requires_grad=requires_grad)) + yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) + +def _generate_sample_shape_reduction(): + shapes = ((S,), (S, S), (S, S, S)) + reductions = ('none', 'mean', 'sum') + yield from product(shapes, reductions) + +def sample_inputs_gaussian_nll_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # Set low slightly above 0 so gradcheck doesn't accidentally dip below 0 + make_var = partial(make_tensor, low=0.1, device=device, dtype=dtype, requires_grad=requires_grad) + + def gen_shape(shape): + yield shape + # Broadcast + yield (*shape[:-1], 1) + yield shape[:-1] + + def gen_shape_kwargs(): + for s, r in _generate_sample_shape_reduction(): + for t_s, v_s in product(gen_shape(s), gen_shape(s)): + yield _make_tensor(s), _make_tensor(t_s), make_var(v_s), dict(reduction=r) + yield ( + _make_tensor(s), _make_tensor(t_s), make_var(v_s), + dict(full=True, reduction=r) + ) + yield ( + _make_tensor(s), _make_tensor(t_s), make_var(v_s), + dict(eps=random.uniform(1e-6, 1e-3), reduction=r) + ) + yield ( + _make_tensor(s), _make_tensor(t_s), make_var(v_s), + dict(full=True, eps=random.uniform(1e-6, 1e-3), reduction=r) + ) + + for input, target, var, kwargs in gen_shape_kwargs(): + yield SampleInput(input, args=(target, var, ), kwargs=kwargs) + +def error_inputs_gaussian_nll_loss(op_info, device, **kwargs): + _make = partial(make_tensor, device=device, dtype=torch.float32) + + # invalid reduction value + yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 3), low=0), reduction="abc"), + error_type=ValueError, error_regex="abc is not valid") + + # var is of incorrect shape + yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 2), low=0)), + error_type=ValueError, error_regex="var is of incorrect size") + + # target is of incorrect shape + yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 2), _make((10, 2, 3), low=0)), + error_type=RuntimeError, + error_regex=(r"The size of tensor a \(3\) must match the size of tensor b \(2\) " + r"at non-singleton dimension 2")) + +def _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for s, r in _generate_sample_shape_reduction(): + yield _make_tensor(s), _make_tensor(s), dict(reduction=r) + +def sample_inputs_hinge_embedding_loss(op_info, device, dtype, requires_grad, **kwargs): + for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs): + # target should contain either 1 or -1 as per docs + mask = torch.rand_like(target) > 0.5 + target[mask] = 1 + target[~mask] = -1 + d['margin'] = random.uniform(-9, 9) + yield SampleInput(input, args=(target, ), kwargs=d) + + # scalar input and target. + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(_make_tensor(()), args=(_make_tensor(()), )) + +def error_inputs_hinge_embedding_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction value + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex='is not a valid value') + +def reference_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs) + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for reduction in ('sum', 'mean', 'none'): + if dtype.is_floating_point: # only supports ints and floats + # NaN propagation + inp = make_input((10, )) + inp[2] = float('nan') + target = make_input((10, )) + # target should contain either 1 or -1 as per docs + mask = torch.rand_like(target) > 0.5 + target[mask] = -1 + target[~mask] = 1 + yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction}) + + # Inf Handling + inp = make_input((10, )) + inp[4] = float('inf') + target = make_input((10, )) + mask = torch.rand_like(target) > 0.5 + target[mask] = -1 + target[~mask] = 1 + yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction}) + + # Broadcasting + inp = make_input((5, 5)) + target = make_input((1, 5)) + mask = torch.rand_like(target) > 0.5 + target[mask] = -1 + target[~mask] = 1 + yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction}) + +def sample_inputs_huber_loss(op_info, device, dtype, requires_grad, **kwargs): + for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs): + d['delta'] = random.uniform(1e-3, 9) + yield SampleInput(input, args=(target, ), kwargs=d) + +def error_inputs_huber_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction value + err = 'is not a valid value for reduction' + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex=err) + # delta <= 0 + for delta in (0, -1): + err = 'huber_loss does not support non-positive values for delta.' + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'delta': delta}), + error_type=RuntimeError, error_regex=err) + +def sample_inputs_poisson_nll_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def gen_shape_kwargs(): + for s, r in _generate_sample_shape_reduction(): + for li in (True, False): + for f in (True, False): + i1 = _make_tensor(s) + i2 = _make_tensor(s) + # For Poisson NLL Loss, + # target is assumed to be from + # Poisson Distribution which + # always has positive samples + t1 = _make_tensor(s, low=0) + t2 = _make_tensor(s, low=0) + + if not li: + i1.abs_() + i2.abs_() + t1.abs_() + t2.abs_() + + yield ( + i1, t1, + dict(log_input=li, full=f, reduction=r) + ) + yield ( + i2, t2, + dict(log_input=li, full=f, + eps=random.uniform(1e-8, 1e-3), + reduction=r) + ) + + for input, target, kwargs in gen_shape_kwargs(): + yield SampleInput(input, args=(target, ), kwargs=kwargs) + + # test INT_TO_FLOAT promotion + if dtype.is_complex: + for d in (torch.bool, torch.int64): + yield SampleInput(_make_tensor(dtype=dtype), args=(_make_tensor(dtype=d),)) + yield SampleInput(_make_tensor(dtype=d), args=(_make_tensor(dtype=dtype),)) + +def error_inputs_poisson_nll_loss(op_info, device, **kwargs): + make = partial(make_tensor, device=device, dtype=torch.float32) + + # invalid reduction value + yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),), + kwargs={'reduction': 'abc'}), + error_type=ValueError, + error_regex='abc is not a valid value for reduction') + # invalid input shapes + yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)), + error_regex=(r'(Attempting to broadcast a dimension of length|' + r'The size of tensor a \(5\) must match the ' + r'size of tensor b \(4\) at non-singleton ' + r'dimension 1)')) + +def error_inputs_soft_margin_loss(op_info, device, **kwargs): + make = partial(make_tensor, device=device, dtype=torch.float32) + + # invalid reduction value + yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),), + kwargs={'reduction': 'abc'}), + error_type=ValueError, + error_regex='abc is not a valid value for reduction') + # invalid input shapes + yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)), + error_regex=(r'(Attempting to broadcast a dimension of length|' + r'The size of tensor a \(4\) must match the ' + r'size of tensor b \(5\) at non-singleton ' + r'dimension 1)')) + +def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, with_distance=False, **kwargs): + make = partial(make_tensor, (S, M), device=device, dtype=dtype, requires_grad=requires_grad) + + kwargss = ( + *[dict(margin=margin) for margin in (1e-6, 1.0, 10.0)], + dict(swap=True), + *[dict(reduction=reduction) for reduction in ("mean", "sum", "none")], + ) + + for kwargs in kwargss: + input = make() + args = (make(), make()) + if with_distance: + kwargs["distance_function"] = torch.nn.PairwiseDistance() + yield SampleInput(input, args=args, kwargs=kwargs) + +def error_inputs_triplet_margin_loss(op_info, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + + samples = ( + # input, args, kwargs, error_type, error_regex + # invalid reduction + (make_input(3, 4), (make_input(3, 4), make_input(3, 4)), + dict(reduction="abc"), + ValueError, "abc is not a valid value for reduction"), + + # invalid margin + (make_input(3, 4), (make_input(3, 4), make_input(3, 4)), + dict(margin=-1.0), + ValueError, "margin must be greater than 0, got -1.0"), + + # shape mismatch + (make_input(3, 5), (make_input(3, 4), make_input(3, 4)), + {}, + RuntimeError, + (r'(Attempting to broadcast a dimension of length|' + r"The size of tensor a \(5\) must match the size of tensor b \(4\) " + r"at non-singleton dimension 1)")), + (make_input(3, 4), (make_input(3, 5), make_input(3, 4)), + {}, + RuntimeError, + (r'(Attempting to broadcast a dimension of length|' + r"The size of tensor a \(4\) must match the size of tensor b \(5\) " + r"at non-singleton dimension 1)")), + (make_input(3, 4), (make_input(3, 4), make_input(3, 5)), + {}, + RuntimeError, + (r'(Attempting to broadcast a dimension of length|' + r"The size of tensor a \(4\) must match the size of tensor b \(5\) " + r"at non-singleton dimension 1)")), + + # different dimensions + (make_input(3,), (make_input(3, 4), make_input(3, 4)), + {}, + RuntimeError, + (r"The anchor, positive, and negative tensors are expected to have " + r"the same number of dimensions, but got: anchor 1D, positive 2D, " + r"and negative 2D inputs")), + (make_input(3, 4), (make_input(3,), make_input(3, 4)), + {}, + RuntimeError, + (r"The anchor, positive, and negative tensors are expected to have " + r"the same number of dimensions, but got: anchor 2D, positive 1D, " + r"and negative 2D inputs")), + (make_input(3, 4), (make_input(3, 4), make_input(3,)), + {}, + RuntimeError, + (r"The anchor, positive, and negative tensors are expected to have " + r"the same number of dimensions, but got: anchor 2D, positive 2D, " + r"and negative 1D inputs")), + ) + + for input, args, kwargs, error_type, error_regex in samples: + yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs), + error_type=error_type, error_regex=error_regex) + +def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs): + make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad) + make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad) + make_scale = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + M, N, K = 15, 32, 16 + samples = [] + # two e4m3 + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + scale1 = make_scale((1,)) + scale2 = make_scale((1,)) + samples.append(SampleInput(mat1, mat2, scale1, scale2)) + # mat1 e4m3 mat2 e5m2 + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e5m2((K, N)).t().contiguous().t() + scale1 = make_scale((1,)) + scale2 = make_scale((1,)) + samples.append(SampleInput(mat1, mat2, scale1, scale2)) + # mat1 e5m2 mat2 e4m3 + mat1 = make_mat_e5m2((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + scale1 = make_scale((1,)) + scale2 = make_scale((1,)) + samples.append(SampleInput(mat1, mat2, scale1, scale2)) + + yield from samples + +def sample_inputs_scaled_mm_v2(op_info, device, dtype, requires_grad, **kwargs): + from torch.nn.functional import ScalingType, SwizzleType + make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad) + + make_scale = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + + M, N, K = 15, 32, 16 + samples = [] + # two e4m3 tensorwise + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + scale1 = make_scale((1,)) + scale2 = make_scale((1,)) + samples.append( + SampleInput( + mat1, + mat2, + [scale1, ], + [ScalingType.TensorWise, ], + [SwizzleType.NO_SWIZZLE, ], + [scale2, ], + [ScalingType.TensorWise, ], + [SwizzleType.NO_SWIZZLE, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + # two e4m3 rowwise + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + scale1 = make_scale((M, 1)) + scale2 = make_scale((1, N)) + samples.append( + SampleInput( + mat1, + mat2, + [scale1, ], + [ScalingType.RowWise, ], + [SwizzleType.NO_SWIZZLE, ], + [scale2, ], + [ScalingType.RowWise, ], + [SwizzleType.NO_SWIZZLE, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + M, K, N = 256, 512, 768 + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + + dmajor, dminor = torch.cuda.get_device_capability() + + if dmajor == 9 and not torch.version.hip: + # 1x128 x 1x128 + scale1 = make_scale((K // 128, M)).t() + scale2 = make_scale((K // 128, N)).t() + samples.append( + SampleInput( + mat1, + mat2, + [scale1, ], + [ScalingType.BlockWise1x128, ], + [SwizzleType.NO_SWIZZLE, ], + [scale2, ], + [ScalingType.BlockWise1x128, ], + [SwizzleType.NO_SWIZZLE, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + # 128x128 x 1x128 + L4 = round_up(K // 128, 4) + scale1 = make_scale((M // 128, L4)).t() + scale2 = make_scale((K // 128, N)).t() + samples.append( + SampleInput( + mat1, + mat2, + [scale1, ], + [ScalingType.BlockWise128x128, ], + [SwizzleType.NO_SWIZZLE, ], + [scale2, ], + [ScalingType.BlockWise1x128, ], + [SwizzleType.NO_SWIZZLE, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + # 1x128 x 128x128 + L4 = round_up(K // 128, 4) + scale1 = make_scale((K // 128, M)).t() + scale2 = make_scale((N // 128, L4)).t() + samples.append( + SampleInput( + mat1, + mat2, + [scale1, ], + [ScalingType.BlockWise1x128, ], + [SwizzleType.NO_SWIZZLE, ], + [scale2, ], + [ScalingType.BlockWise128x128, ], + [SwizzleType.NO_SWIZZLE, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + + if dmajor >= 10: + # MXFP8 + scale1 = make_scale((M, K // 32)).to(torch.float8_e8m0fnu) + scale2 = make_scale((K // 32, N)).to(torch.float8_e8m0fnu) + samples.append( + SampleInput( + mat1, + mat2, + [scale1, ], + [ScalingType.BlockWise1x32, ], + [SwizzleType.SWIZZLE_32_4_4, ], + [scale2, ], + [ScalingType.BlockWise1x32, ], + [SwizzleType.SWIZZLE_32_4_4, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + # NVFP4 + # [M, K] -> [M, K // 2] + # [K, N] -> [K // 2, N] + mat1_fp4 = _bfloat16_to_float4_e2m1fn_x2(mat1.to(torch.bfloat16)) + mat2_fp4 = _bfloat16_to_float4_e2m1fn_x2(mat2.to(torch.bfloat16).t()).t() + scale1 = make_scale((M, K // 16)).to(torch.float8_e4m3fn) + global_scale1 = make_scale((1, )) + scale2 = make_scale((K // 16, N)).to(torch.float8_e4m3fn) + global_scale2 = make_scale((1, )) + samples.append( + SampleInput( + mat1_fp4, + mat2_fp4, + [scale1, global_scale1], + [ScalingType.BlockWise1x16, ScalingType.TensorWise], + [SwizzleType.SWIZZLE_32_4_4, ], + [scale2, global_scale2], + [ScalingType.BlockWise1x16, ScalingType.TensorWise], + [SwizzleType.SWIZZLE_32_4_4, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + + + yield from samples + +def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8 + + dim_3_q_shape = (batch, seq_q, head_dim) + dim_3_kv_shape = (batch, seq_kv, head_dim) + dim_4_q_shape = (batch, num_heads, seq_q, head_dim) + dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) + + broadcast_tuple = ((num_heads, seq_q, head_dim), (batch, num_heads, seq_kv, head_dim)) + + qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple] + samples = [] + gqa_options = [True, False] + causal_options = [True, False] + for qkv_shape, is_causal, dropout_p, _enable_gqa in product( + qkv_shapes, causal_options, [0.0, 0.5], gqa_options): + shape_q, shape_kv = qkv_shape + samples.append(SampleInput( + make(shape_q), + make(shape_kv), + make(shape_kv), + is_causal=is_causal, + dropout_p=dropout_p + )) + + # Add non standard shapes + # FIXME(rec): should diff_v_head_dim be appended to samples? + diff_v_head_dim = SampleInput( # noqa: F841 + make((batch, num_heads, seq_q, head_dim)), + make((batch, num_heads, seq_kv, head_dim)), + make((batch, num_heads, seq_kv, head_dim + 8)), + is_causal=is_causal, + dropout_p=dropout_p + ) + + # Add an attn_mask + samples.append( + SampleInput( + make((batch, num_heads, seq_q, head_dim)), + make((batch, num_heads, seq_kv, head_dim)), + make((batch, num_heads, seq_kv, head_dim)), + attn_mask=make((seq_q, seq_kv)), + is_causal=False, + dropout_p=0.0) + ) + + yield from samples + + +def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batch, num_heads, head_dim = 4, 4, 8 + seq_q = 11 + seq_kv = 32 + + dim_4_q_shape = (batch, num_heads, seq_q, head_dim) + dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) + + qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] + samples = [] + mask_types = [1, 2] # UpperLeft, LowerRight + scales = [None, 1.0] + + for qkv_shape, _is_causal, dropout_p, mask_type, scale in product( + qkv_shapes, [True, False], [0.0, 0.5], mask_types, scales): + shape_q, shape_kv = qkv_shape + samples.append(SampleInput( + make(shape_q).transpose(1, 2), + make(shape_kv).transpose(1, 2), + make(shape_kv).transpose(1, 2), + bias=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + dropout_p=dropout_p, + custom_mask_type=mask_type, + compute_log_sumexp=requires_grad, + scale=scale, + seqlen_k=None + )) + + # Add non standard shapes + # FIXME(rec): should diff_v_head_dim be appended to samples? + diff_v_head_dim = SampleInput( # noqa: F841 + make((batch, seq_q, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim + 8)), + bias=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + dropout_p=dropout_p, + custom_mask_type=0, # No Mask + compute_log_sumexp=requires_grad, + scale=None, + seqlen_k=None + ) + + # Add an attn_mask + samples.append( + SampleInput( + make((batch, seq_q, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim)), + bias=make(batch, num_heads, seq_q, seq_kv), + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + dropout_p=dropout_p, + custom_mask_type=0, # No Mask + compute_log_sumexp=requires_grad, + scale=None, + seqlen_k=None + ) + ) + + # jagged (with query/keys offsets) + cu_seqlens_k = torch.arange(-1, 32 * 2 + 1, 2, dtype=torch.int32, device=device) + cu_seqlens_k[-1] = 62 + cu_seqlens_k[0] = 0 + samples.append( + SampleInput( + make((32, 2, 64)).view(-1, 8, 8).unsqueeze(0), + make((64, 64)).view(-1, 8, 8).unsqueeze(0), + make((64, 64)).view(-1, 8, 8).unsqueeze(0), + bias=None, + cu_seqlens_q=torch.arange(0, 32 * 2 + 2, 2, dtype=torch.int32, device=device), + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=2, + max_seqlen_k=2, + dropout_p=0.0, + custom_mask_type=0, # No Mask + compute_log_sumexp=requires_grad, + scale=None, + seqlen_k=None, + ) + ) + + yield from samples + +def sample_inputs_flash_attention_forward(op_info, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batch, num_heads, head_dim = 4, 4, 8 + seq_q = 11 + seq_kv = 32 + + dim_4_q_shape = (batch, num_heads, seq_q, head_dim) + dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) + + qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] + samples = [] + scales = [None, 1.0] + + for qkv_shape, is_causal, dropout_p, scale in product( + qkv_shapes, [True, False], [0.0, 0.5], scales): + shape_q, shape_kv = qkv_shape + samples.append(SampleInput( + make(shape_q).transpose(1, 2), + make(shape_kv).transpose(1, 2), + make(shape_kv).transpose(1, 2), + cum_seq_q=None, + cum_seq_k=None, + max_q=seq_q, + max_k=seq_kv, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=False, + scale=scale, + )) + + yield from samples + +def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shape = (3,) + batched_shape = (2, *shape) + shapes_and_kwargs = [ + (shape, None), + (batched_shape, None), + (shape, dict(keepdim=True)), + (batched_shape, dict(keepdim=True)), + (shape, dict(p=5.0)), + (shape, dict(p=-1.0)), + (shape, dict(eps=1.0)), + ] + + return ( + SampleInput(make(shape), args=(make(shape),), kwargs=kwargs) for shape, kwargs in shapes_and_kwargs + ) + +def sample_inputs_pixel_shuffle(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield from ( + SampleInput(make_arg((1, 9, 2, 2)), upscale_factor=upscale_factor) + for upscale_factor in (1, 3) + ) + yield from ( + SampleInput(make_arg(shape), upscale_factor=1) + for shape in [ + (1, 0, 1, 1), + (1, 1, 0, 1), + (1, 1, 1, 0), + ] + ) + +def sample_inputs_pixel_unshuffle(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield from ( + SampleInput(make_arg((1, 1, 6, 6)), downscale_factor=downscale_factor) + for downscale_factor in (1, 3) + ) + yield from ( + SampleInput(make_arg(shape), downscale_factor=1) + for shape in [ + (1, 0, 1, 1), + (1, 1, 0, 1), + (1, 1, 1, 0), + ] + ) + +def sample_inputs_channel_shuffle(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes_groups = [ + ((1, 4, 10, 10), 2), + ((2, 6, 8, 8), 3), + ((2, 8, 5, 5), 4), + ] + + yield from ( + SampleInput(make_arg(shape), args=(groups,)) + for shape, groups in shapes_groups + ) + +def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, logits=False, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype) + # Lower bounds must be greater than 'eps' defined in gradcheck.py::gradgradcheck() -> eps + # otherwise perturbation calculation causes Tensor value to become negative triggering + # a device-side hardware assertion + make_prob = partial(make, low=1e-6, high=1) + + reductions = ("mean", "sum", "none") + + shapes_and_kwargs = [ + *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))], + *[((S, S), dict(reduction=reduction)) for reduction in reductions], + *[((S, S), dict(reduction=reduction, weight=make((S, S)))) for reduction in reductions], + ] + + if logits: + shapes_and_kwargs.extend( + [((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions] + ) + + for shape, kwargs in shapes_and_kwargs: + yield SampleInput( + (make if logits else make_prob)(shape, requires_grad=requires_grad), + args=(make_prob(shape, requires_grad=requires_grad),), + kwargs=kwargs, + ) + +def sample_inputs_allclose(op_info, device, dtype, requires_grad, **kwargs): + sample_shapes = [(), (S), (S, S, S)] + atols = [1e-2, 1e-16] + rtols = [1e-1, 0.5] + for s, rtol, atol in product(sample_shapes, rtols, atols): + # close sample + t = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) + close = (t + atol).detach().requires_grad_(requires_grad) + yield SampleInput(t, close, rtol=rtol, atol=atol) + + # random sample + a = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) + b = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(a, b, rtol=rtol, atol=atol) + + +def sample_inputs_l1_loss(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs) + + # test COMPLEX_TO_FLOAT promotion + if dtype.is_complex: + make = partial(make_tensor, (), device=device, requires_grad=requires_grad) + yield SampleInput(make(dtype=dtype), args=(make(dtype=torch.double),)) + yield SampleInput(make(dtype=torch.double), args=(make(dtype=dtype),)) + +def error_inputs_l1_loss(op_info, device, **kwargs): + make = partial(make_tensor, device=device, dtype=torch.float32) + + # invalid reduction value + yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),), + kwargs={'reduction': 'abc'}), + error_type=ValueError, + error_regex='abc is not a valid value for reduction') + # invalid input shapes + yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)), + error_regex=(r'(Attempting to broadcast a dimension of length|' + r'The size of tensor a \(4\) must match the ' + r'size of tensor b \(5\) at non-singleton ' + r'dimension 1)') + ) + +def sample_inputs_smooth_l1_loss(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs) + + make = partial(make_tensor, (S, S), device=device, dtype=dtype, requires_grad=requires_grad) + + # This test case always triggers the smooth condition, since absolute difference of input and target + # is smaller than beta + yield SampleInput(make(low=0, high=2), args=(make(low=-2, high=0),), kwargs=dict(beta=5)) + yield SampleInput(make(), args=(make(),), kwargs=dict(beta=0)) + +def sample_inputs_kl_div(op_info, device, dtype, requires_grad, **kwargs): + # kl_div works with inputs in [0, 1] (aka the pdf of a probability measure) + # Then log [0, 1] = (-inf, 0], so this is the log space + make_arg = partial(make_tensor, low=0., device=device, dtype=dtype, requires_grad=requires_grad) + + def make_log(shape): + out = torch.nn.functional.log_softmax(make_arg(shape), -1) + out.requires_grad_(requires_grad) + return out + + def make_prob(shape): + out = torch.nn.functional.softmax(make_arg(shape), -1) + out.requires_grad_(requires_grad) + return out + + shapes = ((2,), (2, 3)) + reductions = ("none", "mean", "batchmean", "sum") + for shape, reduction, log_target in product(shapes, reductions, (True, False)): + input = make_log(shape) + target = make_log(shape) if log_target else make_prob(shape) + yield SampleInput(input, args=(target,), kwargs=dict(reduction=reduction, log_target=log_target)) + +def sample_inputs_pdist(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield from (SampleInput(make_input((n, m))) for n, m in itertools.product((1, S), repeat=2)) + yield from (SampleInput(make_input((S, S)), kwargs=dict(p=p)) for p in (0.0, 1.0, 2.0, 10.0, float("inf"))) + +def reference_pdist(input, p=2): + pdist = scipy.spatial.distance.pdist + if p == 0: + output = pdist(input, "hamming") * input.shape[1] + elif p == float("inf"): + output = pdist(input, lambda x, y: np.abs(x - y).max()) + else: + output = pdist(input, "minkowski", p=p) + return output.astype(input.dtype) + +def sample_inputs_diagflat(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_input(())) + yield SampleInput(make_input((2,))) + yield SampleInput(make_input((2, 2))) + yield SampleInput(make_input((2,)), offset=1) + yield SampleInput(make_input((2,)), offset=-1) + + +_UNPOOL_NAME_TO_DIM = { + 'nn.functional.max_unpool1d': 1, + 'nn.functional.max_unpool2d': 2, + 'nn.functional.max_unpool3d': 3 +} + + +def error_inputs_max_unpool(op_info, device, **kwargs): + """Error inputs for max_unpool: shape mismatch between input and indices.""" + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + pool_dim = _UNPOOL_NAME_TO_DIM[op_info.name] + + # Create mismatched shapes for input and indices + kwargs_dict = {'kernel_size': 3, 'stride': 2, 'padding': 0} + if pool_dim == 1: + input_shape = (8, 8) + indices_shape = (8, 7) + elif pool_dim == 2: + input_shape = (1, 1, 4, 4) + indices_shape = (1, 1, 4, 1) + else: # pool_dim == 3 + input_shape = (1, 1, 4, 4, 4) + indices_shape = (1, 1, 4, 4, 1) + + yield ErrorInput( + SampleInput( + make_arg(input_shape), + args=(torch.zeros(indices_shape, device=device, dtype=torch.long),), + kwargs=kwargs_dict + ), + error_type=RuntimeError, + error_regex='Expected shape of indices to be' + ) + + +def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs): + unpool_name_to_pool_method_dict = { + 'nn.functional.max_unpool1d': torch.nn.functional.max_pool1d, + 'nn.functional.max_unpool2d': torch.nn.functional.max_pool2d, + 'nn.functional.max_unpool3d': torch.nn.functional.max_pool3d + } + + unpool_to_pool_name_dict = {k: f'nn.functional.{v.__name__}' for k, v in unpool_name_to_pool_method_dict.items()} + + pool_dim = _UNPOOL_NAME_TO_DIM[op_info.name] + pool_method = unpool_name_to_pool_method_dict[op_info.name] + + pool_op_info = copy.copy(op_info) + pool_op_info.name = unpool_to_pool_name_dict[op_info.name] + + for sample in sample_inputs_max_pool(pool_op_info, device, dtype, requires_grad, **kwargs): + # shapes (C, ...) do not work as of now, + # see https://github.com/pytorch/pytorch/issues/68337 + # TODO: remove once the issue is resolved + if sample.input.dim() != pool_dim + 2: + continue + + # No dilation > 1 for max_unpool, + # see https://github.com/pytorch/pytorch/issues/68420 + if sample.kwargs['dilation'] != 1: + continue + + # Can't unpool without indices + if sample.kwargs['return_indices']: + pool, indices = pool_method(sample.input, **sample.kwargs) + # arg has to be a leaf + arg = pool.detach().requires_grad_(requires_grad) + sample_kwargs = { + 'kernel_size': sample.kwargs['kernel_size'], + 'stride': sample.kwargs['stride'], + 'padding': sample.kwargs['padding'], + # output_size could be None but we specify it explicitly + # to compensate for the information lose in pool due + # to the floor/ceil operation used to compute the shapes + 'output_size': sample.input.size() + } + + yield SampleInput(arg, args=(indices,), kwargs=sample_kwargs) + +def sample_inputs_max_unpool_grad(op_info, device, dtype, requires_grad, **kwargs): + for sample in sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs): + indices = sample.args[0] + # The samples for max_unpool are generated with max_pool. + # It could be that a single element from the max_pool's + # input is mapped to several locations in its output. + # This situation leads to failed gradchecks because + # the finite difference algorithm perturbs the elements + # of the output one by one, and not in classes of + # equivalences determined by whether two elements + # in the output are coming from the same location in the + # input (simply put, they have the same corresponding index). + # So, there are two ways to resolve this issue: + # 1. Extract a perturbation for one element and apply it all + # the elements from the same equivalence class, or + # 2. Make sure that the equivalence classes are all singletons, + # i.e. the index tensor has to be comprised of only unique + # indices. + # Here we go with the solution 2, the easiest of all. + if indices.unique().numel() == indices.numel(): + yield sample + +def sample_inputs_multi_head_attention_forward(opinfo, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if requires_grad: + # backward tests would take too long to complete, causing the job timeout. + bsz = 2 + is_batcheds = (True,) + use_separate_proj_weights = (False,) + emb_sizes = (2,) + src_lens = (XS,) + tgt_lens = (XS,) + heads = (2,) + dropouts = (0.5,) + mask_types = ("2d",) + else: + bsz = 2 + is_batcheds = (False, True) + use_separate_proj_weights = (False, True) + emb_sizes = (2, 4) + src_lens = (XS,) + tgt_lens = (XS, S) + heads = (1, 2) + dropouts = (0.0, 0.5) + mask_types = (None, "2d", "3d") + + for is_batched, use_separate_proj_weight, mask_type, emb_size, src_len, tgt_len, num_heads, dropout_p in itertools.product( + is_batcheds, use_separate_proj_weights, mask_types, emb_sizes, src_lens, tgt_lens, heads, dropouts + ): + attn_mask = None + if mask_type == "2d": + attn_mask = make_input(src_len, tgt_len) + elif mask_type == "3d": + attn_mask = make_input((bsz if is_batched else 1) * num_heads, src_len, tgt_len) + + if is_batched: + q = make_input(src_len, bsz, emb_size) + k = make_input(tgt_len, bsz, emb_size) + v = make_input(tgt_len, bsz, emb_size) + else: + q = make_input(src_len, emb_size) + k = make_input(tgt_len, emb_size) + v = make_input(tgt_len, emb_size) + if use_separate_proj_weight: + in_proj_weight = None + q_proj_weight = make_input(emb_size, emb_size) + k_proj_weight = make_input(emb_size, emb_size) + v_proj_weight = make_input(emb_size, emb_size) + else: + in_proj_weight = make_input(emb_size * 3, emb_size) + q_proj_weight = None + k_proj_weight = None + v_proj_weight = None + + bias_k = make_input(emb_size) + bias_v = make_input(emb_size) + in_proj_bias = make_input(emb_size * 3) + out_proj_weight = make_input(emb_size, emb_size) + out_proj_bias = make_input(emb_size) + sample_args = ( + k, v, emb_size, num_heads, in_proj_weight, + in_proj_bias, bias_k, bias_v, False, + dropout_p, out_proj_weight, out_proj_bias + ) + sample_kwargs = { + "q_proj_weight" : q_proj_weight, + "k_proj_weight" : k_proj_weight, + "v_proj_weight" : v_proj_weight, + "attn_mask" : attn_mask, + "training" : dropout_p > 0.0, + "use_separate_proj_weight" : use_separate_proj_weight + } + + yield SampleInput(q, args=sample_args, kwargs=sample_kwargs) + + +# Includes some values such that N * N won't be a multiple of 4, +# which should ensure we test the vectorized and non-vectorized +# kernel code paths. +NUM_SIZE0_TENSORS = 10000 +foreach_num_tensors = [20, 23] if not TEST_WITH_SLOW else [23, 30, 300] +_foreach_inputs_default_kwargs = {"noncontiguous": False, "same_size": False, "low": None, "high": None} + + +class ForeachRightmostArgType(enum.Enum): + TensorList = enum.auto() + ScalarList = enum.auto() + Scalar = enum.auto() + Tensor = enum.auto() + + +class ForeachSampleInput(SampleInput): + # For TensorList Scalar/Tensor, we compute the reference + # by converting it into TensorList ScalarList/TensorList and + # then converting into multiple Tensor Scalar/Tensor. + # ref_args contains the args converted to TensorList ScalarList/TensorList + ref_args: Any + disable_fastpath: bool + + def __init__(self, *args, disable_fastpath=False, ref_args=None, **kwargs): + super().__init__(*args, **kwargs) + self.ref_args = ref_args or self.args + self.disable_fastpath = disable_fastpath + + +class foreach_inputs_sample_func: + def __init__( + self, + arity: int, + rightmost_supports_scalar: bool, + rightmost_supports_scalarlist: bool, + rightmost_supports_tensor: bool = False, + ) -> None: + self.arity = arity + self._set_rightmost_arg_types( + rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor, + ) + self._intersperse_empty = (True, False) + + def _set_rightmost_arg_types( + self, + rightmost_supports_scalar: bool, + rightmost_supports_scalarlist: bool, + rightmost_supports_tensor: bool, + ) -> None: + self._rightmost_arg_types = [ForeachRightmostArgType.TensorList] + if self.arity > 1: + if rightmost_supports_scalar: + self._rightmost_arg_types.append(ForeachRightmostArgType.Scalar) + if rightmost_supports_scalarlist: + self._rightmost_arg_types.append(ForeachRightmostArgType.ScalarList) + if rightmost_supports_tensor: + self._rightmost_arg_types.append(ForeachRightmostArgType.Tensor) + + def _sample_rightmost_arg( + self, + opinfo, + rightmost_arg_type, + device, + dtype, + num_tensors, + allow_higher_dtype_scalars, + **_foreach_inputs_kwargs, + ): + if rightmost_arg_type == ForeachRightmostArgType.TensorList: + return [sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)] + if rightmost_arg_type == ForeachRightmostArgType.Tensor: + return [make_tensor( + (), device=device, dtype=dtype, + noncontiguous=_foreach_inputs_kwargs["noncontiguous"], + requires_grad=_foreach_inputs_kwargs.get("requires_grad", False), + )] + should_use_simpler_scalars = opinfo.name == "_foreach_pow" and dtype in (torch.float16, torch.bfloat16) + + def sample_float(): + s = random.random() + if should_use_simpler_scalars: + return 1.0 if s > 0.5 else 2.0 + else: + return 1.0 - s + + high = 2 if should_use_simpler_scalars else 9 + if rightmost_arg_type == ForeachRightmostArgType.ScalarList: + scalarlist_list = [] + scalarlist_list.append([random.randint(0, high) + 1 for _ in range(num_tensors)]) + + if allow_higher_dtype_scalars or dtype.is_floating_point: + scalarlist_list.append([sample_float() for _ in range(num_tensors)]) + if allow_higher_dtype_scalars or dtype.is_complex: + scalarlist_list.append([complex(sample_float(), sample_float()) for _ in range(num_tensors)]) + scalarlist_list.append([1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)]) + scalarlist_list.append([True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)]) + return scalarlist_list + if rightmost_arg_type == ForeachRightmostArgType.Scalar: + scalars = [] + scalars.append(random.randint(1, high + 1)) + if allow_higher_dtype_scalars or dtype.is_floating_point: + scalars.append(sample_float()) + if allow_higher_dtype_scalars or dtype.is_complex: + scalars.append(complex(sample_float(), sample_float())) + scalars.append(True) + return scalars + raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}") + + def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + if self.arity == 1: + if "foreach_abs" in opinfo.name and dtype in complex_types(): + return True + # unary + if opinfo.ref in (torch.abs, torch.neg): + return False + if opinfo.ref_inplace == torch.Tensor.zero_: + return False + return dtype in integral_types_and(torch.bool) + if self.arity < 2 or rightmost_arg_type == ForeachRightmostArgType.Tensor: + return None + if "foreach_pow" in opinfo.name and dtype in integral_types_and(torch.bool): + return True + if any( + foreach_name in opinfo.name + for foreach_name in ("foreach_clamp_max", "foreach_clamp_min", "foreach_maximum", "foreach_minimum") + ) and dtype in integral_types_and(torch.bool): + return True + if rightmost_arg_type == ForeachRightmostArgType.TensorList: + disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool) + if "foreach_add" in opinfo.name and dtype == torch.bool: + disable_fastpath = True + return disable_fastpath + elif rightmost_arg_type == ForeachRightmostArgType.Scalar: + disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool) + if isinstance(rightmost_arg, bool): + disable_fastpath |= dtype == torch.bool + if opinfo.ref in (torch.add, torch.mul): + disable_fastpath = False + elif isinstance(rightmost_arg, int): + disable_fastpath |= dtype == torch.bool + elif isinstance(rightmost_arg, float): + disable_fastpath |= dtype in integral_types_and(torch.bool) + elif isinstance(rightmost_arg, complex): + disable_fastpath |= dtype not in complex_types() + else: + raise AssertionError(f"Invalid scalar of type {rightmost_arg_type} - {rightmost_arg}") + return disable_fastpath + elif rightmost_arg_type == ForeachRightmostArgType.ScalarList: + disable_fastpath = opinfo.ref == torch.div and dtype in integral_types_and(torch.bool) + elmt_t = type(rightmost_arg[0]) + has_same_type = all(isinstance(v, elmt_t) for v in rightmost_arg) + if not has_same_type: + return dtype not in complex_types() + if isinstance(rightmost_arg[0], bool): + if ("foreach_add" in opinfo.name or "foreach_mul" in opinfo.name) and dtype == torch.bool: + disable_fastpath = False + elif isinstance(rightmost_arg[0], int): + disable_fastpath |= dtype == torch.bool + elif isinstance(rightmost_arg[0], float): + disable_fastpath |= dtype in integral_types_and(torch.bool) + elif isinstance(rightmost_arg[0], complex): + disable_fastpath |= dtype not in complex_types() + else: + raise AssertionError(f"Invalid scalarlist of {rightmost_arg}") + return disable_fastpath + else: + raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}") + + def _sample_kwargs(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + kwargs = {} + if rightmost_arg_type == ForeachRightmostArgType.TensorList and opinfo.supports_alpha_param: + if dtype in integral_types_and(torch.bool): + kwargs["alpha"] = 3 + elif dtype.is_complex: + kwargs["alpha"] = complex(3, 3) + else: + kwargs["alpha"] = 3.14 + if self.arity > 1: + kwargs["disable_fastpath"] = self._should_disable_fastpath(opinfo, rightmost_arg, rightmost_arg_type, dtype) + return kwargs + + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + assert "num_input_tensors" not in kwargs + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) + for _rightmost_arg_type in self._rightmost_arg_types: + zero_size_foreach_inputs_kwargs = copy.deepcopy(_foreach_inputs_kwargs) + zero_size_foreach_inputs_kwargs["zero_size"] = True + input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs) + if self.arity > 1: + args = [ + sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs) + for _ in range(self.arity - 2) + ] + args.append( + self._sample_rightmost_arg( + opinfo, + ForeachRightmostArgType.TensorList, + device, + dtype, + NUM_SIZE0_TENSORS, + allow_higher_dtype_scalars=allow_higher_dtype_scalars, + **zero_size_foreach_inputs_kwargs, + )[0]) + kwargs = self._sample_kwargs( + opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype) + else: + args = [] + kwargs = {} + if opinfo.ref in (torch.abs, torch.neg): + kwargs["disable_fastpath"] = False + else: + kwargs["disable_fastpath"] = dtype in integral_types_and(torch.bool) + yield ForeachSampleInput(input, *args, **kwargs) + + def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): + num_input_tensors_specified = "num_input_tensors" in kwargs + num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors + assert isinstance(num_input_tensors, list) + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + _foreach_inputs_kwargs["zero_size"] = False + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) + + # add empty tensor interspersion to test fully fixing #100701 + for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product( + num_input_tensors, self._rightmost_arg_types, self._intersperse_empty): + if intersperse_empty_tensors and (num_tensors != max(num_input_tensors) or str(device) == 'cpu'): + # generate interspersed empty tensors for only 1 N on non-cpu device to lessen redundancy + continue + _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors + input = sample_inputs_foreach( + None, device, dtype, num_tensors, **_foreach_inputs_kwargs) + args = [] + if self.arity > 1: + args = [ + sample_inputs_foreach( + None, device, dtype, num_tensors, **_foreach_inputs_kwargs) + for _ in range(self.arity - 2) + ] + rightmost_arg_list = self._sample_rightmost_arg( + opinfo, rightmost_arg_type, device, dtype, num_tensors, allow_higher_dtype_scalars, + **_foreach_inputs_kwargs) + for rightmost_arg in rightmost_arg_list: + args.append(rightmost_arg) + kwargs = self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype) + ref_args = args + if rightmost_arg_type in (ForeachRightmostArgType.Scalar, ForeachRightmostArgType.Tensor): + ref_args = args[:-1] + [[args[-1] for _ in range(num_tensors)]] + sample = ForeachSampleInput(input, *args, ref_args=ref_args, **kwargs) + yield sample + args.pop() + else: + yield ForeachSampleInput( + input, + *args, + disable_fastpath=self._should_disable_fastpath(opinfo, None, None, dtype), + ) + + +class foreach_max_sample_func(foreach_inputs_sample_func): + def __init__( + self, + arity: int, + rightmost_supports_scalar: bool, + rightmost_supports_scalarlist: bool, + rightmost_supports_tensor: bool = False, + ) -> None: + super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor) + self._intersperse_empty = (False,) + + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + return [] + + def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + return False + + +class foreach_norm_sample_func(foreach_inputs_sample_func): + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + assert "num_input_tensors" not in kwargs + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + for ord in (0, 1, 2, -1, -2, float('inf'), float('-inf')): + input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) + disable_fastpath = True + if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): + disable_fastpath = False + yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath) + + def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): + num_input_tensors = kwargs.pop("num_input_tensors", foreach_num_tensors) + assert isinstance(num_input_tensors, list) + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + _allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) + + for num_tensors, ord, out_dtype, intersperse_empty_tensors in product( + num_input_tensors, + (0, 1, 2, -1, -2, float('inf'), float('-inf')), + (None,) + (torch.complex128,) if dtype in complex_types() else (torch.float64,), + (True, False), + ): + # inf norm and negative norms on empty tensors is not supported by our reference func vector norm: + # linalg.vector_norm cannot compute the inf norm on an empty tensor because the operation does not have an identity + if (ord in [float('inf'), float('-inf')] or ord < 0) and intersperse_empty_tensors: + continue + + _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors + input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) + disable_fastpath = True + if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): + disable_fastpath = False + yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath, dtype=out_dtype) + + # Also test nan propagation with a single tensor, but skip autograd testing + if not requires_grad: + nan_inputs = [ + [float('nan')], + [float('nan'), 1.0], + [1.0, float('nan')], + [1.0, 2.0, 3.0, float('nan'), float('nan'), 7.0, float('nan'), float('nan'), -1.5, 6.0], + [7.0, 3.0, float('nan'), float('nan'), -1.5, 6.0], + [3.0, float('nan'), float('nan'), -1.5, 6.0], + ] + for input in nan_inputs: + x = torch.tensor(input, device=device) + disable_fastpath = True + if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): + disable_fastpath = False + yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath) + + +class foreach_pointwise_sample_func(foreach_inputs_sample_func): + + def __init__( + self, + arity: int = 3, + rightmost_supports_scalar: bool = False, + rightmost_supports_scalarlist: bool = False, + ): + super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist) + + def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + return dtype in integral_types_and(torch.bool) and opinfo.ref == torch.addcmul + + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + assert "num_input_tensors" not in kwargs + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + # zero_size tensor + input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) + args = [ + sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) + for _ in range(2) + ] + kwargs.pop("scalars", None) + kwargs.update(self._sample_kwargs(opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype)) + yield ForeachSampleInput(input, *args, **kwargs) + + def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): + num_input_tensors_specified = "num_input_tensors" in kwargs + num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors + assert isinstance(num_input_tensors, list) + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) + + for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product( + num_input_tensors, self._rightmost_arg_types, (True, False)): + _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors + input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) + args = [ + sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) + for _ in range(2 - int(rightmost_arg_type == ForeachRightmostArgType.TensorList)) + ] + rightmost_arg_list = self._sample_rightmost_arg( + opinfo, + rightmost_arg_type, + device, + dtype, + num_tensors, + zero_size=False, + allow_higher_dtype_scalars=False if intersperse_empty_tensors else allow_higher_dtype_scalars, + **_foreach_inputs_kwargs, + ) + for rightmost_arg in rightmost_arg_list: + kwargs = {} + if rightmost_arg_type == ForeachRightmostArgType.TensorList: + args.append(rightmost_arg) + elif rightmost_arg_type in [ForeachRightmostArgType.Tensor, ForeachRightmostArgType.ScalarList]: + kwargs["scalars"] = rightmost_arg + else: + kwargs["value"] = rightmost_arg + kwargs.update(self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype)) + assert len(args) == 2, f"{len(args)=}" + sample = ForeachSampleInput(input, *args, **kwargs) + yield sample + if rightmost_arg_type == ForeachRightmostArgType.TensorList: + args.pop() + + +foreach_unary_op_db: list[OpInfo] = [ + ForeachFuncInfo( + 'exp', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32), + backward_requires_result=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'acos', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'asin', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'atan', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'cos', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'cosh', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'log', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'log10', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'log2', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'tan', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + backward_requires_result=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # due to https://github.com/pytorch/pytorch/pull/102427 enabling jiterator for complex + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + toleranceOverride( + { + torch.complex64: tol(atol=3e-04, rtol=2e-05) + } + ), + 'TestForeach', + 'test_parity', + device_type='cuda' + ), + ), + ), + ForeachFuncInfo( + 'tanh', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + backward_requires_result=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + toleranceOverride( + {torch.complex64: tol(atol=5e-03, rtol=1e-04)} + ), + 'TestForeach', + 'test_parity', + device_type='cuda' + ), + ), + ), + ForeachFuncInfo( + 'sin', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'sinh', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'neg', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_unary_op_tensors_on_different_devices", + device_type="cuda", + dtypes=(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'sqrt', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'rsqrt', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'ceil', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'erf', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'erfc', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'expm1', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'floor', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'log1p', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'round', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'frac', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'reciprocal', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'sigmoid', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'trunc', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'abs', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=complex_types()), + ), + ), + ForeachFuncInfo( + 'zero', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + supports_out=False, + ), + ForeachFuncInfo( + 'sign', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'lgamma', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta", + "test_dispatch_symbolic_meta_inplace", dtypes=integral_types_and(torch.bool)), + # DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta", + # "test_dispatch_meta_inplace", dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta", + "test_meta_inplace", dtypes=integral_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types() + integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types() + integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types() + integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), +] + +foreach_binary_op_db: list[OpInfo] = [ + ForeachFuncInfo( + "add", + sample_inputs_func=foreach_inputs_sample_func(2, True, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int32), + supports_alpha_param=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # These tests fail with aten._local_scalar_dense not being implemented. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)), + # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types() + complex_types_and(torch.bool, torch.bfloat16, torch.float16, torch.float64)), + ), + ), + ForeachFuncInfo( + "sub", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_alpha_param=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.skip("consistently fails internally and causes other tests to appear flaky"), + "TestForeach", "test_parity", dtypes=(torch.complex128,), + active_if=lambda kwargs: IS_FBCODE and not kwargs["noncontiguous"]), + ), + ), + ForeachFuncInfo( + "mul", + sample_inputs_func=foreach_inputs_sample_func(2, True, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=(torch.bool,)), + DecorateInfo(unittest.skip("consistently fails internally and causes other tests to appear flaky"), + "TestForeach", "test_parity", dtypes=(torch.complex128,), + active_if=lambda kwargs: IS_FBCODE and not kwargs["noncontiguous"]), + ), + ), + ForeachFuncInfo( + "div", + sample_inputs_func=foreach_inputs_sample_func(2, True, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int32, torch.int8), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types_and(torch.bool)), + ), + ), + ForeachFuncInfo( + "clamp_min", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int64, torch.int32, torch.int8, torch.bool), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_scalar_with_overlapping_tensors", + dtypes=complex_types(), + ), + ), + ), + ForeachFuncInfo( + "clamp_max", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int64, torch.int32, torch.int8, torch.bool), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_scalar_with_overlapping_tensors", + dtypes=complex_types(), + ), + ), + ), + # note(crcrpar): forward ad not implemented. + ForeachFuncInfo( + "minimum", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_autograd=True, + supports_inplace_autograd=False, + supports_forward_ad=False, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_scalar_with_overlapping_tensors", + dtypes=complex_types(), + ), + ), + ), + # note(crcrpar): forward ad not implemented. + ForeachFuncInfo( + "maximum", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_autograd=True, + supports_forward_ad=False, + supports_inplace_autograd=False, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_scalar_with_overlapping_tensors", + dtypes=complex_types(), + ), + ), + ), + ForeachFuncInfo( + "pow", + supports_alpha_param=False, + supports_scalar_self_arg=True, + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int32, torch.int8, torch.bool), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,),), + DecorateInfo(unittest.skip("flaky"), "TestForeach", "test_parity", device_type="cpu", dtypes=(torch.complex64,)), + DecorateInfo( + unittest.skip("failed starting on ROCm 6.2"), + "TestForeach", + "test_parity", + device_type="cuda", + dtypes=(torch.complex64,), + active_if=TEST_WITH_ROCM), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_with_scalar_self_support", + device_type="cuda", + dtypes=(torch.bool,), + active_if=lambda kwargs: kwargs["is_fastpath"], + ), + ), + backward_requires_result=True, + ), + ForeachFuncInfo( + "copy", + sample_inputs_func=foreach_inputs_sample_func(2, False, False), + supports_out=False, + supports_forward_ad=False, + supports_autograd=False, + supports_inplace_autograd=False, + ) +] + +foreach_pointwise_op_db: list[ForeachFuncInfo] = [ + ForeachFuncInfo( + "addcmul", + sample_inputs_func=foreach_pointwise_sample_func(4, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=(torch.bool,)), + # # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types() + complex_types_and(torch.bool)), + ), + ), + ForeachFuncInfo( + "addcdiv", + sample_inputs_func=foreach_pointwise_sample_func(4, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types() + complex_types_and(torch.bool)), + # fails with div_cpu is not implemented with ComplexHalf + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=integral_types() + complex_types_and(torch.bool)), + ), + ), +] + +foreach_reduce_op_db: list[ForeachFuncInfo] = [ + ForeachFuncInfo( + "max", + sample_inputs_func=foreach_max_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # no complex support for ordering ops like max + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_foreach_reduce_large_input", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + ), + ), + ForeachFuncInfo( + "norm", + sample_inputs_func=foreach_norm_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_foreach_reduce_large_input", + device_type="cuda", + dtypes=integral_types_and(torch.bool), + ), + ), + ), +] + +foreach_other_op_db: list[ForeachFuncInfo] = [ + ForeachFuncInfo( + "lerp", + sample_inputs_func=foreach_inputs_sample_func(3, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=integral_types_and(torch.bool), + ), + ), + ), +] + +def reference_sign(x): + if x.dtype == np.bool_: + # `np.sign` doesn't support `bool`. + # >>> np.sign(True) + # ufunc 'sign' did not contain a loop + # with signature matching types dtype('bool') -> dtype('bool') + return np.sign(x, dtype=np.uint8).astype(np.bool_) + return np.sign(x) + + +def reference_sgn(x): + # NumPy doesn't have an equivalent to `torch.sgn` when the dtype is complex. + # For complex inputs, `np.sign` returns sign(x.real) + 0j if x.real != 0 else sign(x.imag) + 0j. + # while `torch.sgn` returns, 0 if abs(input) == 0 else input/abs(input) + if x.dtype not in [np.complex64, np.complex128]: + return reference_sign(x) + + out = (x / np.abs(x)) + if out.ndim == 0: + # Handle x == 0 case + if (x == 0): + # Can't assign to np.complex object + # So make a new one. + return np.array(complex(0, 0), dtype=x.dtype) + return out + + # Handle x == 0 case + mask = (x == 0) + out[mask] = complex(0, 0) + return out + + +def reference_sigmoid(x): + # 'scipy.special.expit' not supported for the input types + if x.dtype in [np.complex64, np.complex128]: + return (1 / (1 + np.exp(-x))) + return scipy.special.expit(x) + + +def reference_logsigmoid(x): + return np.where( + x < 0, + x - np.log1p(np.exp(x)), + -np.log1p(np.exp(-x))) + + +def reference_hardsigmoid(x): + intermediate = x / 6 + 0.5 + y = np.clip(intermediate, 0, None) + return np.where(y > 1, 1, y).astype(x.dtype) + + +def reference_lgamma(x): + # scipy.special.gammaln returns `-inf` when input is `-inf`. + # While Pytorch, C and C++, all return `inf` when input is `-inf`. + # Reference: + # https://en.cppreference.com/w/cpp/numeric/math/lgamma + # https://en.cppreference.com/w/c/numeric/math/lgamma + + # To handle the above discrepancy, + # we replace -inf with inf so values + # that were originally -inf map to inf as expected + if x.dtype.kind == 'f': + x = np.where(x == float('-inf'), np.array(float('inf'), dtype=x.dtype), x) + + out = scipy.special.gammaln(x) + + if x.dtype == np.float16: + # `scipy.special.gammaln` returns output of float32 when input is float16, + # while `torch.lgamma` preserves `float16`. But due to smaller range of float16, + # Pytorch version outputs `inf` while SciPy returns finite values. + out = out.astype(np.float16) + + return out + + +def reference_mvlgamma(x, d): + if x.dtype == np.float16: + return scipy.special.multigammaln(x, d).astype(np.float16) + + return scipy.special.multigammaln(x, d) + +def reference_softplus(input, beta=1, threshold=20): + non_linear = input * beta <= threshold + output = input.copy() + output[non_linear] = np.log(1 + np.exp(beta * input[non_linear])) / beta + return output + +def reference_gelu(X, *, approximate='none'): + def _gelu_ref(X): + return X * stats.norm.cdf(X) + + def _tanh_gelu_ref(X): + M_SQRT_2_PI = math.sqrt(2 / math.pi) + Z = M_SQRT_2_PI * (X + 0.044715 * np.power(X, 3.0)) + return 0.5 * X * (1.0 + np.tanh(Z)) + + if approximate == 'tanh': + return _tanh_gelu_ref(X) + else: + return _gelu_ref(X) + + +def reference_one_hot(a: npt.NDArray, num_classes: int = -1) -> npt.NDArray: + if num_classes == -1: + num_classes = int(np.amax(a) + 1) + + idcs = a.reshape(-1) + np.arange(0, a.size, dtype=np.int64) * num_classes + one_hot = np.zeros((a.size, num_classes), dtype=a.dtype) + np.put(one_hot, idcs, 1) + return one_hot.reshape(*a.shape, -1) + + +def reference_mse_loss(input, target, reduction="mean"): + se = (input - target) ** 2 + if reduction == "mean": + return np.mean(se) + elif reduction == "sum": + return np.sum(se) + else: # reduction == "none" + return se + + +def reference_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int, ...], weight=None, bias=None, eps=1e-5): + return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0] + + +def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int, ...], weight, bias, eps): + feature_size = np.prod(normalized_shape) + inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] + mean = inp_view.mean(axis=-1, keepdims=True) + var = inp_view.var(axis=-1, ddof=0, keepdims=True) + Y = (inp_view - mean) / np.sqrt(var + eps) + if weight is None and bias is not None: + Y = Y + bias.reshape(-1) + elif weight is not None and bias is None: + Y = Y * weight.reshape(-1) + elif weight is not None and bias is not None: + Y = Y * weight.reshape(-1) + bias.reshape(-1) + axis = inp.ndim - len(normalized_shape) + stat_shape = inp.shape[:axis] + (1,) * len(normalized_shape) + return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape) + + +def reference_rms_norm(inp: npt.NDArray, normalized_shape: tuple[int, ...], weight=None, eps=None): + if eps is None: + eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps + feature_size = np.prod(normalized_shape) + inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] + rms = np.sqrt((inp_view**2).mean(axis=-1, keepdims=True) + eps) + Y = inp_view / rms + if weight is not None: + Y = Y * weight.reshape(-1) + return Y.reshape(*inp.shape) + + +def reference_group_norm(inp: npt.NDArray, num_groups: int, weight=None, bias=None, eps=1e-5): + inp_view = inp + if np.prod(inp.shape) != 0: + inp_view = inp.reshape((inp.shape[0], num_groups, -1)) + mean = inp_view.mean(axis=-1, keepdims=True) + var = inp_view.var(axis=-1, ddof=0, keepdims=True) + Y = (inp_view - mean) / np.sqrt(var + eps) + Y = Y.reshape(inp.shape) + if weight is not None: + # weight is a vector of length equal to the channel + if len(Y.shape) > 2: + weight = np.expand_dims(weight, [0] + [idx + 2 for idx in range(inp.ndim - 2)]) + Y = Y * weight + if bias is not None: + # bias is a vector of length equal to the channel + if len(Y.shape) > 2: + bias = np.expand_dims(bias, [0] + [idx + 2 for idx in range(inp.ndim - 2)]) + Y = Y + bias + return Y + + +# using a custom reference function since numpy only has a string side arg (instead of right and side) and doesn't +# have an out_int32 arg. Additionally, numpy doesn't support searchsorted with ND arrays, so this splits those into +# stacked 1D cases +def reference_searchsorted(sorted_sequence, boundary, out_int32=False, right=False, side='left', sorter=None): + side = 'right' if (right or side == 'right') else 'left' + if len(sorted_sequence.shape) == 1 : + ret = np.searchsorted(sorted_sequence, boundary, side=side, sorter=sorter) + return ret.astype(np.int32) if out_int32 else ret + elif sorted_sequence.shape[0] == 0: + if sorter is not None: + sorter = sorter.flatten() + ret = np.searchsorted(sorted_sequence.flatten(), boundary.flatten(), side=side, sorter=sorter) + ret = ret.astype(np.int32) if out_int32 else ret + return ret.reshape(boundary.shape) + else: + # numpy searchsorted only supports 1D inputs so we split up ND inputs + orig_shape = boundary.shape + num_splits = np.prod(sorted_sequence.shape[:-1]) + splits = range(num_splits) + sorted_sequence, boundary = sorted_sequence.reshape(num_splits, -1), boundary.reshape(num_splits, -1) + if sorter is not None: + sorter = sorter.reshape(num_splits, -1) + + split_sequence = [sorted_sequence[i] for i in splits] + split_boundary = [boundary[i] for i in splits] + split_sorter = [sorter[i] if (sorter is not None) else None for i in splits] + + split_ret = [np.searchsorted(s_seq, b, side=side, sorter=s_sort) + for (s_seq, b, s_sort) in zip(split_sequence, split_boundary, split_sorter, strict=True)] + split_ret = [i.astype(np.int32) for i in split_ret] if out_int32 else split_ret + return np.stack(split_ret).reshape(orig_shape) + +def reference_hash_tensor(tensor, dim=(), keepdim=False, mode=0): + assert mode == 0, "Only mode=0 (xor_sum) is supported right now" + + dtype = tensor.dtype + if dtype.kind == 'f': + tensor = tensor.astype(np.float64).view(np.uint64) + else: + tensor = tensor.astype(np.uint64) + + + if dim == (): + result = np.bitwise_xor.reduce(tensor.flatten(), keepdims=keepdim) + else: + if isinstance(dim, list): + dim = tuple(dim) + result = np.bitwise_xor.reduce(tensor, axis=dim, keepdims=keepdim) + + return result + + +def loss_reference_reduction_wrapper(fn): + def wrapper(input, target, *, size_average=None, reduce=None, reduction="mean", **other_kwargs): + if size_average is not None or reduce is not None: + raise RuntimeError( + "The keyword arguments 'size_average' and 'reduce' are deprecated and not supported by this wrapper" + ) + output = fn(input, target, **other_kwargs) + if reduction == "mean": + return np.mean(output) + elif reduction == "sum": + return np.sum(output) + else: # reduction == "none" + return output + + return wrapper + +@loss_reference_reduction_wrapper +def reference_smooth_l1_loss(input, target, beta=1.0): + diff = input - target + abs_diff = np.abs(diff) + above_threshold = abs_diff >= beta + + loss = np.empty_like(input) + loss[above_threshold] = abs_diff[above_threshold] - 0.5 * beta + loss[~above_threshold] = diff[~above_threshold] ** 2 / (2 * beta) + + return loss + +def reference_std_var(f): + """Forwards unbiased/correction kwargs as NumPy's equivalent ddof""" + g = reference_reduction_numpy(f) + + @wraps(g) + def wrapper(x: npt.NDArray, *args, **kwargs): + assert not ('unbiased' in kwargs and 'correction' in kwargs) + + if 'unbiased' in kwargs: + kwargs['ddof'] = int(kwargs.pop('unbiased')) + elif 'correction' in kwargs: + kwargs['ddof'] = kwargs.pop('correction') + + return g(x, *args, **kwargs) + + return wrapper + +def generate_std_var_kwargs(t: torch.Tensor, **kwargs): + """Generates unbiased/correction kwargs for std/var operators""" + yield ((), {'unbiased': True}) + yield ((), {'unbiased': False}) + + # Currently, calling std with correction is only enabled when + # both dim and keepdim are provided. + if 'dim' in kwargs and 'keepdim' in kwargs: + yield ((), {'correction': 0}) + yield ((), {'correction': 1}) + + numel = torch.tensor(t.shape)[kwargs.get('dim')].prod() + yield ((), {'correction': numel // 2}) + +def error_inputs_mean(op_info, device, is_ref=False, **kwargs): + if is_ref: + err_msg1 = (r"mean\(\): could not infer output dtype. " + r"Input dtype must be either a floating point or complex dtype. " + r"Got: torch.int64") + else: + err_msg1 = (r"mean\(\): could not infer output dtype. " + r"Input dtype must be either a floating point or complex dtype. " + r"Got: Long") + yield ErrorInput( + SampleInput(make_tensor((3, 4, 5), dtype=torch.int64, device=device), []), + error_regex=err_msg1, + ) + + if is_ref: + err_msg2 = (r"mean\(\): could not infer output dtype. " + r"Optional dtype must be either a floating point or complex dtype. " + r"Got: torch.int64") + else: + err_msg2 = (r"mean\(\): could not infer output dtype. " + r"Optional dtype must be either a floating point or complex dtype. " + r"Got: Long") + yield ErrorInput( + SampleInput( + make_tensor((3, 4, 5), dtype=torch.float32, device=device), + [], + dtype=torch.int64), + error_regex=err_msg2 + ) + +# numpy implementation of torch.flatten +# unfortunately there's no np.flatten. we figure out the desired shape and call np.reshape +def reference_flatten(input, start_dim=0, end_dim=-1): + in_shape = input.shape + in_rank = len(in_shape) + for d in start_dim, end_dim: + if not ((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank): + raise IndexError(f"Dimension out of range (expected to be in range of [{-in_rank}, {in_rank - 1}], but got {d}") + end_dim = end_dim if end_dim >= 0 else in_rank + end_dim + start_dim = start_dim if start_dim >= 0 else in_rank + start_dim + if in_rank == 0: + end_dim = start_dim + if end_dim < start_dim: + raise RuntimeError("flatten() has invalid args: start_dim cannot come after end_dim") + flatten_bit_dim = functools.reduce(operator.mul, in_shape[start_dim:end_dim + 1], 1) + out_shape = in_shape[:start_dim] + (flatten_bit_dim,) + in_shape[end_dim + 1:] + return np.reshape(input, out_shape) + + +def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput(make_tensor((S,), dtype=dtype, device=device, requires_grad=requires_grad)) + yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) + + +# Operator database (sorted alphabetically) +op_db: list[OpInfo] = [ + UnaryUfuncInfo('abs', + aliases=('absolute', ), + ref=np.abs, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + skips=( + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients', + 'test_inplace_grad', dtypes=(torch.cdouble,)), + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients', + 'test_inplace_gradgrad', dtypes=(torch.cdouble,)), + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestFwdGradients', + 'test_inplace_forward_mode_AD', dtypes=(torch.cdouble,)), + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestSparseUnaryUfuncs", + "test_inplace", dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + # Reference: https://github.com/pytorch/pytorch/issues/49224 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=[torch.int8], active_if=TEST_WITH_ASAN), + # TODO: Fix test_out_arg_all_dtypes as torch.empty_like(expected_output) where expected_output=op(input) + # We can break the logic of the loop over all possible types but it is OK. + # https://github.com/pytorch/pytorch/blob/master/test/test_unary_ufuncs.py#L440-L449 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes', + dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace', + dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace', + dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace', + dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace_all_strides', + dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + ), + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True), + # NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952) + UnaryUfuncInfo('acos', + aliases=('arccos', ), + ref=np.arccos, + domain=(-1, 1), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-1, + torch.complex64: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad', + dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_method_grad', + dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_inplace_grad', + dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', + dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_inplace_forward_mode_AD', + dtypes=[torch.cdouble], active_if=IS_WINDOWS),)), + # NOTE: the derivative for inplace acosh is not implemented + UnaryUfuncInfo('acosh', + aliases=('arccosh', ), + ref=np.arccosh, + domain=(1, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ), + # acosh is not defined at x < 1 (real) + reference_numerics_filter=NumericsFilter( + condition=lambda x: (x < 1 if not x.is_complex() else torch.zeros_like(x, dtype=torch.bool)), + safe_val=2)), + BinaryUfuncInfo('add', + # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate + ref=lambda input, other, *, alpha=1: ( + np.add(input, other) + if alpha == 1 + else np.add(input, np.multiply(alpha, other)) + ), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, + torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_add_sub, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + supports_two_python_scalars=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + ), + skips=( + # boolean alpha not handled properly + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bool,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestCommon', + 'test_numpy_refs', + dtypes=(torch.complex128,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=(torch.complex64, torch.complex128)), + )), + OpInfo('item', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.item, inp, *args, **kwargs), + ref=np.ndarray.item, + method_variant=None, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf, torch.bool), + dtypesIfHpu=custom_types(torch.float32), + supports_out=False, + supports_autograd=False, + error_inputs_func=error_inputs_item, + sample_inputs_func=sample_inputs_item, + skips=( + # Error testing item function variant + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float32, torch.complex64)), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: Composite compliance check failed with the above error. + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + # Booleans mismatch: AssertionError: False is not true + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast'), + # Booleans mismatch: AssertionError: False is not true + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake'), + )), + OpInfo('arange', + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_out=True, + supports_autograd=False, + is_factory_function=True, + error_inputs_func=error_inputs_arange, + sample_inputs_func=sample_inputs_arange, + skips=( + # https://github.com/pytorch/pytorch/issues/81774 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Lazy tensor failures + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), + + # Exception raised from analyzeImpl at ../torch/csrc/jit/ir/alias_analysis.cpp:608 + # We don't have an op for aten::arange but it isn't a special case. + # Argument types: bool, bool, bool, int, int, Device, boo + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), + + # Captured graph does not contain aten::arange (succeeds on complex!) + # g: graph(): + # %25 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={1}]() + # return (%25) + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), + OpInfo('cauchy', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.cauchy_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.cauchy_, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_cauchy, + error_inputs_func=error_inputs_cauchy, + skips=( + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # vmap: calling random operator not supported + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), + + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + )), + OpInfo('exponential', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.exponential_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.exponential_, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_exponential, + error_inputs_func=error_inputs_exponential, + skips=( + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # vmap: calling random operator not supported + DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('geometric', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.geometric_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.geometric_, + dtypes=floating_types_and(torch.float16, torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_geometric, + error_inputs_func=error_inputs_geometric, + skips=( + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # vmap: calling random operator not supported + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + )), + OpInfo('log_normal', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.log_normal_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.log_normal_, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_log_normal, + error_inputs_func=error_inputs_log_normal, + skips=( + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # vmap: calling random operator not supported + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + )), + OpInfo('normal', + variant_test_name='in_place', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.normal_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.normal_, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_normal, + error_inputs_func=error_inputs_normal, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), + + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # vmap: calling random operator not supported + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + )), + OpInfo('uniform', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.uniform_, inp, *args, **kwargs), + method_variant=None, + inplace_variant=torch.Tensor.uniform_, + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + is_factory_function=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_uniform, + error_inputs_func=error_inputs_uniform, + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # aten.uniform was not decomposed + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + BinaryUfuncInfo('clamp_max', + ref=_clamp_max_numpy, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_forward_ad=True, + supports_rhs_python_scalar=False, + supports_fwgrad_bwgrad=True, + rhs_make_tensor_kwargs=dict(exclude_zero=False), + skips=( + # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + # dispatch to lazy test failed + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'), + )), + BinaryUfuncInfo('clamp_min', + ref=_clamp_min_numpy, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_forward_ad=True, + supports_rhs_python_scalar=False, + supports_fwgrad_bwgrad=True, + rhs_make_tensor_kwargs=dict(exclude_zero=False), + skips=( + # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + # dispatch to lazy test failed + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'), + )), + BinaryUfuncInfo('mul', + aliases=('multiply',), + dtypes=all_types_and_complex_and(torch.chalf, torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + error_inputs_sparse_func=error_inputs_sparse_mul, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_bsc)), + BinaryUfuncInfo('sub', + # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate + ref=lambda input, other, *, alpha=1: np.subtract(input, np.multiply(alpha, other)), + aliases=('subtract',), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_add_sub, + supports_two_python_scalars=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0), + torch.bfloat16: tol(atol=1e-5, rtol=5e-3), + torch.complex32: tol(atol=1e-5, rtol=1e-3)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), + 'TestDecomp', 'test_comprehensive', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), + 'TestDecomp', 'test_quick', device_type='cpu'), + ), + skips=( + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.uint8,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + )), + OpInfo('addmm', + # This addmm OpInfo is for when alpha and beta are not both equal to 1. + # alpha=beta=1 is tested in the following opinfo, because that special case will + # trigger addmm being decomposed by a jit pass. + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=sample_inputs_addmm, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + )), + OpInfo('addmm', + # When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add. + variant_test_name='decomposed', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + autodiff_nonfusible_nodes=['aten::add', 'aten::mm'], + sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1), + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + # https://github.com/pytorch/pytorch/issues/71784 + DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', + device_type='cpu', dtypes=(torch.float16,)), + )), + OpInfo('addmv', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128, + torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.half: tol(atol=1e-5, rtol=3e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ], + sample_inputs_func=sample_inputs_addmv), + OpInfo('addbmm', + ref=lambda M, batch1, batch2, beta=1, alpha=1: np.add(np.multiply(np.asarray(beta, dtype=M.dtype), M), + np.multiply(np.asarray(alpha, dtype=batch1.dtype), + np.sum(np.matmul(batch1, batch2), axis=0))), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, + *[torch.bfloat16] + if SM53OrLater or TEST_WITH_ROCM else []), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=1.3e-05), + torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_numpy_refs'), + # MPS has slightly worse precision. Is this acceptable? + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-04), + torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_numpy_ref_mps'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5), + torch.bfloat16: tol(atol=2e-1, rtol=6e-1)}), + 'TestConsistency', + 'test_output_match', + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.5e-05, rtol=1e-05)}), + 'TestCommon', 'test_out'), + DecorateInfo( + toleranceOverride({torch.half: tol(atol=6e-3, rtol=1e-2)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), + ], + skips=( + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + # addbmm does not correctly warn when resizing out= inputs + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # https://github.com/pytorch/pytorch/issues/55907 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + ), + sample_inputs_func=sample_inputs_addbmm), + OpInfo('baddbmm', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128, + torch.bfloat16), + backward_dtypesIfCUDA=floating_types_and(torch.float16, + *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else [], + torch.complex64, torch.complex128), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + # Higher differences starting with Zen3 or Alder Lake + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=4e-05, rtol=4e-06)}), + 'TestDecomp', 'test_quick', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestMathBits', 'test_conj_view', device_type='cuda'), + ], + sample_inputs_func=sample_inputs_baddbmm, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + )), + OpInfo('dot', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_dot_vdot, + error_inputs_func=error_inputs_dot_vdot, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + )), + OpInfo('vdot', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_dot_vdot, + error_inputs_func=error_inputs_dot_vdot, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + )), + OpInfo('bmm', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, + *[torch.bfloat16] + if SM53OrLater or TEST_WITH_ROCM else []), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}), + "TestCommon", "test_out"), + # Fast math on MacOS-13? + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=2e-5, rtol=5e-6)}), + 'TestConsistency', + 'test_output_match', + active_if=lambda _: MACOS_VERSION < 14.0, + device_type='mps', + dtypes=(torch.float32,)), + ), + sample_inputs_func=sample_inputs_bmm), + OpInfo('mv', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_mv), + OpInfo('addr', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + # Reference: https://github.com/pytorch/pytorch/issues/50747 + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/50747 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)), + ), + sample_inputs_func=sample_inputs_addr, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + OpInfo('addcmul', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # TODO: update sample inputs with for_inplace_variant kwarg to support this test + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + ), + sample_inputs_func=sample_inputs_addcmul_addcdiv, + reference_inputs_func=partial( + reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)), + OpInfo('addcdiv', + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # TODO: update sample inputs with for_inplace_variant kwarg to support this test + DecorateInfo(unittest.expectedFailure, + 'TestCommon', + 'test_variant_consistency_eager'), + ), + sample_inputs_func=sample_inputs_addcmul_addcdiv, + reference_inputs_func=partial( + reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)), + UnaryUfuncInfo('asin', + aliases=('arcsin', ), + ref=np.arcsin, + domain=(-1, 1), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}), + 'TestUnaryUfuncs', device_type='cuda' + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=8e-5, rtol=4e-5)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda' + ), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=5e-05, rtol=2e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu' + ), + precisionOverride({torch.bfloat16: 1e-2}), + ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + # NOTE: derivative for inplace asinh is not implemented + UnaryUfuncInfo('asinh', + aliases=('arcsinh', ), + ref=np.arcsinh, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + UnaryUfuncInfo('atan', + aliases=('arctan', ), + ref=np.arctan, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + BinaryUfuncInfo('atan2', + aliases=('arctan2',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + supports_rhs_python_scalar=False, + skips=( + # Incorrectly attempts to use a scalar for the second argument + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), + )), + UnaryUfuncInfo('atanh', + aliases=('arctanh', ), + ref=np.arctanh, + domain=(-1, 1), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + decorators=[ + precisionOverride({torch.bfloat16: 1e-2}), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=9e-3, rtol=8e-5)}), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ], + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cfloat], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + OpInfo('allclose', + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + ref=np.allclose, + supports_autograd=False, + supports_forward_ad=False, + sample_inputs_func=sample_inputs_allclose, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + ), + supports_out=False), + OpInfo('broadcast_to', + ref=np.broadcast_to, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_broadcast_to), + OpInfo('broadcast_shapes', + op=torch.broadcast_shapes, + ref=np.broadcast_shapes if np.lib.NumpyVersion(np.__version__) >= '1.20.0' else None, + dtypes=_dispatch_dtypes((torch.float32,)), + supports_out=False, + supports_gradgrad=False, + assert_autodiffed=False, + supports_autograd=False, + supports_scripting=False, + sample_inputs_func=sample_inputs_broadcast_shapes, + skips=( + # https://github.com/pytorch/pytorch/issues/64997 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # skip dtype tests since broadcast_shape is not device dependent. + # having dtypes limited to torch.float32 would cause test_dtypes to report unexpected success + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'), + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('broadcast_tensors', + ref=np.broadcast_arrays, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_broadcast_tensors, + reference_inputs_func=reference_inputs_broadcast_tensors, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + # https://github.com/pytorch/pytorch/issues/64997 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + )), + OpInfo('block_diag', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # Default batching rule in core doesn't work for ops with TensorList args + check_batched_forward_grad=False, + skips=( + # https://github.com/pytorch/pytorch/issues/64997 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + ), + sample_inputs_func=sample_inputs_block_diag), + UnaryUfuncInfo('bitwise_not', + ref=np.bitwise_not, + dtypes=integral_types_and(torch.bool), + dtypesIfHpu=custom_types(torch.bool), + operator_variant=operator.invert, + supports_autograd=False), + BinaryUfuncInfo('bitwise_left_shift', + op=torch.bitwise_left_shift, + dtypes=integral_types(), + dtypesIfCUDA=integral_types(), + dtypesIfHpu=custom_types(torch.int32, torch.int8, torch.bool), + operator_variant=operator.lshift, + inplace_operator_variant=operator.ilshift, + supports_autograd=False, + supports_one_python_scalar=True, + rhs_make_tensor_kwargs=dict(low=0), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + # https://github.com/pytorch/pytorch/issues/70904 + DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), + )), + BinaryUfuncInfo('bitwise_right_shift', + op=torch.bitwise_right_shift, + dtypes=integral_types(), + dtypesIfCUDA=integral_types(), + dtypesIfHpu=custom_types(torch.int32, torch.int8, torch.bool), + operator_variant=operator.rshift, + inplace_operator_variant=operator.irshift, + supports_autograd=False, + supports_one_python_scalar=True, + rhs_make_tensor_kwargs=dict(low=0), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + # https://github.com/pytorch/pytorch/issues/70904 + DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('combinations', + op=torch.combinations, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + supports_out=False, + sample_inputs_func=sample_inputs_combinations), + OpInfo('cartesian_prod', + op=torch.cartesian_prod, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_cartesian_prod, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270 + DecorateInfo(unittest.expectedFailure, + 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + )), + OpInfo('cdist', + dtypes=floating_types(), + supports_out=False, + supports_gradgrad=False, + assert_autodiffed=False, + sample_inputs_func=sample_inputs_cdist), + UnaryUfuncInfo('ceil', + ref=np.ceil, + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=tuple(t for t in integral_types() if t != torch.uint8)), + ), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True), + OpInfo('cholesky', + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_cholesky, + gradcheck_wrapper=gradcheck_wrapper_hermitian_input, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],), + OpInfo('cholesky_inverse', + dtypes=floating_and_complex_types(), + backward_dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + check_batched_gradgrad=True, + sample_inputs_func=sample_inputs_linalg_cholesky_inverse, + gradcheck_wrapper=gradcheck_wrapper_triangular_input_real_positive_diagonal, + decorators=[ + skipCUDAIfNoMagma, + skipCPUIfNoLapack, + DecorateInfo( + toleranceOverride({ + torch.float32: tol(atol=5e-03, rtol=1e-04) + }), + 'TestCommon', device_type='cpu', + ), + DecorateInfo( + toleranceOverride({ + torch.float32: tol(atol=5e-03, rtol=1e-04) + }), + 'TestEagerFusionOpInfo', device_type='cpu', + ), + ], + skips=( + # Strides are not the same! Original strides were ((4, 2, 1),) and strides are now ((4, 1, 2),) + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),), + ), + OpInfo('cholesky_solve', + op=torch.cholesky_solve, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_cholesky_solve, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs), + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), + OpInfo('chunk', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_chunk, + reference_inputs_func=reference_inputs_chunk, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('unsafe_chunk', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_chunk, + check_batched_forward_grad=False, + reference_inputs_func=reference_inputs_chunk, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('clone', + ref=np.copy, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + sample_inputs_func=sample_inputs_clone_contiguous, + reference_inputs_func=reference_inputs_clone_contiguous, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + # TypeError: _copy_dispatcher() got an unexpected keyword argument 'memory_format' + # (NumPy reference needs to be extended with memory_format) + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'), + ),), + OpInfo('contiguous', + op=lambda x, *args, **kwargs: x.contiguous(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_clone_contiguous, + reference_inputs_func=reference_inputs_clone_contiguous, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_fusible_nodes=['aten::contiguous'], + assert_jit_shape_analysis=True, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + OpInfo('sum_to_size', + op=lambda x, *args, **kwargs: x.sum_to_size(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_sum_to_size, + error_inputs_func=error_inputs_sum_to_size, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float,)), + )), + OpInfo('clamp', + aliases=('clip',), + ref=_clamp_numpy, + dtypes=all_types_and(torch.bfloat16, torch.half), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + sample_inputs_func=sample_inputs_clamp, + reference_inputs_func=partial(reference_inputs_elementwise_ternary, sample_inputs_func=sample_inputs_clamp), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # NNC appear to not handle boolean clamp + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bool,)), + # MPS does not support float64, while numpy does internal computations in float64. + # See https://github.com/pytorch/pytorch/blob/3c1cf03fde145bdbe1f5ffb81765d076c10b4c04/test/test_ops.py#L260-L264 + DecorateInfo(unittest.expectedFailure, + 'TestCommon', + 'test_numpy_ref_mps'), + )), + UnaryUfuncInfo('positive', + ref=np.positive, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + ), + UnaryUfuncInfo('conj', + ref=np.conj, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, + torch.half, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.int32), + supports_sparse=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + supports_out=False), + UnaryUfuncInfo('conj_physical', + decomp_aten_name='_conj_physical', + ref=np.conj, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, + torch.half, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + skips=( + # RuntimeError: inputSet && outputSet + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":118, + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, )), + DecorateInfo(unittest.skip("Skipped! conj_physical_ not implemented for sparse"), + 'TestSparseUnaryUfuncs', 'test_inplace'), + )), + OpInfo('resolve_conj', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_view_as_real, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + ), + OpInfo('resolve_neg', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_view_as_real, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + ), + OpInfo('view_as_real', + dtypes=complex_types(), + supports_forward_ad=True, + supports_out=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_view_as_real, + test_conjugated_samples=False, + ), + OpInfo('view_as_complex', + dtypes=floating_types_and(torch.half), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + test_neg_view=False, + sample_inputs_func=sample_inputs_view_as_complex, + skips=( + # RuntimeError: Tensor must have a last dimension with stride 1 + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), + # RuntimeError: "eq_cpu" not implemented for 'ComplexHalf' + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.half,)), + # RuntimeError: view size is not compatible with input tensor's size and stride + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + )), + BinaryUfuncInfo('complex', + dtypes=floating_types_and(torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False, + error_inputs_func=error_inputs_complex, + skips=( + # Tests don't account for complex's type promotion semantics + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),)), + BinaryUfuncInfo('copysign', + sample_inputs_func=sample_inputs_copysign, + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + promotes_int_to_float=True, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), + OpInfo('corrcoef', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_corrcoef, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + ), + supports_out=False), + UnaryUfuncInfo('cos', + ref=np.cos, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + handles_large_floats=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), + # This fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), + )), + UnaryUfuncInfo('cosh', + ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48641 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.int8]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (6000,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), + )), + OpInfo('cov', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_cov, + error_inputs_func=error_inputs_cov, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + # Float did not match double + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'), + # Jacobian mismatch + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.skip("Barely fails"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + # JIT test not working for tensor kwargs (https://github.com/pytorch/pytorch/issues/58507) + # RuntimeError: + # undefined value tensor: + # File "", line 3 + # def the_method(i0): + # return torch.cov(i0, correction=0, fweights=None, aweights=tensor([0.0518, 0.4681], dtype=torch.float32, requires_grad=True)) # noqa: B950 + # ~~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=1.4e-3)}), + "TestInductorOpInfo", "test_comprehensive", device_type="cpu"), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=3e-4, rtol=1e-4)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + )), + OpInfo('cross', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_cross, + supports_fwgrad_bwgrad=True, + supports_out=True, + supports_forward_ad=True), + OpInfo('cumsum', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # cumsum does not handle correctly out= dtypes + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + ), + sample_inputs_func=sample_inputs_cumulative_ops), + OpInfo('cumprod', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # cumprod does not handle correctly out= dtypes + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + ), + # gradgradcheck fails in fast_mode=True: #56275 + sample_inputs_func=sample_inputs_cumprod, + gradcheck_fast_mode=False), + OpInfo('cummax', + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + OpInfo('cummin', + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + UnaryUfuncInfo('deg2rad', + ref=np.radians, + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 7e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True), + OpInfo('diff', + op=torch.diff, + # np.diff has np._NoValue as default values for prepend and append, compare_with_reference breaks if prepend/append + # are set as None when converting to numpy + ref=lambda input, n=1, dim=-1, prepend=np._NoValue, append=np._NoValue: ( + np.diff(input, n, dim, np._NoValue if prepend is None else prepend, np._NoValue if append is None else append) + ), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diff, + error_inputs_func=error_inputs_diff, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + )), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='no_rounding_mode', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + promotes_int_to_float=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True),), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='trunc_rounding', + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + sample_kwargs=lambda device, dtype, input: + ({"rounding_mode": "trunc"}, {"rounding_mode": "trunc"}), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True), + decorators=( + # See https://github.com/pytorch/pytorch/issues/111126 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + ), + skips=( + # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'), + # FIXME: + # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for + # output 0 with respect to input 1, + # numerical:tensor(-17746.9307, dtype=torch.float64) + # analytical:tensor(0., dtype=torch.float64) + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', + 'test_fn_grad', device_type='cpu', + dtypes=(torch.float64,)), + )), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='floor_rounding', + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + sample_kwargs=lambda device, dtype, input: + ({"rounding_mode": "floor"}, {"rounding_mode": "floor"}), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True), + decorators=( + # See https://github.com/pytorch/pytorch/issues/111126 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + ), + skips=( + # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'), + # FIXME: + # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for + # output 0 with respect to input 1, + # numerical:tensor(-17746.9307, dtype=torch.float64) + # analytical:tensor(0., dtype=torch.float64) + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', + 'test_fn_grad', + dtypes=(torch.float64,), + device_type='cpu'), + DecorateInfo(unittest.skip("Broken on MacOS13"), + 'TestConsistency', + 'test_output_match', + device_type='mps', + dtypes=(torch.float16,), + active_if=lambda _: MACOS_VERSION < 14.0), + )), + BinaryUfuncInfo('true_divide', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_forward_ad=True, + promotes_int_to_float=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True)), + OpInfo('equal', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + ref=lambda input, other: (input == other).all(), + sample_inputs_func=sample_inputs_equal, + supports_autograd=False, + supports_tracing=False, + skips=( + )), + UnaryUfuncInfo('exp', + ref=np_unary_ufunc_integer_promotion_wrapper(np.exp), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48010 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + ), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + OpInfo('expand', + op=lambda self, shape: self.expand(shape), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_expand, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + OpInfo('expand_as', + op=lambda self, other: self.expand_as(other), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_expand_as, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),), + ), + OpInfo('expand_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_expand, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + supports_out=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + )), + OpInfo('diag', + ref=np.diag, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_diag, + error_inputs_func=error_inputs_diag), + OpInfo('diag_embed', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_diag_embed, + reference_inputs_func=reference_inputs_diagonal_diag_embed, + error_inputs_func=error_inputs_diagonal_diag_embed), + OpInfo('diagonal', + aten_backward_name='diagonal_backward', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_diag_embed, + reference_inputs_func=reference_inputs_diagonal_diag_embed, + error_inputs_func=error_inputs_diagonal_diag_embed), + OpInfo('diagonal_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_diag_embed, + reference_inputs_func=reference_inputs_diagonal_diag_embed, + error_inputs_func=error_inputs_diagonal_diag_embed), + OpInfo('diagonal_scatter', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_scatter), + OpInfo('alias_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_alias_copy, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=True), + BinaryUfuncInfo('eq', + ref=np.equal, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + always_returns_bool=True, + supports_autograd=False, + sample_inputs_func=sample_inputs_comparison_ops, + skips=( + )), + BinaryUfuncInfo('fmax', + op=torch.fmax, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + )), + BinaryUfuncInfo('fmin', + op=torch.fmin, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + )), + BinaryUfuncInfo('fmod', + ref=np.fmod, + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=None, + rhs_make_tensor_kwargs={'exclude_zero': True}, + decorators=( + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_contig_vs_every_other', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_non_contig', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + # FIXME: + # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for + # output 0 with respect to input 1, + # numerical:tensor(101.6283, dtype=torch.float64) + # analytical:tensor(-18.3575, dtype=torch.float64) + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', + 'test_fn_grad', + dtypes=(torch.float64,), + device_type='cpu'), + )), + BinaryUfuncInfo('remainder', + ref=np.remainder, + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=None, + operator_variant=operator.mod, + inplace_operator_variant=operator.imod, + supports_one_python_scalar=True, + rhs_make_tensor_kwargs={'exclude_zero': True}, + decorators=( + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_contig_vs_every_other', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_non_contig', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bfloat16,)), + # Fails on XLA + # False is not true : Tensors failed to compare as equal! + # Attempted to compare equality of tensors with different dtypes + DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)), + # FIXME: + # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for + # output 0 with respect to input 1, + # numerical:tensor(102.4676, dtype=torch.float64) + # analytical:tensor(-17.5182, dtype=torch.float64) + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', + 'test_fn_grad', device_type='cpu', + dtypes=(torch.float64,)), + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=5e-4, rtol=3e-3), + }), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + DecorateInfo(unittest.skip("Broken on MacOS13"), + 'TestConsistency', + 'test_output_match', + device_type='mps', + dtypes=(torch.float16,), + active_if=lambda _: MACOS_VERSION < 14.0), + )), + UnaryUfuncInfo('frac', + ref=lambda x: np.modf(x)[0], + dtypes=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)), + # 76047 + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.bfloat16, torch.float32, torch.float64)), + )), + OpInfo('stft', + decorators=[ + skipCPUIfNoFFT, + DecorateInfo(unittest.skip("Skipped! stft does not match the native function"), + 'TestJit', 'test_variant_consistency_jit'), + ], + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_stft, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_out=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + ), + OpInfo('istft', + dtypes=complex_types(), + sample_inputs_func=sample_inputs_istft, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_out=False, + decorators=( + DecorateInfo(unittest.skip("Skipped! istft does not match the native function"), + 'TestJit', 'test_variant_consistency_jit'), + ), + skips=( + skipCPUIfNoFFT, + # gradcheck fails on ROCm (gh-68429) + # grad is computed improperly (probably for weights tensor) + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + )), + UnaryUfuncInfo('floor', + ref=np.floor, + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=tuple(t for t in integral_types() if t != torch.uint8)), + ), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True), + OpInfo('flip', + op=torch.flip, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + sample_inputs_func=sample_inputs_flip, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('fliplr', + op=torch.fliplr, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fliplr_flipud, + error_inputs_func=error_inputs_fliplr, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('flipud', + op=torch.flipud, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fliplr_flipud, + error_inputs_func=error_inputs_flipud, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('sparse.sampled_addmm', + dtypes=floating_and_complex_types(), + supports_autograd=True, + sample_inputs_func=sample_inputs_sparse_sampled_addmm, + decorators=[ + skipCPUIfNoMklSparse, + skipXPU], + skips=( + # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # RuntimeError: Sparse CSR tensors do not have strides. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), + # RuntimeError: sampled_addmm: Expected result to have sparse csr layout, but got Strided + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: unsupported memory format option Preserve + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: sparse_mask does not support automatic differentiation for outputs with complex dtype + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + # RuntimeError: sparse_mask does not support automatic differentiation for outputs with complex dtype. + # RuntimeError: Sparse CSR tensors do not have is_contiguous + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # NotImplementedError: Could not run 'aten::sparse_sampled_addmm' with arguments from the 'SparseCsrMeta' backend. + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'), + )), + OpInfo('sparse.mm', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + variant_test_name='reduce', + supports_autograd=True, + supports_out=False, + supports_gradgrad=False, + supports_forward_ad=False, + sample_inputs_func=sample_inputs_sparse_mm_reduce, + decorators=[onlyCPU], + skips=( + # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # RuntimeError: Sparse CSR tensors do not have strides. + DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: unsupported memory format option Preserve + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + # RuntimeError: Sparse CSR tensors do not have is_contiguou + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'), + # NotImplementedError: Could not run 'aten::_sparse_mm_reduce_impl' with arguments from the 'SparseCsrMeta' backend + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'), + )), + UnaryUfuncInfo('i0', + ref=np_unary_ufunc_integer_promotion_wrapper( + scipy.special.i0) if TEST_SCIPY else None, + aliases=('special.i0',), + decorators=(precisionOverride({torch.bfloat16: 3e-1, + torch.float16: 5e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + sample_inputs_func=sample_inputs_i0_i1, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.int8,)), + )), + BinaryUfuncInfo('floor_divide', + ref=_floor_divide_np, + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + + supports_autograd=False, + rhs_make_tensor_kwargs=dict(exclude_zero=True), + supports_two_python_scalars=True, + skips=( + # AssertionError: Results of original model and exported/imported version of model differed + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + # bfloat16 floor_divide compared with a float32 reference works inconsistently + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', + dtypes=(torch.bfloat16,)), + # int8 floor divide has different results for -128 // -1 vs. NumPy + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_small_values', + dtypes=(torch.int8,)), + # The following tests fails on some jobs + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', + dtypes=(torch.float16,)), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + )), + UnaryUfuncInfo('frexp', + op=torch.frexp, + ref=np.frexp, + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + # skip testing torch.frexp as it is not supported by ROCm platform yet + decorators=[], + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # skips below tests as torch.frexp returns tuple-like (mantissa, exponent) as outputs, + # while these tests currently requires output to a single tensor. + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_non_contig_expand'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), + + # skips test_reference_numerics due to error in Windows CI. + # The np.frexp returns exponent as np.intc dtype on Windows platform, + # and np.intc does not have the correspond torch dtype + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + active_if=IS_WINDOWS), + )), + UnaryUfuncInfo('log1p', + ref=np.log1p, + aliases=('special.log1p',), + domain=(-1, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + decorators=(precisionOverride({torch.bfloat16: 1e-1}),), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True, + promotes_int_to_float=True), + BinaryUfuncInfo('ge', + ref=np.greater_equal, + aliases=('greater_equal',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + OpInfo('geqrf', + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_qr_geqrf, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + supports_autograd=False, + skips=( + # FIXME: geqrf can't forward with complex inputs that require grad + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + )), + BinaryUfuncInfo('gt', + ref=np.greater, + aliases=('greater',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + UnaryUfuncInfo('imag', + ref=np.imag, + dtypes=complex_types_and(torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + # RuntimeError: view_as_real doesn't work on unresolved conjugated tensors. + check_batched_forward_grad=False, + skips=( + # Skip since real and imag don't have out variants. + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), + )), + OpInfo('gradient', + dtypes=floating_and_complex_types_and(torch.int8, torch.int16, + torch.int32, torch.int64, + torch.bfloat16, torch.half), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # following tests give a runtime error with undefined value tensor + # see discussion : https://github.com/pytorch/pytorch/issues/56660 + # RuntimeError: + # Arguments for call are not valid. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950 + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + ), + supports_inplace_autograd=False, + sample_inputs_func=sample_inputs_gradient, + error_inputs_func=error_inputs_gradient), + OpInfo('isin', + dtypes=all_types_and(torch.bfloat16, torch.half), + supports_autograd=False, + sample_inputs_func=sample_inputs_isin), + OpInfo('kthvalue', + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_kthvalue, + error_inputs_func=error_inputs_kthvalue), + BinaryUfuncInfo('le', + ref=np.less_equal, + aliases=('less_equal',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + OpInfo('linspace', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_linspace, + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + OpInfo('linspace', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_linspace_tensor_overload, + variant_test_name="tensor_overload", + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # TypeError: 'int' object is not subscriptable + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + OpInfo('logspace', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_logspace, + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + OpInfo('logspace', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_logspace_tensor_overload, + variant_test_name="tensor_overload", + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # TypeError: 'int' object is not subscriptable + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + UnaryUfuncInfo('log', + ref=np.log, + domain=(0, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + # log(z)->-inf for |z|->0 + reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), + UnaryUfuncInfo('log10', + ref=np.log10, + domain=(0, None), + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + # log10(z)->-inf for |z|->0 + reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), + UnaryUfuncInfo('log2', + ref=np.log2, + domain=(0, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 1e-1}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble]), + ), + # log2(z)->-inf for |z|->0 + reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), + BinaryUfuncInfo('ldexp', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_inplace_autograd=False, + promotes_int_to_float=True, + supports_out=True, + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: mul(): functions with out=... arguments don't support + # automatic differentiation, but one of the arguments requires grad + # https://github.com/pytorch/pytorch/issues/68966 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.complex64: tol(atol=1e-05, rtol=1e-05) + }), + 'TestCommon', device_type='cpu', + ), + ], ), + BinaryUfuncInfo('logaddexp', + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.float16, torch.complex32), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False), + OpInfo('logaddexp2', + dtypes=floating_types_and(torch.bfloat16, torch.half), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_logaddexp), + UnaryUfuncInfo('logical_not', + ref=np.logical_not, + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 5e-1}),), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool), + supports_autograd=False, + skips=( + # The function variant always returns BoolTensor + # while the inplace variant preserves the input dtype. + # >>> t = torch.randn(3) + # >>> torch.logical_not(t) + # tensor([False, False, False]) + # >>> torch.logical_not(t).dtype + # torch.bool + # >>> t.logical_not_().dtype + # torch.float32 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)), + )), + BinaryUfuncInfo('lt', + ref=np.less, + aliases=('less',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.int32), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + OpInfo('lu_unpack', + op=torch.lu_unpack, + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=(skipCPUIfNoLapack,), + sample_inputs_func=sample_inputs_lu_unpack), + OpInfo('lu', + op=torch.lu, + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_lu, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + # we skip jit tests because `lu` is a torch function + # RuntimeError: + # 'Tensor (inferred)' object has no attribute or method 'lu'.: + # File "", line 3 + # def the_method(i0): + # return i0.lu(True, True) + # ~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError not raised: Expected RuntimeError when calling with input.device=cpu and out.device=cuda + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), + OpInfo('lu_solve', + op=torch.lu_solve, + dtypes=floating_and_complex_types(), + supports_forward_ad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_lu_solve, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Tests different backward paths"), + "TestCommon", "test_floating_inputs_are_differentiable"),), + decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver]), + OpInfo('masked_fill', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool, torch.int32), + sample_inputs_func=sample_inputs_masked_fill, + error_inputs_func=error_inputs_masked_fill, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + supports_out=False), + OpInfo('masked_scatter', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool, torch.int32), + sample_inputs_func=sample_inputs_masked_scatter, + error_inputs_func=error_inputs_masked_scatter, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + supports_out=False, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('masked_select', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_masked_select, + error_inputs_func=error_inputs_masked_select, + skips=( + # Compiler issue on ROCm. Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('matrix_exp', + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + aliases=('linalg.matrix_exp',), + sample_inputs_func=sample_inputs_matrix_exp, + # Needs to construct a 2nx2n matrix by copy_ ing into it + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + # mexp does not support bf16 and fp16 + DecorateInfo(unittest.skip('Skipped!'), 'TestInductorOpInfo', 'test_comprehensive', + dtypes=[torch.half], device_type="cpu"), + ), + supports_out=False, + ), + OpInfo('matmul', + aliases=('linalg.matmul',), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, + *[torch.bfloat16] + if SM53OrLater or TEST_WITH_ROCM else []), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + assert_jit_shape_analysis=True, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=False), + decorators=[ + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + # ROCm intermittently fails the test with standard atol/rtol + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}), + 'TestCommon', 'test_noncontiguous_samples', device_type='cuda', + active_if=TEST_WITH_ROCM), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}), + 'TestCommon', 'test_out', device_type='cuda', + active_if=TEST_WITH_ROCM), + # mv for the sample with shapes (S, S, M, M), (M,) has some variance in the + # backward on CPU + DecorateInfo(toleranceOverride({torch.float32: tol(atol=0, rtol=1e-5)}), + 'TestCommon', 'test_noncontiguous_samples', + device_type='cpu'), + DecorateInfo( + toleranceOverride({ + torch.float32: tol(atol=1e-5, rtol=1e-5), + torch.complex64: tol(atol=1e-5, rtol=1e-5), + }), + "TestDecomp", "test_comprehensive", device_type="cuda", + ), + ], + skips=( + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # https://github.com/pytorch/pytorch/issues/67470 + DecorateInfo(unittest.skip("67470!"), + 'TestCommon', 'test_noncontiguous_samples', + device_type='cpu', dtypes=(torch.long,)), + # AssertionError: False is not true : Tensors failed to compare as equal! + DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', + device_type='xla', dtypes=(torch.long,)), + # https://github.com/pytorch/pytorch/issues/71774 + DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', + device_type='cpu', dtypes=(torch.long,)), + )), + OpInfo('max', + variant_test_name='reduction_with_dim', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_max_min_reduction_with_dim, + supports_fwgrad_bwgrad=True, + skips=( + ), + supports_forward_ad=True), + OpInfo('max', + variant_test_name='reduction_no_dim', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_max_min_reduction_no_dim, + skips=( + )), + OpInfo('median', + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + # TODO: some signatures of median do support out + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_median, + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), + OpInfo('nanmedian', + dtypes=all_types_and(torch.bfloat16, torch.float16), + # TODO: some signatures of nanmedian do support out + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), + OpInfo('var_mean', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + # TODO: some signatures of var_mean do support out + supports_out=False, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}), + "TestDecomp", "test_comprehensive", device_type="cuda"), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestInductorOpInfo", "test_comprehensive", device_type="cuda"), + )), + OpInfo('var_mean', + variant_test_name='unbiased', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var_unbiased, + # TODO: some signatures of var_mean do support out + supports_out=False, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}), + "TestDecomp", "test_comprehensive", device_type="cuda"), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestInductorOpInfo", "test_comprehensive", device_type="cuda"), + )), + OpInfo('std_mean', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + # TODO: some signatures of std_mean do support out + supports_out=False, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}), + "TestDecomp", "test_comprehensive", device_type="cuda"), + )), + OpInfo('std_mean', + variant_test_name='unbiased', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var_unbiased, + # TODO: some signatures of var_mean do support out + supports_out=False, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=4e-5, rtol=9e-3), + torch.float64: tol(atol=2e-7, rtol=2e-7), + }), + "TestDecomp", + "test_comprehensive", + device_type="cuda" + ), + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=4e-5, rtol=9e-3), + torch.float64: tol(atol=2e-7, rtol=2e-7), + }), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + )), + OpInfo('meshgrid', + variant_test_name='variadic_tensors', + ref=np.meshgrid, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_meshgrid, variant='variadic'), + skips=[ + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # meshgrid is defined in torch.functional to take a + # variadic list of tensors. Variadic parameters are not + # compatible with the normalize operator tests. + DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Skip operator schema test because this is a functional and not an operator + DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + ], + supports_out=False, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False,), + OpInfo('meshgrid', + variant_test_name='list_of_tensors', + # Unlike the variant above, we do not use np.meshgrid as a + # ref since it does not officially support list of numpy + # arrays. + dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_meshgrid, variant='list'), + skips=[ + # meshgrid is defined in torch.functional to take a + # variadic list of tensors. Variadic parameters are not + # compatible with the normalize operator tests. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + ], + assert_autodiffed=True, + supports_out=False, + autodiff_nonfusible_nodes=[], + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False,), + OpInfo('min', + variant_test_name='reduction_with_dim', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_max_min_reduction_with_dim, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + skips=( + )), + OpInfo('min', + variant_test_name='reduction_no_dim', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_max_min_reduction_no_dim, + skips=( + )), + OpInfo('quantile', + dtypes=floating_types(), + sample_inputs_func=sample_inputs_reduction_quantile, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which + # does not have a batching rule in core + check_batched_forward_grad=False), + OpInfo('nanquantile', + dtypes=floating_types(), + sample_inputs_func=sample_inputs_reduction_quantile, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which + # does not have a batching rule in core + check_batched_forward_grad=False), + BinaryUfuncInfo( + 'max', + aliases=('maximum',), + variant_test_name='binary', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + ref=np.maximum, + supports_rhs_python_scalar=False, + skips=( + # Incorrectly attempts to use a scalar for the second argument + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), + # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'), + )), + BinaryUfuncInfo( + 'maximum', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ref=np.maximum, + supports_rhs_python_scalar=False, + skips=( + # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'), + )), + BinaryUfuncInfo( + 'min', + aliases=('minimum',), + variant_test_name='binary', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + ref=np.minimum, + supports_rhs_python_scalar=False, + skips=( + # Incorrectly attempts to use a scalar for the second argument + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), + # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + )), + BinaryUfuncInfo( + 'minimum', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ref=np.minimum, + supports_rhs_python_scalar=False, + skips=( + # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + ), + ), + BinaryUfuncInfo('logical_and', + ref=np.logical_and, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_autograd=False, + always_returns_bool=True, + supports_rhs_python_scalar=False), + BinaryUfuncInfo('logical_or', + ref=np.logical_or, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool), + supports_autograd=False, + always_returns_bool=True, + supports_rhs_python_scalar=False), + BinaryUfuncInfo('logical_xor', + ref=np.logical_xor, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool), + supports_autograd=False, + always_returns_bool=True, + supports_rhs_python_scalar=False, + skips=( + )), + BinaryUfuncInfo('bitwise_and', + ref=np.bitwise_and, + dtypes=integral_types_and(torch.bool), + dtypesIfHpu=custom_types(torch.bool), + operator_variant=operator.and_, + inplace_operator_variant=operator.iand, + supports_autograd=False, + supports_one_python_scalar=True, + skips=( + # RuntimeError: "bitwise_and_cuda" not implemented for 'Half' + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', + 'test_type_promotion', device_type='cuda'), + )), + BinaryUfuncInfo('bitwise_or', + ref=np.bitwise_or, + dtypes=integral_types_and(torch.bool), + dtypesIfHpu=custom_types(torch.bool), + operator_variant=operator.or_, + inplace_operator_variant=operator.ior, + supports_autograd=False, + supports_one_python_scalar=True, + skips=( + # TODO: FIXME: RuntimeError: "bitwise_or_cuda" not implemented for 'Half' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + )), + BinaryUfuncInfo('bitwise_xor', + ref=np.bitwise_xor, + dtypes=integral_types_and(torch.bool), + dtypesIfHpu=custom_types(torch.bool), + operator_variant=operator.xor, + inplace_operator_variant=operator.ixor, + supports_autograd=False, + supports_one_python_scalar=True, + skips=( + # TODO: FIXME: RuntimeError: "bitwise_xor_cuda" not implemented for 'Half' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + )), + BinaryUfuncInfo('heaviside', + ref=lambda a, b: ( + # necessary because np.heaviside incorrectly returns float64 when passed args of dtype int64 + np.int64(np.heaviside(a, b)) if a.dtype == np.int64 and b.dtype == np.int64 else np.heaviside(a, b) + ), + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_autograd=False, + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: heaviside is not yet implemented for tensors with different dtypes. + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + # PyTorch's heaviside does not appear to propagate NaNs + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values'), + )), + BinaryUfuncInfo('lcm', + ref=np.lcm, + dtypes=integral_types_and(), + supports_autograd=False, + supports_rhs_python_scalar=False), + BinaryUfuncInfo('gcd', + ref=np.gcd, + dtypes=integral_types_and(), + supports_autograd=False, + supports_rhs_python_scalar=False, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.int8,)),)), + BinaryUfuncInfo('isclose', + ref=np.isclose, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_isclose, + error_inputs_func=error_inputs_isclose, + supports_autograd=False, + supports_out=False, + supports_rhs_python_scalar=False, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestCommon', + 'test_numpy_refs', dtypes=(torch.complex128,)), + # RuntimeError: Short did not match Int + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values'), + )), + # `softmax` supports different dtypes based on whether `dtype` argument, + # is passed or not. Hence two OpInfo entries, one with dtype and other without. + # https://github.com/pytorch/pytorch/issues/68752 + OpInfo('softmax', + aliases=('special.softmax', 'nn.functional.softmax',), + aten_name='softmax', + aten_backward_name='_softmax_backward_data', + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_softmax_variant, + assert_jit_shape_analysis=True, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=True), + OpInfo('softmax', + aliases=('special.softmax', 'nn.functional.softmax',), + variant_test_name="with_dtype", + aten_name='softmax', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=True), + OpInfo( + '_softmax_backward_data', + op=torch.ops.aten._softmax_backward_data, + aten_name='_softmax_backward_data', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_softmax_backward_data, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ), + ), + # `softmin` supports different dtypes based on whether `dtype` argument, + # is passed or not. Hence two OpInfo entries, one with dtype and other without. + # https://github.com/pytorch/pytorch/issues/68752 + OpInfo('nn.functional.softmin', + aten_name='softmin', + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_softmax_variant, + assert_jit_shape_analysis=False, + assert_autodiffed=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('nn.functional.softmin', + variant_test_name="with_dtype", + aten_name='softmin', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), + assert_autodiffed=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo( + "nn.functional.cross_entropy", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_cross_entropy, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=3e-3, rtol=1e-3)}), + "TestJit", + "test_variant_consistency_jit", + device_type="cpu", + ), + ), + skips=( + # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 1536 + # test_ops.TestJitCUDA.test_variant_consistency_jit_nn_functional_cross_entropy_cuda_float32 leaked + # 1536 bytes CUDA memory on device 0 + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + DecorateInfo(unittest.skip("FP16 corss_entropy cases have not been enabled on MPS yet"), + dtypes=(torch.half,), device_type="mps"), + + ) + ), + OpInfo('nn.functional.normalize', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_normalize, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), + OpInfo('aminmax', + ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)), + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + decorators=(onlyNativeDeviceTypes,), + supports_autograd=False, + sample_inputs_func=sample_inputs_aminmax, + error_inputs_func=error_inputs_aminmax_amax_amin), + OpInfo('as_strided', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_as_strided, + skips=( + # Note: This xfail is fine -- it's inherent to how as_strided works + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'), + # AssertionError: False is not true : Scalars failed to compare as equal! + DecorateInfo(unittest.skip("Errors when storage_offset is included"), + 'TestCommon', 'test_variant_consistency_eager'), + # Not close + DecorateInfo(unittest.skip("Errors when storage_offset is included"), + 'TestCommon', 'test_complex_half_reference_testing'), + # Not close + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'), + )), + OpInfo('as_strided', + variant_test_name='partial_views', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_as_strided_partial_views, + skips=( + # Note: This xfail is fine -- it's inherent to how as_strided works + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'), + # These fail because the test changes the input's in-memory layout + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_grad'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_gradgrad'), + DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', + 'test_make_fx_symbolic_exhaustive_inplace'), + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), + # Fail but are also flaky + DecorateInfo(unittest.skip("Test changes in memory layout"), 'TestMathBits'), + DecorateInfo(unittest.skip("Modifies input strides and storage_offset"), 'TestCommon', + 'test_non_standard_bool_values'), + # RuntimeError: setStorage: sizes [2, 2], strides [1, 2], storage offset 10, and itemsize 2 requiring a + # storage size of 28 are out of bounds for storage of size 20 + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace_all_strides'), + )), + OpInfo('as_strided_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_as_strided, + skips=( + # Note: This xfail is fine -- it's inherent to how as_strided works + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'), + # AssertionError: False is not true : Scalars failed to compare as equal! + DecorateInfo(unittest.skip("Errors when storage_offset is included"), + 'TestCommon', 'test_variant_consistency_eager'), + # Not close + DecorateInfo(unittest.skip("Errors when storage_offset is included"), + 'TestCommon', 'test_complex_half_reference_testing'), + # Not close + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + )), + OpInfo('as_strided_scatter', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_as_strided_scatter, + error_inputs_func=error_inputs_as_strided_scatter, + skips=( + DecorateInfo(unittest.skip('Works for int64, fails for everything else'), 'TestCommon', 'test_noncontiguous_samples'), # noqa: B950 + DecorateInfo(unittest.skip('Fails in most cases, passes on LAZY for some reason'), 'TestCommon', 'test_variant_consistency_eager'), # noqa: B950 + DecorateInfo(unittest.skip('Fails on cuda'), 'TestCommon', 'test_complex_half_reference_testing', + active_if=not TEST_WITH_ROCM), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.skip('Passes on complex128 and float64 only'), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + # AssertionError: Tensor-likes are not close! (new_empty_strided.default) + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'),)), + OpInfo('native_layer_norm', + aten_name='native_layer_norm', + ref=reference_native_layer_norm, + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + assert_jit_shape_analysis=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_native_layer_norm, + error_inputs_func=error_inputs_native_layer_norm, + skips=( + # IndexError: tuple index out of range + DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients', 'test_forward_mode_AD'), + # Tests fail when weight=None and bias is defined + # https://github.com/pytorch/pytorch/issues/79705 + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), + # JIT test also tries to compute double backward, which fails + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-03, rtol=5e-03)}), + "TestDecomp", "test_comprehensive", device_type="cpu"), + )), + OpInfo('native_batch_norm', + aten_name='native_batch_norm', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + allow_cow_input_materialize_forward=[3, 4], + allow_cow_input_materialize_backward=[3, 4], + sample_inputs_func=sample_inputs_native_batch_norm, + skips=( + # NotImplementedError: Could not run + # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), + # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), + # Problem with _get_numerical_jacobian + # IndexError: tuple index out of range + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # https://github.com/pytorch/pytorch/issues/85960 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), + # AssertionError: Booleans mismatch: True is not False + DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}), + "TestCompositeCompliance", "test_forward_ad"), + ) + ), + OpInfo('_native_batch_norm_legit', + aten_name='_native_batch_norm_legit', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + allow_cow_input_materialize_forward=[3, 4], + allow_cow_input_materialize_backward=[3, 4], + sample_inputs_func=sample_inputs__native_batch_norm_legit, + skips=( + # NotImplementedError: Could not run + # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), + # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), + # Problem with _get_numerical_jacobian + # IndexError: tuple index out of range + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # https://github.com/pytorch/pytorch/issues/85960 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}), + "TestCompositeCompliance", "test_forward_ad"), + ) + ), + OpInfo('_batch_norm_with_update', + op=torch.ops.aten._batch_norm_with_update, + aten_name='_batch_norm_with_update', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + allow_cow_input_materialize_forward=[3, 4], + allow_cow_input_materialize_backward=[3, 4], + sample_inputs_func=sample_inputs__batch_norm_with_update, + skips=( + # NotImplementedError: Could not run + # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), + # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), + # Problem with _get_numerical_jacobian + # IndexError: tuple index out of range + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}), + "TestCompositeCompliance", "test_forward_ad"), + # _batch_norm_with_update expects contiguous inputs for cudnn and miopen + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type="cuda"), + DecorateInfo(unittest.expectedFailure, + 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides', device_type="cuda"), + # _batch_norm_with_update does not have python bindings + DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # aten out variants do not accept out= kwarg, only python out variants + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + ) + ), + OpInfo('nn.functional.cosine_similarity', + aten_name="cosine_similarity", + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1.3e-5, rtol=2e-2)}), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + ], + sample_inputs_func=sample_inputs_cosine_similarity), + OpInfo('nn.functional.adaptive_avg_pool1d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_avg_pool1d, + sample_inputs_func=sample_inputs_adaptive_avg_pool1d), + OpInfo('nn.functional.adaptive_avg_pool2d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + decorators=( + # RuntimeError: + # adaptive_avg_pool2d(Tensor input, int[2] output_size) -> (Tensor): + # Expected a value of type 'List[int]' for argument 'output_size' but + # instead found type 'Tuple[NoneType, int]'. : + # File "", line 3 + # def the_method(i0): + # return torch.nn.functional.adaptive_avg_pool2d(i0, (None, 7)) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_avg_pool2d, + sample_inputs_func=sample_inputs_adaptive_avg_pool2d), + OpInfo('nn.functional.adaptive_avg_pool3d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + decorators=( + # RuntimeError: + # adaptive_avg_pool3d(Tensor input, int[3] output_size) -> (Tensor): + # Expected a value of type 'List[int]' for argument 'output_size' but + # instead found type 'Tuple[NoneType, NoneType, NoneType]'. : + # File "", line 3 + # + # def the_method(i0): + # return torch.nn.functional.adaptive_avg_pool3d(i0, (None, None, None)) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + # + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_avg_pool3d, + sample_inputs_func=sample_inputs_adaptive_avg_pool3d), + OpInfo('nn.functional.adaptive_max_pool1d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_max_pool1d, + sample_inputs_func=sample_inputs_adaptive_max_pool1d), + OpInfo('nn.functional.adaptive_max_pool2d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + decorators=( + # RuntimeError: + # adaptive_max_pool2d(Tensor input, int[2] output_size) -> (Tensor): + # Expected a value of type 'List[int]' for argument 'output_size' but + # instead found type 'Tuple[NoneType, int]'. : + # File "", line 3 + # def the_method(i0): + # return torch.nn.functional.adaptive_max_pool2d(i0, (None, 7)) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_max_pool2d, + sample_inputs_func=sample_inputs_adaptive_max_pool2d), + OpInfo('nn.functional.adaptive_max_pool3d', + dtypes=floating_types_and(torch.bfloat16, torch.half), + decorators=( + # RuntimeError: + # adaptive_max_pool3d(Tensor input, int[3] output_size) -> (Tensor): + # Expected a value of type 'List[int]' for argument 'output_size' but + # instead found type 'Tuple[NoneType, NoneType, NoneType]'. : + # File "", line 3 + # + # def the_method(i0): + # return torch.nn.functional.adaptive_max_pool3d(i0, (None, None, None)) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + # + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_max_pool3d, + sample_inputs_func=sample_inputs_adaptive_max_pool3d), + OpInfo('nn.functional.avg_pool1d', + aten_name='avg_pool1d', + supports_autograd=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_avg_pool1d, + sample_inputs_func=sample_inputs_avgpool1d), + OpInfo('nn.functional.avg_pool3d', + aten_name='avg_pool3d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.int64), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_avg_pool3d, + sample_inputs_func=sample_inputs_avgpool3d, + skips=( + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'), + )), + OpInfo( + "nn.functional.binary_cross_entropy_with_logits", + aten_name="binary_cross_entropy_with_logits", + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=sample_inputs_binary_cross_entropy_with_logits, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,) + ), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + ), + ), + UnaryUfuncInfo( + 'nn.functional.relu', + aten_name="relu", + ref=lambda a: np.where(a <= 0, 0, a), + supports_autograd=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_nn_activation_relu, + supports_out=False, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True), + OpInfo('nn.functional.conv_transpose1d', + # `ref` for this function is backward of + # corresponding `conv*d` + ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose1d), + aten_name='conv_transpose1d', + aliases=('conv_transpose1d',), + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, + torch.bfloat16), + sample_inputs_func=sample_inputs_conv_transpose1d, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=( + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-2, rtol=5e-2), }), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo( + toleranceOverride({torch.float: tol(atol=1.5e-5, rtol=1.5e-5), }), + 'TestCommon', 'test_numpy_ref_mps'), + DecorateInfo( + toleranceOverride({torch.half: tol(atol=1e-3, rtol=5e-3), }), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), + ), + skips=( + # Reason for Skip: https://github.com/pytorch/pytorch/pull/79694#issuecomment-1186949486 + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.complex64,)), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.complex64, torch.complex128)), + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float,)), + # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long' + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.int64,)), + ), + supports_out=False,), + OpInfo('nn.functional.conv_transpose2d', + aten_name='conv_transpose2d', + aliases=('conv_transpose2d',), + # `ref` for this function is backward of + # corresponding `conv*d` + ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose2d), + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, + torch.bfloat16), + sample_inputs_func=sample_inputs_conv_transpose2d, + # Runs very slowly on slow-gradcheck for complex. + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=2e-05, rtol=5e-05), }), + 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=8e-2, rtol=8e-2), }), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo( + toleranceOverride({torch.half: tol(atol=1e-3, rtol=4e-3), }), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')], + skips=( + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.complex64, torch.complex128)), + # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long' + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.int64,)), + # Reference: https://github.com/pytorch/pytorch/issues/86356 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.double, torch.cdouble)), + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + # AssertionError: None mismatch: torch.complex64 is not None + DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', 'test_custom_rules', + dtypes=(torch.complex64, torch.complex128)), + ), + supports_out=False,), + OpInfo('nn.functional.conv_transpose3d', + aten_name='conv_transpose3d', + aliases=('conv_transpose3d',), + # `ref` for this function is backward of + # corresponding `conv*d` + ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose3d), + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and( + torch.float16, torch.chalf, torch.bfloat16), + sample_inputs_func=sample_inputs_conv_transpose3d, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + # Runs very slowly on slow-gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=5e-2, rtol=5e-2), }), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), + torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=2e-04, rtol=2e-04), }), + 'TestCompositeCompliance', 'test_operator', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06), + torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}), + 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=2e-05), }), + 'TestCompositeCompliance', 'test_forward_ad', device_type='cuda', + active_if=TEST_CUDNN), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1e-4)}), + "TestMathBits", "test_conj_view", device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=9e-2, rtol=9e-2), }), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo( + toleranceOverride({torch.half: tol(atol=9e-3, rtol=2e-1), }), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')], + skips=( + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: "slow_conv3d_cpu_grad_input" not implemented for 'Long' + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.int64,)), + # Reference: https://github.com/pytorch/pytorch/issues/86356 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.double, torch.cdouble)), + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip('Skipped for ROCm!'), 'TestCommon', 'test_complex_half_reference_testing', + dtypes=[torch.complex32], active_if=TEST_WITH_ROCM), + ), + supports_out=False,), + OpInfo('nn.functional.conv1d', + aliases=('conv1d',), + aten_name='conv1d', + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, + torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_conv1d, + error_inputs_func=error_inputs_conv1d, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=5e-2)}), + 'TestCommon', 'test_complex_half_reference_testing' + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda', + ), + ), + skips=( + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Ref: https://github.com/pytorch/pytorch/issues/75309 + # AssertionError: None mismatch: torch.complex128 is not None + DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', + 'test_custom_rules', dtypes=(torch.complex64, torch.complex128)), + # Ref: https://github.com/pytorch/pytorch/issues/75309 + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), + ), + supports_expanded_weight=True, + supports_out=False,), + OpInfo('nn.functional.conv2d', + aliases=('conv2d',), + aten_name='conv2d', + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, + torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_conv2d), + error_inputs_func=error_inputs_conv2d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}), + 'TestCommon', 'test_complex_half_reference_testing', + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=5e-3, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', + ), + ), + skips=( + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Works on some configs!"), 'TestJit', 'test_variant_consistency_jit'), + # Ref: https://github.com/pytorch/pytorch/issues/75309 + # AssertionError: None mismatch: torch.complex128 is not None + DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', + 'test_custom_rules', dtypes=(torch.complex64, torch.complex128)), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), + ), + supports_expanded_weight=True, + supports_out=False,), + OpInfo('nn.functional.conv3d', + aliases=('conv3d',), + aten_name='conv3d', + dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_conv3d, + error_inputs_func=error_inputs_conv3d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}), + 'TestCommon', 'test_complex_half_reference_testing', + ), + # TF32 + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=5e-3, rtol=1e-3), + torch.complex64: tol(atol=5e-3, rtol=1e-3)}), + 'TestCommon', 'test_noncontiguous_samples', + ), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=2e-5, rtol=3e-6)}), + 'TestCommon', 'test_variant_consistency_eager', + ), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=5e-5, rtol=5e-6)}), + 'TestMathBits', 'test_conj_view', + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-6)}), + 'TestOperators', 'test_vjpvmap', + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=5e-3, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', + ), + ), + skips=( + # RuntimeError: !lhs.isAliasOf(rhs) INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), + # AssertionError: Tensor-likes are not close! + # break slow tests + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), + ), + supports_expanded_weight=True, + supports_out=False,), + OpInfo('nn.functional.group_norm', + aten_name='group_norm', + aliases=('group_norm',), + ref=reference_group_norm, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_group_norm, + decorators=[ + # RuntimeError: Cannot insert a Tensor that requires grad as a constant. + # Consider making it a parameter or input, or detaching the gradient + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=5e-05, rtol=3e-03)}), + "TestDecomp", + "test_comprehensive", + device_type="cpu" + ), + ], + sample_inputs_func=sample_inputs_group_norm, + reference_inputs_func=reference_inputs_group_norm, + supports_expanded_weight=True,), + OpInfo('nn.functional.instance_norm', + # no ref because instance_norm will often have numerical instability (large numbers or nan) + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + allow_cow_input_materialize_forward=['running_mean', 'running_var'], + decorators=[ + # RuntimeError: Cannot insert a Tensor that requires grad as a constant. + # Consider making it a parameter or input, or detaching the gradient + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ], + sample_inputs_func=sample_inputs_instance_norm, + supports_expanded_weight=True,), + OpInfo('nn.functional.layer_norm', + aten_name='layer_norm', + aten_backward_name='layer_norm_backward', + aliases=('layer_norm',), + ref=reference_layer_norm, + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}), + 'TestCommon', 'test_numpy_refs' + ), + DecorateInfo(unittest.skip("Bug in MPS backend!"), 'TestCommon', 'test_numpy_ref_mps'), + ], + sample_inputs_func=sample_inputs_layer_norm, + supports_expanded_weight=True,), + OpInfo('nn.functional.rms_norm', + aten_name='rms_norm', + aliases=('rms_norm',), + ref=reference_rms_norm, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_rms_norm, + error_inputs_func=error_inputs_rms_norm,), + OpInfo('nn.functional.local_response_norm', + dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ], + sample_inputs_func=sample_inputs_local_response_norm,), + OpInfo('constant_pad_nd', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + sample_inputs_func=sample_inputs_constant_pad_nd, + supports_out=False, + skips=( + # bool can't be passed to Scalar arguments in JIT tracer because + # BoolType is not a subtype of ScalarType. + DecorateInfo( + unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.bool,)), + )), + OpInfo('nn.functional.pad', + variant_test_name='constant', + aten_name='constant_pad_nd', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + sample_inputs_func=partial(sample_inputs_nn_pad, mode='constant'), + supports_out=False), + OpInfo('nn.functional.pad', + variant_test_name='reflect', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'), + skips=( + # Doesn't have a corresponding aten operator. + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False), + OpInfo('nn.functional.pad', + variant_test_name='replicate', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'), + skips=( + # Doesn't have a corresponding aten operator. + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False), + OpInfo('nn.functional.pad', + variant_test_name='replicate_negative', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_nn_pad_replicate_negative, + skips=( + # Doesn't have a corresponding aten operator. + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + # Some negative padding cases cause a segfault on MPS + DecorateInfo(unittest.skip("Not fully supported on MPS"), 'TestConsistency'), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False), + OpInfo('nn.functional.pad', + variant_test_name='circular', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + sample_inputs_func=partial(sample_inputs_nn_pad, mode='circular'), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + # Doesn't have a corresponding aten operator. + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + # Difference from is larger with decomposition new_empty_strided.default than original on output 0 + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'), + ), + supports_out=False), + OpInfo('nn.functional.hardswish', + aten_name="hardswish", + aten_backward_name='hardswish_backward', + supports_autograd=True, + assert_autodiffed=True, + sample_inputs_func=sample_inputs_hardswish, + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_gradgrad=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + autodiff_nonfusible_nodes=["aten::hardswish"]), + OpInfo('nn.functional.unfold', + aten_name='im2col', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), + sample_inputs_func=sample_inputs_nn_unfold, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + # NOTE: this failure may not reproduce consistently on different systems + # false INTERNAL ASSERT FAILED at "...torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185 + DecorateInfo(unittest.skip("Internal assert failed!"), 'TestJit', 'test_variant_consistency_jit'), + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='nearest', + supports_autograd=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_interpolate, 'nearest'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='nearest-exact', + supports_autograd=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + dtypes=floating_types_and(torch.half, torch.bfloat16, torch.uint8), + sample_inputs_func=partial(sample_inputs_interpolate, 'nearest-exact'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: aten::_upsample_nearest_exact*d hit the vmap fallback which is currently disabled + DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'), + DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'), + DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='linear', + supports_autograd=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_interpolate, 'linear'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='bilinear', + supports_fwgrad_bwgrad=True, + supports_autograd=True, + supports_forward_ad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_interpolate, 'bilinear'), + reference_inputs_func=partial(reference_inputs_interpolate, 'bilinear'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='bicubic', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_interpolate, 'bicubic'), + reference_inputs_func=partial(reference_inputs_interpolate, 'bicubic'), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='trilinear', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_interpolate, 'trilinear'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='area', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_interpolate, 'area'), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.upsample_bilinear', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_upsample, 'bilinear'), + reference_inputs_func=partial(reference_inputs_upsample, 'bilinear'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('_upsample_bilinear2d_aa', + op=torch.ops.aten._upsample_bilinear2d_aa, + aten_name='_upsample_bilinear2d_aa', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.uint8), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_upsample_aa, 'bilinear'), + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + )), + OpInfo( + "nn.functional.soft_margin_loss", + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + # doesn't support grad on target + sample_inputs_func=partial(sample_inputs_loss, rhs_requires_grad=False), + error_inputs_func=error_inputs_soft_margin_loss, + ), + OpInfo('nn.functional.upsample_nearest', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_upsample, 'nearest'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo( + "nn.functional.margin_ranking_loss", + dtypes=all_types_and(torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_margin_ranking_loss, + error_inputs_func=error_inputs_margin_ranking_loss, + reference_inputs_func=reference_inputs_margin_ranking_loss, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), + OpInfo( + "nn.functional.multi_margin_loss", + dtypes=floating_types(), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + supports_out=False, + supports_gradgrad=False, + sample_inputs_func=sample_inputs_multi_margin_loss, + reference_inputs_func=reference_inputs_multi_margin_loss, + error_inputs_func=error_inputs_multi_margin_loss, + decorators=( + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), + "TestJit", + "test_variant_consistency_jit", + ), + ), + ), + OpInfo( + "nn.functional.multilabel_margin_loss", + dtypes=floating_types(), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + supports_out=False, + supports_gradgrad=False, + sample_inputs_func=sample_inputs_multilabel_margin_loss, + reference_inputs_func=reference_inputs_multilabel_margin_loss, + error_inputs_func=error_inputs_multilabel_margin_loss, + ), + OpInfo('nn.functional.leaky_relu', + aliases=None, + aten_name="leaky_relu", + aten_backward_name='leaky_relu_backward', + sample_inputs_func=sample_inputs_leaky_relu, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + inplace_variant=lambda x, negative_slope=0.01: + torch.nn.functional.leaky_relu(x, negative_slope, inplace=True), + supports_autograd=True, + assert_autodiffed=True, + supports_gradgrad=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::leaky_relu"]), + OpInfo( + "nn.functional.multilabel_soft_margin_loss", + supports_out=False, + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_multilabel_soft_margin_loss, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), + "TestJit", + "test_variant_consistency_jit", + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=4e-3, rtol=1.3e-3)}), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + ), + skips=( + # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 4096 + # __main__.TestJitCUDA.test_variant_consistency_jit_nn_functional_multilabel_soft_margin_loss_cuda_float32 + # leaked 4096 bytes CUDA memory on device 0 + DecorateInfo( + # Skip instead of expectedFailure because this fails + # locally for me but passes in CI. + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + ), + ), + OpInfo('nn.functional.avg_pool2d', + aten_name='avg_pool2d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + error_inputs_func=error_inputs_avg_pool2d, + sample_inputs_func=sample_inputs_avgpool2d, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'), + )), + OpInfo('nn.functional.fractional_max_pool2d', + supports_autograd=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.fractional_max_pool2d, input, *args, **kwargs), + # vmap does not support random operations + check_batched_forward_grad=False, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + test_neg_view=False, + sample_inputs_func=sample_inputs_fractional_max_pool2d, + decorators=( + # FIXME: AssertionError: False is not true : Tensors failed to compare as equal! + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')), + skips=( + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)), + OpInfo('nn.functional.fractional_max_pool3d', + supports_autograd=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.fractional_max_pool3d, input, *args, **kwargs), + # vmap does not support random operations + check_batched_forward_grad=False, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + test_neg_view=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=sample_inputs_fractional_max_pool3d, + decorators=( + # FIXME: both derivatives are implemented incorrectly + # https://github.com/pytorch/pytorch/issues/69322 + # FIXME: AssertionError: False is not true : Tensors failed to compare as equal! + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')), + skips=( + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)), + OpInfo('nn.functional.max_pool1d', + aten_name='max_pool1d', + supports_autograd=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + # TODO: add shape checks + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # Pre-existing condition; Needs to be fixed + DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.bfloat16,)), + # RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. + # Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() + # to actually allocate memory + DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), + ), + error_inputs_func=error_inputs_max_pool1d, + sample_inputs_func=sample_inputs_max_pool), + OpInfo('nn.functional.max_pool2d', + aten_name='max_pool2d', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + # Vmap is not happy with non-contiguous (channels_last) inputs + check_batched_gradgrad=False, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + assert_jit_shape_analysis=True, + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + error_inputs_func=error_inputs_max_pool2d, + sample_inputs_func=sample_inputs_max_pool), + OpInfo('max_pool2d_with_indices_backward', + op=max_pool2d_backward, + # We've defined a custom op, so there's no corresponding aten op + aten_name=None, + method_variant=None, + inplace_variant=None, + operator_variant=None, + inplace_operator_variant=None, + check_batched_gradgrad=False, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_max_pool, + skips=( + # We've defined a custom op here, and we don't handle the case where we receive an out kwarg + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # object has no attribute max_pool2d_with_indices_backward (It's not available on torch -- so expected) + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit') + )), + OpInfo('nn.functional.max_pool3d', + aten_name='max_pool3d', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + # TODO: add shape checks + assert_jit_shape_analysis=False, + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + # TODO: investigate nondeterminism + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_max_pool3d, + sample_inputs_func=sample_inputs_max_pool), + OpInfo('nn.functional.max_unpool1d', + aten_name='max_unpool1d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool, + error_inputs_func=error_inputs_max_unpool, + skips=( + # Gradients are tested in `variant_test_name=grad` below. + # We skip tests here because there is non-determinism in backward + # with gather, when there are writes into the same memory location, + # and if there are several indices pointing to the same memory, + # gradcheck is oblivious about that and cannot perturb them all at once + # (see sample_inputs_max_unpool_grad to find out more). + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', + active_if=(not IS_MACOS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad', + device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'), + )), + OpInfo('nn.functional.max_unpool1d', + variant_test_name='grad', + aten_name='max_unpool1d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool_grad), + OpInfo('nn.functional.max_unpool2d', + aten_name='max_unpool2d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool, + error_inputs_func=error_inputs_max_unpool, + skips=( + # Gradients are tested in `variant_test_name=grad` below. + # We skip tests here because there is non-determinism in backward + # with gather, when there are writes into the same memory location, + # and if there are several indices pointing to the same memory, + # gradcheck is oblivious about that and cannot perturb them all at once + # (see sample_inputs_max_unpool_grad to find out more). + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', + active_if=(not IS_MACOS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'), + )), + OpInfo('nn.functional.max_unpool2d', + variant_test_name='grad', + aten_name='max_unpool2d', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # Vmap is not happy with non-contiguous (channels_last) inputs + check_batched_grad=False, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool_grad), + OpInfo('nn.functional.max_unpool3d', + aten_name='max_unpool3d', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool, + error_inputs_func=error_inputs_max_unpool, + skips=( + # Gradients are tested in `variant_test_name=grad` below. + # We skip tests here because there is non-determinism in backward + # with gather, when there are writes into the same memory location, + # and if there are several indices pointing to the same memory, + # gradcheck is oblivious about that and cannot perturb them all at once + # (see sample_inputs_max_unpool_grad to find out more). + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', + active_if=(not IS_MACOS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'), + )), + OpInfo('nn.functional.max_unpool3d', + variant_test_name='grad', + aten_name='max_unpool3d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool_grad), + OpInfo('nn.functional.linear', + aten_name='linear', + supports_autograd=True, + supports_gradgrad=True, + sample_inputs_func=sample_inputs_linear, + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # linear calls mm under the hood which is nondeterministic on CUDA + # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + supports_expanded_weight=True, + decorators=( + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + )), + OpInfo('nn.functional.bilinear', + aten_name='bilinear', + supports_autograd=True, + sample_inputs_func=sample_inputs_bilinear, + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, + *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else []), + decorators=( + DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-03, rtol=1.3e-03)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), + ), + skips=( + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bfloat16,)), + ), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('nn.functional.glu', + aten_name='glu', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + sample_inputs_func=sample_inputs_glu, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + UnaryUfuncInfo( + 'nn.functional.elu', + aten_backward_name='elu_backward', + ref=lambda x, alpha=1.0, inplace=False: + np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x) - 1)), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + sample_kwargs=lambda device, dtype, input: + ({'alpha': 0.8}, {'alpha': 0.8}), + inplace_variant=lambda x, alpha=1.0: + torch.nn.functional.elu(x, alpha, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1.2e-03), + torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + # Marked as a Unary function because it has some rather odd broadcasting semantics in its + # second argument + UnaryUfuncInfo( + 'nn.functional.prelu', + aten_backward_name='_prelu_kernel_backward', + ref=lambda x, weight: + np.maximum(0., x) + np.minimum(0., x) * + (weight if x.ndim == 1 else weight.reshape([weight.size if i == 1 else 1 for i in range(x.ndim)])), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + # test_reference_numerics only tests the case when the weight tensor is a scalar + sample_kwargs=sample_kwargs_prelu_scalar_weight, + error_inputs_func=error_inputs_prelu, + sample_inputs_func=sample_inputs_prelu, + reference_inputs_func=reference_inputs_prelu, + decorators=[ + # RuntimeError: Cannot insert a Tensor that requires grad as a constant. + # Consider making it a parameter or input, or detaching the gradient + # https://github.com/pytorch/pytorch/issues/68752 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), ], + ), + UnaryUfuncInfo( + 'nn.functional.celu', + ref=lambda x, alpha=1.0, inplace=False: + np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x / alpha) - 1)), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + sample_kwargs=lambda device, dtype, input: + ({'alpha': 0.8}, {'alpha': 0.8}), + inplace_variant=lambda x, alpha=1.0: + torch.nn.functional.celu(x, alpha, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1.2e-03), + torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + UnaryUfuncInfo( + 'nn.functional.rrelu', + aten_backward_name='rrelu_with_noise_backward', + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.rrelu, input, *args, **kwargs), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.rrelu, input, *args, inplace=True, **kwargs), + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + gradcheck_wrapper=wrapper_set_seed, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_kwargs=lambda device, dtype, input: + (dict(lower=0., upper=1., training=True), dict(lower=0., upper=1., training=True)), + sample_inputs_func=sample_inputs_rrelu, + error_inputs_func=error_inputs_rrelu, + decorators=( + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1.2e-03), + torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) + }), + 'TestUnaryUfuncs', device_type='cuda', + ),), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # In-place operations do not play well with forward AD + # https://github.com/pytorch/pytorch/issues/77447 + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', + 'test_inplace_forward_mode_AD'), + # The noise vector that's generated in these tests is not the same elementwise + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'), + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'), + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_non_contig_expand'), + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + skip_correctness_check_compile_vs_eager=True, + ), + UnaryUfuncInfo( + 'nn.functional.selu', + ref=lambda x, inplace=False: + 1.0507009873554804934193349852946 * ( + np.maximum(0., x) + np.minimum(0., 1.6732632423543772848170429916717 * (np.exp(x) - 1)) + ), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, # depends on 'elu' + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + inplace_variant=lambda x: torch.nn.functional.selu(x, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-2, rtol=1.8e-2), + torch.bfloat16: tol(atol=1e-2, rtol=1.8e-2) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + OpInfo( + 'torch._scaled_mm_v2', + sample_inputs_func=sample_inputs_scaled_mm_v2, + dtypes=float8_types(), + dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,), + supports_out=True, + supports_forward_ad=False, + supports_autograd=False, + decorators=[onlyCUDA, skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')], + skips=( + # Sample inputs isn't really parametrized on dtype + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), + # "add_stub" not implemented for 'Float8_e4m3fn' + # "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn' + # https://github.com/pytorch/pytorch/issues/107256 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + # "mul_cuda" not implemented for float8_e4m3fn + # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn' + # https://github.com/pytorch/pytorch/issues/107256 + DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'), + # aten::_scaled_mm hit the vmap fallback which is currently disabled + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)), + ) + ), + OpInfo( + 'torch._scaled_mm', + sample_inputs_func=sample_inputs_scaled_mm, + dtypes=float8_types(), + dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,), + supports_out=True, + supports_forward_ad=False, + supports_autograd=False, + decorators=[skipXPU, skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')], + skips=( + # Sample inputs isn't really parametrized on dtype + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), + # "add_stub" not implemented for 'Float8_e4m3fn' + # "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn' + # https://github.com/pytorch/pytorch/issues/107256 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + # "mul_cuda" not implemented for float8_e4m3fn + # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn' + # https://github.com/pytorch/pytorch/issues/107256 + DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'), + # aten::_scaled_mm hit the vmap fallback which is currently disabled + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)), + ) + ), + OpInfo( + 'torch.ops.aten._safe_softmax.default', + dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool), + sample_inputs_func=sample_inputs_safe_softmax, + assert_jit_shape_analysis=True, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + supports_cow_input_no_materialize_backward=False, + decorators=[], + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + ), + OpInfo( + 'nn.functional.scaled_dot_product_attention', + op=lambda *args, **kwargs: + wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs), + sample_inputs_func=sample_inputs_scaled_dot_product_attention, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=False, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + decorators=[DecorateInfo(toleranceOverride( + {torch.float32: tol(atol=5e-05, rtol=5e-6)}), 'TestCommon',), ], + skips=( + # When attn mask is a composite tensor this fails backward by returning a none + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', device_type='cuda'), + # This is only failing on Linux Bionic 3.10 Cuda 11.6 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', + device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', + dtypes=(torch.float32,)), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Forward works for dtype=float64 which is the math path + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # Not implemented for Forward AD + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', + device_type='cpu'), + # Not implemented for backward derivative + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad', + device_type='cpu'), + # CPU and CUDA have inconsistencies for intermediate outputs + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace', + device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', + device_type='cpu'), + # When changing input from Tensor to CompositeCompliantTensor, input.requires_grad() changes from true to false + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', + device_type='cpu'), + # OpInfo was implemented with a lambda + DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # TODO Need to understand what this is testing and why it doesn't work + DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'), + # TODO skip this for now since we can't skip on runtime arch support + DecorateInfo(unittest.skip('This is '), 'TestInductorOpInfo', 'test_comprehensive'), + # skip for sm < 80 + DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), + # FIXME + DecorateInfo(unittest.skip('test_cow_input does not work with efficient attention on ROCM'), + 'TestCompositeCompliance', 'test_cow_input', + device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION),), + ), + OpInfo( + 'torch.ops.aten._flash_attention_forward', + sample_inputs_func=sample_inputs_flash_attention_forward, + dtypes=empty_types(), + dtypesIfCUDA=custom_types(torch.float16) + if not SM80OrLater + else custom_types(torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=True, + supports_fwgrad_bwgrad=False, + supports_forward_ad=False, + check_batched_forward_grad=False, + decorators=[skipCUDAIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "This platform doesn't support Flash Attention")], + skips=( + # Checking the scalar value of the philox seed and offset + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'), + # None Mismatch Tensor + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'), + ) + ), + OpInfo( + 'torch.ops.aten._efficient_attention_forward', + sample_inputs_func=sample_inputs_efficient_attention_forward, + dtypes=empty_types(), + dtypesIfCUDA=custom_types(torch.float16, torch.float32) + if not SM80OrLater + else custom_types(torch.float16, torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=True, + supports_fwgrad_bwgrad=False, + supports_forward_ad=False, + check_batched_forward_grad=False, + # TODO: Skip because it produces a CUDA illegal memory access for some reason + skip_cow_input_backward=True, + # FIXME: mask_type == 2 (LowerRight) + decorators=[ + skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"), + skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2"), + skipXPU], + skips=( + # Checking the scaler value of the philox seed and offset + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'), + # None Mismatch Tensor + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'), + ) + ), + UnaryUfuncInfo( + 'nn.functional.silu', + aten_backward_name='silu_backward', + ref=lambda x, inplace=False: x / (1 + np.exp(-x)), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_autograd=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + supports_out=False, + inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-3, rtol=1e-3), + torch.bfloat16: tol(atol=1e-4, rtol=1e-4) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + dtypes=(torch.cfloat,), device_type='cpu'), + ), + autodiff_nonfusible_nodes=["aten::silu"], + ), + # TODO: combine this with the nn.functional.silu OpInfo when + # complex autodiff for silu is supported or when + # the forward bug is fixed + # Note: silu errors when given inputs that require grad + # but it doesn't support grad in their dtype + # This is why the dtypes list above passes test_dtypes, + # because it's getting lucky and failing in forward + # because test_dtypes sets requires_grad to True + # THIS IS A BUG + UnaryUfuncInfo( + 'nn.functional.silu', + variant_test_name='complex', + ref=lambda x, inplace=False: + x / (1 + np.exp(-x)), + dtypes=complex_types(), + dtypesIfCUDA=complex_types(), + supports_forward_ad=False, + supports_autograd=False, + assert_autodiffed=False, + supports_out=False, + inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-3, rtol=1e-3), + torch.bfloat16: tol(atol=1e-4, rtol=1e-4) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + dtypes=(torch.cfloat,)), + # FIXME: intentionally misreports dtypes + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), + # FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j) + DecorateInfo(unittest.skip("Skipped!"), + 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.complex64, torch.cdouble)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=(torch.complex64,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.complex64,)))), + UnaryUfuncInfo( + 'nn.functional.hardsigmoid', + aten_backward_name='hardsigmoid_backward', + ref=reference_hardsigmoid, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=False, + supports_forward_ad=True, + supports_out=False, + inplace_variant=partial(torch.nn.functional.hardsigmoid, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-04, rtol=0.001)}), 'TestUnaryUfuncs', device_type='cuda',), ], + skips=[ + # still want to test that first derivative works though second derivative isn't supported + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', "test_inplace_gradgrad")] + ), + UnaryUfuncInfo( + 'nn.functional.logsigmoid', + aten_name="log_sigmoid", + aten_backward_name='log_sigmoid_backward', + ref=reference_logsigmoid, + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_autograd=True, + assert_autodiffed=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_gradgrad=True, + # autodiff_nonfusible_nodes=["aten::log_sigmoid"], + decorators=[ + DecorateInfo( + precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), + 'TestUnaryUfuncs', 'test_reference_numerics_small'), + DecorateInfo( + precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), + 'TestUnaryUfuncs', 'test_reference_numerics_large'), + DecorateInfo( + precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + ], + skips=( + # Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cpu'), + ), + ), + UnaryUfuncInfo( + 'nn.functional.mish', + aten_backward_name='mish_backward', + ref=lambda x: x * np.tanh(reference_softplus(x)), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + inplace_variant=partial(torch.nn.functional.mish, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), 'TestUnaryUfuncs',), ], + ), + UnaryUfuncInfo( + 'nn.functional.softsign', + ref=lambda x: x / (np.abs(x) + 1), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1.3e-04)}), 'TestUnaryUfuncs',), ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=(torch.int, torch.int8)),), + ), + UnaryUfuncInfo( + 'nn.functional.tanhshrink', + ref=lambda x: x - np.tanh(x), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + decorators=[ + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo( + toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}), 'TestUnaryUfuncs',), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=6e-04, rtol=1e-05), + torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), + ], + skips=( + # in each case, pytorch will produce a nan while numpy will not + DecorateInfo(unittest.skip("Fails on some jobs works on others!"), + 'TestUnaryUfuncs', "test_reference_numerics_large", + dtypes=(torch.complex64, torch.complex128), active_if=(IS_MACOS)), + DecorateInfo(unittest.skip("Fails on some jobs works on others!"), + 'TestUnaryUfuncs', "test_reference_numerics_extremal", + dtypes=(torch.complex64, torch.complex128), device_type='cpu', + active_if=(IS_MACOS or IS_WINDOWS)), + ), + # tan(j * pi/2 * odd_number) is nan which also make tanhshrink nan. + reference_numerics_filter=NumericsFilter( + condition=lambda x: (close_to_int(x / (math.pi * 0.5j)) + if x.is_complex() else x.new_tensor(False, dtype=torch.bool)), + safe_val=0) + ), + UnaryUfuncInfo( + 'nn.functional.threshold', + ref=lambda x, threshold, value: np.where(x <= threshold, value, x).astype(x.dtype), + dtypes=all_types_and(torch.half, torch.bfloat16), + inplace_variant=lambda x, threshold, value: + torch.nn.functional.threshold(x, threshold, value, inplace=True), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + sample_kwargs=lambda device, dtype, input: ({'threshold': float.fromhex('0x1.3ap-3'), + 'value': -9}, + {'threshold': float.fromhex('0x1.3ap-3'), + 'value': -9}), + # TODO(whc) should not need sample_inputs_func, but without it + # kwargs aren't being hooked up properly + sample_inputs_func=sample_inputs_threshold, + ), + OpInfo( + "nn.functional.triplet_margin_loss", + sample_inputs_func=sample_inputs_triplet_margin_loss, + error_inputs_func=error_inputs_triplet_margin_loss, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + OpInfo( + "nn.functional.triplet_margin_with_distance_loss", + sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True), + error_inputs_func=error_inputs_triplet_margin_loss, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # This test cannot handle a callable passed to `distance_function`. If we would use + # `distance_function=None`, the test would pass fine. + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + ), + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + ), + ), + BinaryUfuncInfo('nextafter', + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_autograd=False, + supports_rhs_python_scalar=False), + OpInfo( + "to", + op=lambda x, *args, **kwargs: x.to(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_inputs_func=sample_inputs_to, + skips=( + # RuntimeError: undefined value cpu + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cpu", + ), + # NotImplementedError: Cannot copy out of meta tensor; no data! + DecorateInfo( + unittest.skip("Skipped!"), + "TestMeta", + "test_meta_outplace", + ), + # https://github.com/pytorch/pytorch/issues/84335 + DecorateInfo( + unittest.skip("Skipped!"), + "TestProxyTensorOpInfo", + "test_make_fx_symbolic_exhaustive", + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + ), + ), + OpInfo('topk', + dtypes=all_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs_topk), + # Multiple variants for batch_norm to test with and without cuDNN disabled + # See https://github.com/pytorch/pytorch/pull/63218#discussion_r688549391 for more details + OpInfo('nn.functional.batch_norm', + aten_name='batch_norm', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + allow_cow_input_materialize_forward=[1, 2], + allow_cow_input_materialize_backward=[1, 2], + sample_inputs_func=sample_inputs_batch_norm, + skips=( + # see https://github.com/pytorch/pytorch/issues/71286 + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', + device_type='cpu', dtypes=(torch.bfloat16, torch.float16)), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-05, rtol=1e-05)}), + 'TestCompositeCompliance', 'test_forward_ad', device_type="cpu"), + )), + # This variant tests batch_norm with cuDNN disabled only on CUDA devices + OpInfo('nn.functional.batch_norm', + variant_test_name='without_cudnn', + aten_name='batch_norm', + dtypes=empty_types(), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + allow_cow_input_materialize_forward=[1, 2], + allow_cow_input_materialize_backward=[1, 2], + decorators=[onlyCUDA, disablecuDNN], + skips=( + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-04)}), + 'TestJit', 'test_variant_consistency_jit'), + ), + sample_inputs_func=sample_inputs_batch_norm), + OpInfo( + "nn.functional.binary_cross_entropy", + aten_backward_name='binary_cross_entropy_backward', + sample_inputs_func=sample_inputs_binary_cross_entropy, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + gradcheck_fast_mode=False, + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + # RuntimeError: expected int at position 0, but got: Tensor + DecorateInfo( + unittest.skip("Skipped!"), + "TestCudaFuserOpInfo", + ), + # RuntimeError: expected int at position 0, but got: Tensor + DecorateInfo( + unittest.skip("Skipped!"), + "TestNNCOpInfo", + "test_nnc_correctness", + ), + # Fails for unknown reason: https://github.com/pytorch/pytorch/issues/120783 + DecorateInfo( + unittest.skip("Skipped!"), + "TestCompositeCompliance", + "test_cow_input", + device_type='cuda', + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-3, rtol=1e-3)}), + "TestJit", + "test_variant_consistency_jit", + ), + # RuntimeError: output with shape [] doesn't match the broadcast shape [5, 5] + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), + ), + skips=( + # RuntimeError: expected int at position 0, but got: Tensor + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + ), + ), + ), + # We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the + # standard entry, second is to run gradcheck tests on the second argument. + BinaryUfuncInfo('igamma', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + aliases=('torch.special.gammainc',), + dtypesIfCUDA=floating_types(), + # TODO: FIXME + supports_rhs_python_scalar=False, + supports_autograd=False, + skips=( + # FIXME: incorrectly tries to pass a rhs scalar + DecorateInfo(unittest.expectedFailure, 'TestJit', + 'test_jit_alias_remapping'), + )), + # TODO: FIXME, ideally by implemented grad for both inputs + # BinaryUfuncInfo('igamma', + # variant_test_name='grad_other', + # # Since autograd formula is implemented only for other and + # # gradcheck test verifies the formula for input in SampleInput, + # # we permute the arguments. + # op=lambda self, other, **kwargs: torch.igamma(other, self, **kwargs), + # inplace_variant=None, + # method_variant=None, + # supports_rhs_python_scalar=False, + # rhs_make_tensor_kwargs=dict(requires_grad=False), + # dtypes=floating_types_and(torch.bfloat16, torch.float16), + # backward_dtypesIfCPU=floating_types_and(torch.bfloat16), + # dtypesIfCUDA=floating_types(), + # backward_dtypesIfCUDA=floating_types(), + # supports_inplace_autograd=False, + # skips=( + # # Derivative wrt first tensor not implemented + # DecorateInfo(unittest.expectedFailure, "TestCommon", + # "test_floating_inputs_are_differentiable"),"), + # # test does not work with passing lambda for op + # # AssertionError: False is not true : Tensors failed to compare as equal! + # DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # # test fails are we permute the arguments function variant + # # but not for inplace or method. + # DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # # TypeError: igamma(): argument 'input' (position 1) must be Tensor, not float + # DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'), + # )), + BinaryUfuncInfo('igammac', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + aliases=('torch.special.gammaincc',), + dtypesIfCUDA=floating_types(), + supports_autograd=False, + supports_rhs_python_scalar=False, + skips=( + # FIXME: incorrectly tries to pass a rhs scalar + DecorateInfo(unittest.expectedFailure, 'TestJit', + 'test_jit_alias_remapping'), + )), + # TODO: FIXME, ideally by implementing grad for both inputs + # BinaryUfuncInfo('igammac', + # variant_test_name='grad_other', + # # Since autograd formula is implemented only for other and + # # gradcheck test verifies the formula for input in SampleInput, + # # we permute the arguments + # op=lambda self, other, **kwargs: torch.igammac(other, self, **kwargs), + # inplace_variant=None, + # method_variant=None, + # supports_rhs_python_scalar=False, + # rhs_make_tensor_kwargs=dict(requires_grad=False), + # dtypes=floating_types_and(torch.bfloat16, torch.float16), + # backward_dtypesIfCPU=floating_types_and(torch.bfloat16), + # dtypesIfCUDA=floating_types(), + # backward_dtypesIfCUDA=floating_types(), + # supports_inplace_autograd=False, + # decorators=[ + # # Derivative wrt first tensor not implemented + # DecorateInfo(unittest.expectedFailure, "TestCommon", + # "test_floating_inputs_are_differentiable"), + # ], + # skips=( + # # test does not work with passing lambda for op + # # AssertionError: False is not true : Tensors failed to compare as equal! + # DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # # test fails are we permute the arguments function variant + # # but not for inplace or method. + # DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # # TypeError: igammac(): argument 'input' (position 1) must be Tensor, not float + # DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'), + # )), + UnaryUfuncInfo('nn.functional.softshrink', + aten_name="softshrink", + aten_backward_name='softshrink_backward', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + sample_inputs_func=sample_inputs_softshrink, + error_inputs_func=error_inputs_softshrink), + UnaryUfuncInfo('nn.functional.hardshrink', + aten_name="hardshrink", + aten_backward_name='hardshrink_backward', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_hardshrink, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::hardshrink"]), + UnaryUfuncInfo('nn.functional.hardtanh', + aten_name="hardtanh", + aten_backward_name='hardtanh_backward', + dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.half, torch.bfloat16), + backward_dtypes=all_types_and(torch.half, torch.bfloat16), + backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_hardtanh, + error_inputs_func=error_inputs_hardtanh, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::hardtanh"]), + OpInfo('nn.functional.gelu', + aten_name="gelu", + aten_backward_name='gelu_backward', + ref=reference_gelu if TEST_SCIPY else None, + error_inputs_func=error_inputs_gelu, + supports_autograd=True, + assert_autodiffed=True, + sample_inputs_func=sample_inputs_gelu, + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_gradgrad=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::gelu"], + skips=( + # AssertionError: Tensor-likes are not close! + # May not replicate in CI + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + )), + UnaryUfuncInfo('nn.functional.relu6', + aten_name="relu6", + dtypes=all_types_and(torch.half, torch.bfloat16), + backward_dtypes=floating_types_and(torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::relu6"]), + OpInfo('mm', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_mm, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + # Fast math on MacOS-13? + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=2e-5, rtol=5e-6)}), + 'TestConsistency', + 'test_output_match', + active_if=lambda _: MACOS_VERSION < 14.0, + device_type='mps', + dtypes=(torch.float32,)), + )), + OpInfo('mode', + op=torch.mode, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Resized a non-empty tensor but did not warn about it + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FIXME: + # Expected 2114 but got 1123. + # Absolute difference: 991 (up to 0.001 allowed) + # Relative difference: 0.46877956480605487 (up to 0.001 allowed) + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_compare_cpu", + dtypes=(torch.float32,), + device_type="cuda", + ), + ), + sample_inputs_func=sample_inputs_mode,), + make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_1', + domain=(1, None), + skips=skips_mvlgamma(), + sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})), + make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_3', + domain=(2, None), + skips=skips_mvlgamma(), + sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})), + make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_5', + domain=(3, None), + skips=skips_mvlgamma(), + sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})), + BinaryUfuncInfo('ne', + ref=np.not_equal, + aliases=('not_equal',), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + OpInfo('narrow', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=True), + reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=True), + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=False), + skips=( + # Use of .item() + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + )), + OpInfo('narrow_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=True, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + supports_autograd=False, + # https://github.com/pytorch/pytorch/issues/86931 + sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=False), + reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=False), + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=False), + skips=( + # https://github.com/pytorch/pytorch/issues/84577 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # Could not run 'aten::narrow_copy.out' with arguments from the 'CUDA' backend + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace', + device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace', + device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace', + device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), + )), + OpInfo('view_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + ref=lambda x, newshape: np.reshape(x, newshape).copy(), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + sample_inputs_func=sample_inputs_view_reshape, + error_inputs_func=error_inputs_view_reshape, + skips=( + # RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + DecorateInfo( + unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides" + ), + )), + UnaryUfuncInfo('neg', + aliases=('negative', ), + ref=np.negative, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), + error_inputs_func=error_inputs_neg, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True), + OpInfo('dist', + op=torch.dist, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: + # Could not allocate memory to change Tensor SizesAndStrides! + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_dist), + OpInfo('outer', + op=torch.outer, + aliases=('ger', ), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_outer,), + OpInfo('ormqr', + op=torch.ormqr, + dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + sample_inputs_func=sample_inputs_ormqr, + error_inputs_func=error_inputs_ormqr, + decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack], + skips=( + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + )), + OpInfo('permute', + ref=np.transpose, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_varargs=True, + sample_inputs_func=sample_inputs_permute, + reference_inputs_func=reference_inputs_permute), + OpInfo('permute_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_varargs=False, # torch.permute is also not varargs + sample_inputs_func=sample_inputs_permute, + reference_inputs_func=reference_inputs_permute, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + )), + BinaryUfuncInfo('pow', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), + ref=np.power, + # Due to AVX2 currently not being fully supported for Float16, log_vml_cpu can't be enabled + # for Float16, causing this test to fail. pow's autograd for Float16 is thus currently + # unsupported on CPU. + backward_dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + supports_one_python_scalar=True, + # Integer types do not support negative exponentes + rhs_make_tensor_kwargs=dict(low=0), + # Raising negative real numbers to fractional powers is not supported + lhs_make_tensor_kwargs=dict(low=0), + decorators=( + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05), + torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_scalar_support'), + ), + skips=( + # Skipping integers because they are being raised to negative powers causing an error + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_small_values', + dtypes=[torch.int8, torch.int16, torch.int32, torch.int64]), + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_large_values', + dtypes=[torch.int16, torch.int32, torch.int64]), + # FIXME Complex values error with: Greatest absolute difference: nan at index + # Ref: https://github.com/pytorch/pytorch/issues/76853 + # For `chalf`, reference computation in `numpy` is computed in `cfloat`. + # Output of `chalf` saturates to `inf` quicker than reference due to its small range + # which leads to failure of this test. + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick', + dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM), + # FIXME: + # Mismatched elements: 1 / 500 (0.2%) + # Greatest absolute difference: nan at index (7, 9, 0) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (7, 9, 0) (up to 0.001 allowed) + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing', + dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_batch_vs_slicing', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_non_contig', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + )), + BinaryUfuncInfo('float_power', + ref=np.float_power, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), + promotes_int_to_float=True, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_one_python_scalar=True, + # Integer types do not support negative exponentes + rhs_make_tensor_kwargs=dict(low=0), + # Raising negative real numbers to fractional powers is not supported + lhs_make_tensor_kwargs=dict(low=0), + decorators=( + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05), + torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_scalar_support'), + ), + skips=( + # FIXME + # AssertionError: Object comparison failed: torch.float64 != torch.float32 + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + # -3.43399e+38 is outside the range of representable values of type 'float' + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Complex values error with: Greatest absolute difference: nan at index + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values', + dtypes=[torch.complex64, torch.complex128]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values', + dtypes=[torch.complex64, torch.complex128]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', + dtypes=[torch.complex64, torch.complex128]), + # Inplace always promotes to double and thus other floating dtypes are not supported + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace', + dtypes=[torch.bfloat16, torch.float16, torch.float32]), + )), + OpInfo('qr', + op=torch.qr, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_qr_geqrf, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # In-place ops + check_batched_gradgrad=False, + decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack]), + UnaryUfuncInfo('rad2deg', + ref=np.degrees, + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 7e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True), + UnaryUfuncInfo('real', + ref=np.real, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + # Skip since real and imag don't have out variants. + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), + )), + OpInfo( + "roll", + ref=np.roll, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + error_inputs_func=error_inputs_roll, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_roll, + decorators=(onlyNativeDeviceTypes,), + ), + OpInfo( + "rot90", + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + error_inputs_func=error_inputs_rot90, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_rot90, + ), + # To test reference numerics against multiple values of argument `decimals`, + # we make multiple OpInfo entries with each entry corresponding to different value of decimals. + UnaryUfuncInfo('round', + ref=np.round, + aliases=('special.round',), + dtypes=all_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=tuple(t for t in integral_types() if t != torch.uint8)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bfloat16,)), + ), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True, + ), + UnaryUfuncInfo('round', + ref=np.round, + variant_test_name='decimals_0', + aliases=('special.round',), + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_kwargs=lambda device, dtype, input: ({'decimals': 0}, {'decimals': 0}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 0}), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + supports_sparse_csr=False), + UnaryUfuncInfo('round', + ref=np.round, + variant_test_name='decimals_3', + aliases=('special.round',), + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + sample_kwargs=lambda device, dtype, input: ({'decimals': 3}, {'decimals': 3}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 3}), + skips=( + # test_ops already tested for this overload with `decimals_0` opinfo entry + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'), + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), + "TestUnaryUfuncs", "test_reference_numerics_extremal", + device_type="cuda"), + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), + "TestUnaryUfuncs", "test_reference_numerics_normal", + device_type="cuda"), + ), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + supports_sparse_csr=False), + UnaryUfuncInfo('round', + ref=np.round, + variant_test_name='decimals_neg_3', + aliases=('special.round',), + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + sample_kwargs=lambda device, dtype, input: ({'decimals': -3}, {'decimals': -3}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': -3}), + skips=( + # test_ops already tested for this overload with `decimals_0` opinfo entry + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'), + ), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + supports_sparse_csr=False), + UnaryUfuncInfo('sin', + ref=np.sin, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + handles_large_floats=False, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + # Fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ), + decorators=(precisionOverride({torch.bfloat16: 1e-2}),)), + UnaryUfuncInfo('sinc', + ref=np_sinc_with_fp16_as_fp32, + aliases=('special.sinc',), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + handles_large_floats=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + UnaryUfuncInfo('sinh', + ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.float16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cdouble,)), + # Reference: https://github.com/pytorch/pytorch/issues/48641 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.int8]), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + UnaryUfuncInfo('sign', + ref=reference_sign, + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half), + dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/41245 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]), + )), + UnaryUfuncInfo('sgn', + ref=reference_sgn, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/41245 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + OpInfo('split', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=partial(sample_inputs_split, list_args=False), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + assert_autodiffed=True), + OpInfo('split', + # Cannot declare this aten_name because of + # test_variant_consistency_jit_split_list_args_cpu_float32 + decomp_aten_name='split_with_sizes', + variant_test_name='list_args', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=partial(sample_inputs_split, list_args=True), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + # `unsafe_split` supports only `int` for split_size argument + OpInfo('unsafe_split', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=partial(sample_inputs_split, list_args=False), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + assert_autodiffed=True, + check_batched_forward_grad=False), + OpInfo('split_with_sizes', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=sample_inputs_split_with_sizes, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True), + OpInfo('split_with_sizes_copy', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=sample_inputs_split_with_sizes, + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # No error raised + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_requires_grad_error"), + )), + BinaryUfuncInfo('__radd__', + op=torch.Tensor.__radd__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + + ), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=['aten::add'],), + BinaryUfuncInfo('__rdiv__', + op=torch.Tensor.__rdiv__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + promotes_int_to_float=True, + lhs_make_tensor_kwargs={'exclude_zero': True}, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + skips=( + # https://github.com/pytorch/pytorch/issues/76806 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + ), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + autodiff_nonfusible_nodes=['aten::mul', 'aten::reciprocal'],), + BinaryUfuncInfo('__rmul__', + op=torch.Tensor.__rmul__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + ), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=['aten::mul'],), + BinaryUfuncInfo('__rand__', + op=torch.Tensor.__rand__, + dtypes=integral_types_and(torch.bool), + supports_out=False, + supports_autograd=False, + supports_forward_ad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + BinaryUfuncInfo('__ror__', + op=torch.Tensor.__ror__, + dtypes=integral_types_and(torch.bool), + supports_out=False, + supports_autograd=False, + supports_forward_ad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + BinaryUfuncInfo('__rxor__', + op=torch.Tensor.__rxor__, + dtypes=integral_types_and(torch.bool), + supports_out=False, + supports_autograd=False, + supports_forward_ad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + OpInfo('__rmatmul__', + op=torch.Tensor.__rmatmul__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, + *[torch.bfloat16] + if SM53OrLater or TEST_WITH_ROCM else []), + assert_autodiffed=True, + sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=True), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + decorators=( + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestMathBits', 'test_conj_view'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1e-05)}), + "TestDecomp", "test_comprehensive", device_type="cuda", + active_if=TEST_WITH_ROCM), + ), + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + # https://github.com/pytorch/pytorch/issues/67470 + DecorateInfo(unittest.skip("67470!"), + 'TestCommon', 'test_noncontiguous_samples', + device_type='cpu', dtypes=(torch.long,)), + # Fails on XLA. + # AssertionError: False is not true : Tensors failed to compare as equal + DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)), + # https://github.com/pytorch/pytorch/issues/71774 + DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', + device_type='cpu', dtypes=(torch.long,)), + )), + BinaryUfuncInfo('__rmod__', + op=torch.Tensor.__rmod__, + dtypes=floating_types_and(torch.bfloat16, torch.half,), + dtypesIfCUDA=all_types_and(torch.bfloat16, torch.half), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_one_python_scalar=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + ), + # Support autograd after torch.remainder(Tensor, Tensor) supports + # autograd of the second argument. + # https://github.com/pytorch/pytorch/pull/58476/files#r637167630 + # supports_autograd=False, + assert_autodiffed=True, + autodiff_nonfusible_nodes=['aten::remainder'],), + BinaryUfuncInfo('__rpow__', + op=torch.Tensor.__rpow__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + # Reference: https://github.com/pytorch/pytorch/issues/54774 + # "log2" "_vml_cpu" not implemented for Half + backward_dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_one_python_scalar=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + # TODO: FIXME tolerance is too high + DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients'), + DecorateInfo(unittest.skip('Skipped!'), 'TestBwdGradients'), + ), + assert_autodiffed=True, + autodiff_nonfusible_nodes=['aten::pow'],), + BinaryUfuncInfo('__rsub__', + op=torch.Tensor.__rsub__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + supports_one_python_scalar=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + ), + assert_autodiffed=True, + autodiff_nonfusible_nodes=['aten::rsub'],), + BinaryUfuncInfo('rsub', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + supports_inplace_autograd=False, + assert_autodiffed=None, + sample_inputs_func=sample_inputs_add_sub), + OpInfo('select', + aten_backward_name='select_backward', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=sample_inputs_select, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('select_scatter', + dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_select_scatter, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('slice', + op=torch.ops.aten.slice.Tensor, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=sample_inputs_slice, + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_scripting=False, + supports_inplace_autograd=False, + supports_out=False), + OpInfo('slice_scatter', + dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_slice_scatter, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=True), + UnaryUfuncInfo('signbit', + ref=np.signbit, + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False,), + UnaryUfuncInfo('tan', + ref=np.tan, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + decorators=(DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=1e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda'),), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + # FIXME: + # Mismatched elements: 2 / 400 (0.5%) + # Greatest absolute difference: inf at index (7, 16) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (7, 16) (up to 0.001 allowed) + DecorateInfo( + unittest.skip("Skipped!"), + "TestInductorOpInfo", + "test_comprehensive", + dtypes=(torch.float16,), + device_type="cuda", + ), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=3e-5, rtol=7e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ), + # tan(pi/2 * odd_number) is nan + reference_numerics_filter=NumericsFilter( + condition=lambda x: close_to_int(x / (math.pi * 0.5)), safe_val=math.pi)), + UnaryUfuncInfo('tanh', + ref=np.tanh, + aten_backward_name='tanh_backward', + aliases=('nn.functional.tanh',), + decorators=(precisionOverride({torch.bfloat16: 1e-2}), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=2e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda'),), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=3e-5, rtol=7e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + ), + # tan(j * pi/2 * odd_number) is nan + reference_numerics_filter=NumericsFilter( + condition=lambda x: (close_to_int(x / (math.pi * 0.5j)) + if x.is_complex() else x.new_tensor(False, dtype=torch.bool)), + safe_val=0)), + OpInfo('tensor_split', + ref=np.array_split, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Pre-existing condition; Needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), + ), + sample_inputs_func=sample_inputs_tensor_split,), + OpInfo('hsplit', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_hsplit, + error_inputs_func=error_inputs_hsplit,), + OpInfo('vsplit', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_vsplit, + error_inputs_func=error_inputs_vsplit,), + OpInfo('dsplit', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_dsplit, + error_inputs_func=error_inputs_dsplit,), + OpInfo('triangular_solve', + op=torch.triangular_solve, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_legacy_solve, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs), + decorators=[ + skipCUDAIfNoMagma, + skipCPUIfNoLapack, + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=3e-5, rtol=3e-6)}), + 'TestConsistency', 'test_output_match', device_type='cpu', + ), + ], + skips=( + # AssertionError: Scalars are not equal! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # Gradcheck fails + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', + dtypes=floating_and_complex_types()), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + device_type='mps', dtypes=[torch.float32]), + )), + UnaryUfuncInfo('trunc', + aliases=('fix', ), + ref=np.trunc, + dtypes=all_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=tuple(t for t in integral_types() if t != torch.uint8)), + ), + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True), + UnaryUfuncInfo('exp2', + aliases=('special.exp2', ), + ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cdouble]), + # Reference: https://github.com/pytorch/pytorch/issues/48010 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + )), + UnaryUfuncInfo('expm1', + aliases=('special.expm1', ), + ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + assert_autodiffed=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.complex128]), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + UnaryUfuncInfo('nan_to_num', + ref=np.nan_to_num, + dtypes=all_types_and(torch.half, torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.half, torch.bool, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + skips=( + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + ), + # Passing numpy_kwargs via sample_kwargs, as numpy does comparison + # with BFloat16 in float, since it currently doesn't support BFloat16. + # Ref: https://github.com/pytorch/pytorch/issues/57982#issuecomment-839150556 + sample_kwargs=lambda device, dtype, input: ({}, + {'posinf': torch.finfo(torch.bfloat16).max, + 'neginf': torch.finfo(torch.bfloat16).min}) + if dtype is torch.bfloat16 else ({}, {})), + UnaryUfuncInfo('reciprocal', + ref=np_unary_ufunc_integer_promotion_wrapper(np.reciprocal), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/45690 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble]), + )), + UnaryUfuncInfo('rsqrt', + ref=lambda x: np.reciprocal(np.sqrt(x)), + domain=(0, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + decorators=(precisionOverride({torch.half: 5e-2}),), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble)), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (700,) (up to 0.01 allowed) + # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.chalf,)), + )), + UnaryUfuncInfo('sqrt', + ref=np.sqrt, + supports_sparse=True, + domain=(0, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=( + precisionOverride({torch.bfloat16: 7e-2}), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestUnaryUfuncs', 'test_reference_numerics_large'), + ), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/47358 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=(torch.cfloat, torch.cdouble), + active_if=IS_MACOS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=2e-5, rtol=3e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + )), + UnaryUfuncInfo('square', + ref=np.square, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/52549 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble]), + # >>> t = torch.tensor(complex(-0.01, float("inf"))) + # >>> np.square(t.numpy()) + # (-inf-infj) + # >>> t.square() + # tensor(-inf-infj) + # >>> t.cuda().square() + # tensor(inf+nanj, device='cuda:0') + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace', + dtypes=[torch.bool]), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace', + dtypes=[torch.bool]), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace', + dtypes=[torch.bool]), + ),), + OpInfo('lerp', + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), + dtypesIfCUDA=floating_and_complex_types_and(torch.chalf, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_lerp, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True), + UnaryUfuncInfo('angle', + ref=np.angle, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool), + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2}),), + backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_complex_to_float=True, + skips=( + # Ref: https://github.com/pytorch/pytorch/issues/78413 + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64),), + )), + UnaryUfuncInfo('isfinite', + ref=np.isfinite, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_autograd=False), + UnaryUfuncInfo('isinf', + ref=np.isinf, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False), + UnaryUfuncInfo('isposinf', + ref=np.isposinf, + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False), + UnaryUfuncInfo('isneginf', + ref=np.isneginf, + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False), + UnaryUfuncInfo('isreal', + ref=np.isreal, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_autograd=False), + UnaryUfuncInfo('isnan', + ref=np.isnan, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False), + OpInfo('einsum', + # we need this lambda because SampleInput expects tensor input as the first argument + # TODO(@heitorschueroff) update SampleInput to handle such cases + op=lambda tensors, equation: torch.einsum(equation, tensors), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + # See https://github.com/pytorch/pytorch/issues/66357 + sample_inputs_func=sample_inputs_einsum, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # test does not work with passing lambda for op + # there's a test `test_einsum` in `test_jit.py` to handle this case + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('svd', + op=torch.svd, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_svd, + # Runs very slowly on slow-gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + # We're using at::allclose, which does not have a batching rule + check_batched_grad=False, + check_batched_gradgrad=False, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + device_type='mps', dtypes=[torch.float32]), + )), + OpInfo('svd_lowrank', + op=lambda *args, **kwargs: wrapper_set_seed( + lambda a, b, **kwargs: torch.svd_lowrank(a @ b.mT, **kwargs), + *args, **kwargs + ), + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + # Due to the use of randomness + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_svd_lowrank, + decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off, + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.complex64: tol(atol=1e-02, rtol=1e-02)}), + 'TestCommon', 'test_noncontiguous_samples'), + # FIXME This should be the following, but the toleranceOverride does not seem to do anything! + # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}), + # 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + DecorateInfo(unittest.skip("See comment above"), + 'TestFwdGradients', + 'test_fn_fwgrad_bwgrad', + dtypes=[torch.complex128]), + ], + skips=( + # test does not work with passing lambda for op + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(slowTest, 'TestCompositeCompliance', 'test_forward_ad'), + )), + OpInfo('pca_lowrank', + op=lambda *args, **kwargs: wrapper_set_seed( + lambda a, b, **kwargs: torch.pca_lowrank(a @ b.mT, **kwargs), + *args, **kwargs + ), + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + check_batched_forward_grad=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_pca_lowrank, + decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off, + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.complex64: tol(atol=4e-02, rtol=4e-02)}), + 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=5e-05)}), + 'TestOperators', 'test_grad'), + # FIXME This should be the following, but the toleranceOverride does not seem to do anything! + # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}), + # 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + DecorateInfo(unittest.skip("See comment above"), + 'TestFwdGradients', + 'test_fn_fwgrad_bwgrad', + dtypes=[torch.complex128]), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=3e-5, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'), + ], + skips=( + # test does not work with passing lambda for op + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + BinaryUfuncInfo('polar', + dtypes=floating_types(), + # this function is undefined if 'abs' values are <0 + supports_forward_ad=True, + lhs_make_tensor_kwargs=dict(low=0), + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: Expected object of scalar type Float but got scalar type Double for second argument + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + # GradcheckError: Jacobian computed with forward mode mismatch for output 0 with respect to input 0 + # Numerical: + # tensor([[0.]], dtype=torch.float64) + # Analytical: + # tensor([[-0.0047]], dtype=torch.float64, grad_fn=) + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + )), + # TODO(@kshitij12345): Refactor similar to `mvlgamma` entries. + # To test reference numerics against multiple values of argument `n`, + # we make multiple OpInfo entries with each entry corresponding to different value of n (currently 0 to 4). + # We run the op tests from test_ops.py only for `n=0` to avoid redundancy in testing. + UnaryUfuncInfo('polygamma', + op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs), + variant_test_name='polygamma_n_0', + ref=reference_polygamma if TEST_SCIPY else None, + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + sample_inputs_func=sample_inputs_polygamma, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + ), + sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0}), + # polygamma functions have multiple singularities at x having non-positive integer value + reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), + safe_val=1)), + *(UnaryUfuncInfo('polygamma', + op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs), + variant_test_name=f'polygamma_n_{n_}', + ref=reference_polygamma if TEST_SCIPY else None, + dtypes=all_types_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + sample_inputs_func=sample_inputs_polygamma, + decorators=( + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-3)}), 'TestUnaryUfuncs'), + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e1, rtol=1e-1), + torch.float32: tol(atol=1e-4, rtol=1e-2)}), + 'TestUnaryUfuncs', 'test_reference_numerics_normal', + active_if=IS_WINDOWS), + ), + skips=( + # Redundant tests + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), + # Mismatch: https://github.com/pytorch/pytorch/issues/55357 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large'), + ), + sample_kwargs=lambda device, dtype, input: ({'n': n_}, {'n': n_}), + # polygamma functions have multiple singularities at x having non-positive integer value + reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), + safe_val=1)) + for n_ in (1, 2, 3, 4)), + OpInfo('ravel', + ref=np.ravel, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_ravel, + ), + OpInfo('unravel_index', + ref=np.unravel_index, + dtypes=integral_types_and(), + supports_out=False, + supports_autograd=False, + sample_inputs_func=sample_inputs_unravel_index, + ), + OpInfo('reshape', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_view_reshape, + reference_inputs_func=reference_inputs_view_reshape, + error_inputs_func=error_inputs_view_reshape, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + OpInfo('reshape_as', + op=lambda x, other: x.reshape_as(other), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True), + reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True), + error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + )), + OpInfo('view', + op=lambda x, shape: x.view(shape), + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs_view_reshape, + reference_inputs_func=reference_inputs_view_reshape, + error_inputs_func=error_inputs_view_reshape, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + )), + OpInfo('view_as', + op=lambda x, other: x.view_as(other), + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True), + reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True), + error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True), + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: view size is not compatible with input tensor's size and stride + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides") + )), + OpInfo('atleast_1d', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_atleast1d2d3d, + skips=( + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + ), + ), + OpInfo('atleast_2d', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + ), + sample_inputs_func=sample_inputs_atleast1d2d3d, + ), + OpInfo('atleast_3d', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + ), + sample_inputs_func=sample_inputs_atleast1d2d3d, + ), + OpInfo('flatten', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + ref=reference_flatten, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_flatten, + reference_inputs_func=reference_inputs_flatten, + ), + OpInfo('unflatten', + op=torch.unflatten, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_unflatten, + ), + OpInfo('column_stack', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_column_stack,), + OpInfo('pinverse', + op=torch.pinverse, + dtypes=floating_and_complex_types(), + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False, + sample_inputs_func=sample_inputs_linalg_invertible, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + device_type='mps', dtypes=[torch.float32]), + )), + OpInfo('gather', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_gather, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_gather, + ), + OpInfo('index_fill', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), + inplace_variant=torch.Tensor.index_fill_, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal! + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'), + # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal! + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_amp'), + ), + sample_inputs_func=sample_inputs_index, + reference_inputs_func=partial(sample_inputs_index, reference=True)), + OpInfo('index_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_index, + reference_inputs_func=partial(sample_inputs_index, reference=True), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + OpInfo('index_select', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_index, + reference_inputs_func=partial(sample_inputs_index, reference=True), + error_inputs_func=error_inputs_index_select, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + OpInfo('index_add', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + inplace_variant=torch.Tensor.index_add_, + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_index, + reference_inputs_func=partial(sample_inputs_index, reference=True), + error_inputs_func=error_inputs_index_add, + skips=( + # boolean alpha not handled properly + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bool,)), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + *(OpInfo('index_reduce', + variant_test_name=reduction_type, + dtypes=all_types_and(torch.float16, torch.bfloat16), + skips=( + DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-3, rtol=3e-3)}), + 'TestInductorOpInfo', 'test_comprehensive'), + ), + supports_out=True, + sample_inputs_func=sample_inputs_index_reduce, + ) for reduction_type in ('mean', 'prod', 'amin', 'amax')), + OpInfo('_unsafe_masked_index', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + supports_out=False, + supports_inplace_autograd=False, + supports_scripting=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs__unsafe_masked_index, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(slowTest, 'TestDecomp', 'test_quick_core_backward', + dtypes=(torch.float64,), active_if=IS_WINDOWS), + ),), + OpInfo('_unsafe_masked_index_put_accumulate', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + supports_out=False, + supports_inplace_autograd=False, + supports_scripting=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=2e-3, rtol=3e-2)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu' + ), + ), + sample_inputs_func=sample_inputs__unsafe_masked_index_put_accumulate, + skips=( + DecorateInfo(slowTest, 'TestDecomp', 'test_quick_core_backward', + dtypes=(torch.float64,), active_if=IS_WINDOWS), + ),), + OpInfo('__getitem__', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_inplace_autograd=False, + supports_scripting=False, + op=torch.Tensor.__getitem__, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 104448 + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),), + sample_inputs_func=sample_inputs_getitem), + OpInfo('index_put', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_inplace_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + test_neg_view=False, + sample_inputs_func=sample_inputs_index_put, + skips=( + DecorateInfo(unittest.skip("Skipped"), 'TestBwdGradients', 'test_fn_grad', dtypes=[torch.float64], + device_type='cuda', active_if=(TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)), + )), + OpInfo('sort', + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_sort, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], device_type='cuda', active_if=not TEST_WITH_ROCM), + )), + OpInfo('unique', + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64), + sample_inputs_func=sample_inputs_unique, + supports_out=False, + supports_autograd=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Output order is undefined when sorted=False'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('unique_consecutive', + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_unique_consecutive, + supports_out=False, + supports_autograd=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('put', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + check_batched_gradgrad=False, # vmap complains of the sizes + sample_inputs_func=sample_inputs_put), + OpInfo('take', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + check_batched_grad=False, # vmap complains of the sizes + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_take, + error_inputs_func=error_inputs_take), + OpInfo('scatter', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter, + error_inputs_func=error_inputs_scatter_and_scatter_add, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + UnaryUfuncInfo( + 'bfloat16', + op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + )), + UnaryUfuncInfo( + 'bool', + op=lambda x, *args, **kwargs: x.bool(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attributis not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + UnaryUfuncInfo( + 'byte', + op=lambda x, *args, **kwargs: x.byte(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_byte, + # The autograd test runner cannot handle functions that change dtype + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'char', + op=lambda x, *args, **kwargs: x.char(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + # The autograd test runner cannot handle functions that change dtype + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'double', + op=lambda x, *args, **kwargs: x.double(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + UnaryUfuncInfo( + 'float', + op=lambda x, *args, **kwargs: x.float(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + UnaryUfuncInfo( + 'half', + op=lambda x, *args, **kwargs: x.half(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=True, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + UnaryUfuncInfo( + 'int', + op=lambda x, *args, **kwargs: x.int(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'long', + op=lambda x, *args, **kwargs: x.long(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'short', + op=lambda x, *args, **kwargs: x.short(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'cdouble', + op=torch.Tensor.cdouble, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + )), + UnaryUfuncInfo( + 'cfloat', + op=torch.Tensor.cfloat, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + )), + UnaryUfuncInfo( + 'chalf', + op=lambda x, *args, **kwargs: x.chalf(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + # use of lambda doesn't work with test_normalize_operator_exhaustive + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager', + device_type='cpu'), + # TypeError: 'int' object is not iterable + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view', + device_type='cpu'), + # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view', + device_type='cpu'), + # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' + # RuntimeError: "neg_conj_cuda" not implemented for 'ComplexHalf' + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ) + ), + OpInfo('empty_like', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + reference_inputs_func=reference_inputs_like_fns, + supports_autograd=False, + skips=( + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), + "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: empty_like is not comparable"), 'TestCompositeCompliance', + 'test_operator'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('zeros_like', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + supports_autograd=False, + error_inputs_sparse_func=error_inputs_sparse_like_fns, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc), + skips=( + )), + OpInfo('ones_like', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + supports_autograd=False, + skips=( + )), + OpInfo('randn', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32), + op=lambda *args, **kwargs: wrapper_set_seed(torch.randn, *args, **kwargs), + supports_out=True, + sample_inputs_func=sample_inputs_randn, + supports_autograd=False, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + # CPU randn generates different values based on the strides of out tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'), + # randn fails to warn when resizing its out tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + )), + OpInfo('randn_like', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32), + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.randn_like, inp, *args, **kwargs), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + supports_autograd=False, + error_inputs_sparse_func=error_inputs_sparse_like_fns, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc), + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('rand_like', + dtypes=floating_types_and(torch.half, torch.bfloat16, torch.complex32, torch.complex64, torch.complex128), + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.randn_like, inp, *args, **kwargs), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('randint', + dtypes=all_types_and(torch.half, torch.bfloat16), + op=lambda *args, **kwargs: + wrapper_set_seed(torch.randint, *args, **kwargs), + supports_out=False, + sample_inputs_func=sample_inputs_randint, + supports_autograd=False, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + # CPU randint generates different values based on the strides of out tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # randint fails to warn when resizing its out tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_multiple_devices', + dtypes=[torch.float32, torch.int64], active_if=TEST_WITH_ROCM), + )), + OpInfo('randint_like', + dtypes=all_types_and(torch.half, torch.bfloat16), + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.randint_like, inp, *args, **kwargs), + supports_out=False, + sample_inputs_func=sample_inputs_randint_like, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('full_like', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, + torch.uint16, torch.uint32), + supports_out=False, + sample_inputs_func=sample_inputs_full_like, + supports_autograd=False, + ), + OpInfo('new_zeros', + op=lambda x, *args, **kwargs: x.new_zeros(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_new_fns, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + ), + supports_autograd=False), + OpInfo('new_ones', + op=lambda x, *args, **kwargs: x.new_ones(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_new_fns, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + ), + supports_autograd=False), + OpInfo('ones', + op=torch.ones, + supports_autograd=False, + supports_varargs=True, + is_factory_function=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=True, + sample_inputs_func=sample_inputs_ones_zeros, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), + OpInfo('zeros', + op=torch.zeros, + supports_autograd=False, + is_factory_function=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=True, + sample_inputs_func=sample_inputs_ones_zeros, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), + OpInfo('full', + op=torch.full, + supports_autograd=False, + is_factory_function=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=True, + sample_inputs_func=sample_inputs_full, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # RuntimeError: UNSUPPORTED DTYPE: bool + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bool,)), + )), + OpInfo('new_empty', + op=lambda x, *args, **kwargs: x.new_empty(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_new_fns, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: new_empty is not comparable"), 'TestCompositeCompliance', + 'test_operator'), + DecorateInfo(unittest.skip("Expected: new_empty is not comparable"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + supports_autograd=False), + OpInfo('new_empty_strided', + op=lambda x, *args, **kwargs: x.new_empty_strided(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=partial(sample_inputs_new_fns, is_strided=True), + supports_autograd=False, + skips=( + # FX failed to normalize op + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Lazy tensor failures + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCompositeCompliance', 'test_operator'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestDecomp', 'test_quick'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestProxyTensorOpInfo', 'test_make_fx_exhaustive'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('empty_strided', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.empty_strided, inp, *args, **kwargs), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.half), + supports_out=False, + supports_autograd=False, + sample_inputs_func=sample_inputs_empty_strided, + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance', 'test_operator'), + # Lazy tensor failures + DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestLazyOpInfo'), + # RuntimeError: unsupported operation: more than one element of the written-to tensor refers to a single + # memory location. Please clone() the tensor before performing the operation. + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), + )), + OpInfo('empty', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_empty, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance', + 'test_operator'), + # requires_grad doesn't exist in the jit schema + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestLazyOpInfo'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('eye', + dtypes=all_types_complex_float8_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_eye, + error_inputs_func=error_inputs_eye, + supports_out=True, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # TODO: same as this? + # https://github.com/pytorch/pytorch/issues/81774 + # also see: arange, new_full + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn' + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)), + )), + OpInfo('empty_permuted', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_empty_permuted, + error_inputs_func=error_inputs_empty_permuted, + supports_out=False, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), 'TestCompositeCompliance', + 'test_operator'), + # requires_grad doesn't exist in the jit schema + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), + 'TestLazyOpInfo'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('scalar_tensor', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_scalar_tensor, + supports_autograd=False, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + )), + OpInfo('new_full', + op=lambda x, *args, **kwargs: x.new_full(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_new_full, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + ), + supports_autograd=False), + OpInfo('multinomial', + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.multinomial, inp, *args, **kwargs), + method_variant=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.Tensor.multinomial, inp, *args, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_out=True, + sample_inputs_func=sample_inputs_multinomial, + error_inputs_func=error_inputs_multinomial, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Strides are not the same! + # This may not be reproducible in CI + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + supports_autograd=False), + OpInfo('normal', + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.normal, inp, *args, **kwargs), + # The inplace variant (Tensor.normal_) is different from torch.normal + inplace_variant=None, + dtypes=floating_types_and(torch.bfloat16, torch.half), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half), + supports_out=True, + sample_inputs_func=sample_inputs_normal_tensor_first, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Tensor-likes are not close! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # Computed gradient is incorrect -- would be an exfail but gradgrad somehow passes + DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # RuntimeError: Difference from {dtype} is larger with decomposition + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'), + # The inplace variant (Tensor.normal_) is different from torch.normal + # inplace variant Tensor.normal_ is decomposed using randn_like() + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'))), + OpInfo('normal', + # This has its own variant b/c OpInfos assume the first arg is a Tensor but it is not here + variant_test_name='number_mean', + op=lambda std, mean, *args, **kwargs: + wrapper_set_seed(torch.normal, mean, std, *args, **kwargs), + # The inplace variant (Tensor.normal_) is different from torch.normal + inplace_variant=None, + dtypes=floating_types_and(torch.bfloat16, torch.half), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half), + supports_out=True, + sample_inputs_func=sample_inputs_normal_tensor_second, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestEagerFusionOpInfo'), + DecorateInfo(unittest.skip("Skipped!"), 'TestOperators'), + # AssertionError + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'), + # AssertionError + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'), + # AssertionError in CUDA variant + DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDeviceUtils', 'test_device_mode_ops'))), + OpInfo('bernoulli', + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.bernoulli, inp, *args, **kwargs), + # The inplace variant (Tensor.bernoulli_) is different from torch.bernoulli + inplace_variant=None, + method_variant=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.Tensor.bernoulli, inp, *args, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_bernoulli, + error_inputs_func=error_inputs_bernoulli, + skips=( + # vmap: We do not yet support calling random operations inside of vmap + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Expected RuntimeError when doing an unsafe cast from a result of + # dtype torch.float32 into an out= with dtype torch.lon + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))), + OpInfo('scatter_add', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + inplace_variant=torch.Tensor.scatter_add_, + sample_inputs_func=sample_inputs_scatter_add, + error_inputs_func=error_inputs_scatter_and_scatter_add, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('stack', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_stack, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # https://github.com/pytorch/pytorch/issues/77046 + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + )), + OpInfo('_chunk_cat', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_chunk_cat, + error_inputs_func=error_inputs_chunk_cat, + supports_autograd=False, + supports_out=True, + ), + OpInfo('hstack', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_hstack_dstack_vstack, + error_inputs_func=error_inputs_hstack_dstack_vstack, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + BinaryUfuncInfo('hypot', + dtypes=floating_types_and(torch.bfloat16, torch.half), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False), + OpInfo('histogram', + dtypes=floating_types(), + dtypesIfCUDA=_dispatch_dtypes(), # histogram is only implemented on CPU + sample_inputs_func=sample_inputs_histogram, + supports_autograd=False, + skips=( + # JIT tests don't work with Tensor keyword arguments + # https://github.com/pytorch/pytorch/issues/58507 + # RuntimeError: + # undefined value tensor: + # File "", line 3 + # def the_method(i0): + # return torch.histogram(i0, 1, weight=tensor(-0.5735, dtype=torch.float32), density=False) + # ~~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Not Implemented on XLA. + DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla'), + )), + OpInfo('histogramdd', + dtypes=floating_types(), + dtypesIfCUDA=_dispatch_dtypes(), # histogramdd is only implemented on CPU + sample_inputs_func=sample_inputs_histogramdd, + error_inputs_func=error_inputs_histogramdd, + supports_autograd=False, + skips=( + # Not implemented on CUDA + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='cuda'), + # JIT tests don't work with Tensor keyword arguments + # https://github.com/pytorch/pytorch/issues/58507 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('histc', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), + sample_inputs_func=sample_inputs_histc, + supports_out=True, + supports_autograd=False, + skips=( + # CUDA histc returns a float tensor but does not correctly warn when passed an integral out tensor + # "AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast + # from a result of dtype torch.float32 into an out= with dtype torch.long" + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'), + )), + OpInfo('bincount', + dtypes=integral_types_and(), + sample_inputs_func=sample_inputs_bincount, + supports_out=False, + supports_autograd=False, + skips=( + # JIT tests don't work with Tensor keyword arguments + # https://github.com/pytorch/pytorch/issues/58507 + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('bucketize', + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_bucketize, + reference_inputs_func=reference_inputs_bucketize, + error_inputs_func=error_inputs_bucketize, + supports_autograd=False, + skips=( + # JIT tests don't work with Tensor keyword arguments + DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('searchsorted', + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_searchsorted, + supports_autograd=False, + ref=reference_searchsorted, + skips=( + # JIT tests don't work with Tensor keyword arguments + # https://github.com/pytorch/pytorch/issues/58507 + DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('cat', + ref=_cat_np, + aliases=('concat', 'concatenate'), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), + sample_inputs_func=sample_inputs_cat_concat, + reference_inputs_func=reference_inputs_cat, + error_inputs_func=error_inputs_cat, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + assert_autodiffed=True, + skips=( + # https://github.com/pytorch/pytorch/issues/89353 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'), + # RuntimeError: Arguments for call not valid. + # Expected a value of type 'List[Tensor]' for argument + # 'tensors' but instead found type 'Tensor (inferred)'. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), + # see https://github.com/pytorch/pytorch/issues/71286 + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), + # see https://github.com/pytorch/pytorch/issues/99806 + # RuntimeError: The size of tensor a (25) must match the size of tensor b (0) at non-singleton dimension 0. + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), + )), + OpInfo('unbind', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + ref=reference_unbind, + sample_inputs_func=sample_inputs_unbind, + error_inputs_func=error_inputs_unbind, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_gradgrad=True, + supports_out=False, + ), + OpInfo('unbind_copy', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + ref=reference_unbind, + sample_inputs_func=sample_inputs_unbind, + error_inputs_func=error_inputs_unbind, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_gradgrad=True, + supports_out=True, + check_batched_grad=False, + ), + OpInfo('vstack', + aliases=('row_stack',), + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_hstack_dstack_vstack, + error_inputs_func=error_inputs_hstack_dstack_vstack, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # RuntimeError: _fn() Expected a value of type + # 'Tensor (inferred)' for argument 't0' but instead found type 'tuple'. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),)), + OpInfo('dstack', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_hstack_dstack_vstack, + error_inputs_func=error_inputs_hstack_dstack_vstack, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + ), + OpInfo('unfold', + op=lambda x, *args: x.unfold(*args), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_gradgrad=False, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Skip operator schema test because this is a functional and not an operator + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + ), + sample_inputs_func=sample_inputs_unfold), + OpInfo('unfold_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_gradgrad=False, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_unfold), + OpInfo('msort', + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_msort), + OpInfo('movedim', + aliases=('moveaxis',), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_movedim_moveaxis, + reference_inputs_func=reference_movedim_moveaxis, + error_inputs_func=error_movedim_moveaxis), + OpInfo('renorm', + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_renorm, + error_inputs_func=error_inputs_renorm, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # RuntimeError: Difference from float64 is larger with decomposition + # linalg_vector_norm.default than original on output 0. + # Original max diff: 2.560596747969157e-07, + # Decomp max diff: 1.8187482915266173e-06 + DecorateInfo(unittest.skip("Inconsistent accuracy"), 'TestDecomp', 'test_comprehensive', + device_type='cpu', dtypes=(torch.float16,)), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-4, rtol=3e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + )), + ShapeFuncInfo('repeat', + op=lambda x, dims: x.repeat(dims), + ref=np.tile, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_repeat_tile, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + )), + OpInfo('squeeze', + ref=_squeeze_ref, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_squeeze), + OpInfo('squeeze', + ref=_squeeze_ref, + variant_test_name="multiple", + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_squeeze_multiple), + OpInfo('squeeze_copy', + ref=_squeeze_ref, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_squeeze, + skips=( + DecorateInfo( + unittest.expectedFailure, + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,), + ), + )), + UnaryUfuncInfo( + 'fill', + ref=_fill_np, + method_variant=None, + sample_kwargs=_fill_sample_kwargs, + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'value': True}), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + skips=( + # JIT has issue when op is passed as lambda + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("No fill_ op"), 'TestCudaFuserOpInfo'), + DecorateInfo(unittest.skip("No fill_ op"), 'TestNNCOpInfo'), + )), + OpInfo('resize_', + op=lambda x, shape: x.clone().resize_(shape), + method_variant=None, + inplace_variant=torch.Tensor.resize_, + # the test fails because resize_ doesn't work with imag views as expected by the test + # https://github.com/pytorch/pytorch/issues/65945 + test_neg_view=False, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=False, + skips=( + # Cannot resize variables that require grad + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'), + ), + sample_inputs_func=sample_inputs_resize_ops), + OpInfo('resize_as_', + op=lambda x, other: torch.resize_as_(x.clone(), other), + method_variant=None, + inplace_variant=torch.Tensor.resize_as_, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=False, + skips=( + # Cannot resize variables that require grad + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), + DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), + ), + sample_inputs_func=sample_inputs_resize_ops), + OpInfo('take_along_dim', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_take_along_dim, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=( + # RuntimeError: view size is not compatible with input tensor's size and stride + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + )), + ShapeFuncInfo('tile', + ref=np.tile, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_repeat_tile), + OpInfo('trapz', # TODO: in the future, 'trapz' should be made a proper alias of 'trapezoid' + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + decorators=[ + DecorateInfo( + toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda' + ), + ], + sample_inputs_func=sample_trapezoid), + OpInfo('trapezoid', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + decorators=[ + DecorateInfo( + toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda' + ), + ], + sample_inputs_func=sample_trapezoid), + OpInfo('cumulative_trapezoid', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + supports_out=False, + decorators=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=4e-3, rtol=4e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', + ), + ), + sample_inputs_func=sample_cumulative_trapezoid,), + OpInfo('unsqueeze', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + assert_jit_shape_analysis=True, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + sample_inputs_func=sample_unsqueeze), + OpInfo('unsqueeze_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + assert_jit_shape_analysis=True, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + sample_inputs_func=sample_unsqueeze, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo( + unittest.expectedFailure, + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,), + ), + )), + BinaryUfuncInfo('xlogy', + aliases=('special.xlogy',), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + promotes_int_to_float=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_one_python_scalar=True, + # We don't test 0 as the gradient will be NaN and it'll break + rhs_make_tensor_kwargs=dict(low=0.01)), + OpInfo('zero_', + op=lambda x: torch.zero_(x.clone()), + method_variant=None, + inplace_variant=torch.Tensor.zero_, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_gradgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + sample_inputs_func=sample_inputs_zero_), + OpInfo('logsumexp', + aliases=('special.logsumexp',), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_fast_mode=False, + sample_inputs_func=sample_inputs_logsumexp, + reference_inputs_func=reference_inputs_logsumexp), + OpInfo('trace', + dtypes=all_types_and_complex(), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + error_inputs_func=error_inputs_trace, + supports_inplace_autograd=False, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_trace), + OpInfo('transpose', + ref=_numpy_ref_transpose, + aliases=('swapdims', 'swapaxes'), + assert_jit_shape_analysis=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_transpose_swapdims), + OpInfo('transpose_copy', + assert_jit_shape_analysis=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_transpose_swapdims, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo( + unittest.expectedFailure, + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,) + ), + )), + OpInfo('T', + op=lambda x: x.T, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), + sample_inputs_func=sample_inputs_T, + error_inputs_func=error_inputs_T), + OpInfo('H', + op=lambda x: x.H, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), + sample_inputs_func=sample_inputs_T), + OpInfo('mT', + op=lambda x: x.mT, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), + sample_inputs_func=sample_inputs_adjoint), + OpInfo('mH', + op=lambda x: x.mH, + aliases=('adjoint',), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), + sample_inputs_func=sample_inputs_adjoint), + OpInfo('tril', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_tril_triu, + sample_inputs_func=sample_inputs_tril_triu, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('triu', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_tril_triu, + sample_inputs_func=sample_inputs_tril_triu, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('triu_indices', + dtypes=_dispatch_dtypes((torch.int32, torch.int64)), + sample_inputs_func=sample_inputs_trilu_indices, + ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.triu_indices(h, ofs, w), dtype=dtype), + supports_out=False, + supports_autograd=False, + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), + )), + OpInfo('tril_indices', + dtypes=_dispatch_dtypes((torch.int32, torch.int64)), + sample_inputs_func=sample_inputs_trilu_indices, + ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.tril_indices(h, ofs, w), dtype=dtype), + supports_out=False, + supports_autograd=False, + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), + )), + OpInfo('kron', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_kron, + decorators=( + # RuntimeError: view size is not compatible with input tensor's size and stride + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + )), + OpInfo('inner', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_inner, + ), + OpInfo('tensordot', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_tensordot, + skips=( + # Skip operator schema test because this is a functional and not an operator. + # Reference: https://github.com/pytorch/pytorch/issues/54574 + DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + ) + ), + OpInfo('to_sparse', + op=lambda x, *args: x.to_sparse(*args), + sample_inputs_func=sample_inputs_to_sparse, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + backward_dtypes=floating_types(), + backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_sparse_csr=True, + supports_sparse_csc=True, + check_batched_grad=False, + check_batched_gradgrad=False, + skips=( + # NotImplementedError: Could not run 'aten::normal_' with arguments from the 'SparseCPU' backend + DecorateInfo(unittest.skip(""), 'TestCommon', 'test_noncontiguous_samples'), + # TODO: FIXME: complex inputs requiring grad error in forward + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Allowed exception: sparse tensors don't have strides + DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'), + DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.skip("Allowed exception"), 'TestTags', 'test_tags'), + # TODO: implement csr.to_sparse(sample_dim) where sampled_dim is 1. + DecorateInfo(unittest.skip("csr.to_sparse(1) not implemented. Skipped!"), + 'TestSparseCSR', 'test_sparse_csr_consistency'), + # Compiler issue on ROCm. Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + ) + ), + OpInfo('logcumsumexp', + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), + backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cuda'), + # RuntimeError: "max_values_cpu" not implemented for 'ComplexDouble' + # Falling back to non-numerically stabilized exp, causing nan in the results. + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD', dtypes=[torch.complex128]), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', dtypes=[torch.complex128]), + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=7e-5, rtol=6e-3), + }), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + ), + sample_inputs_func=sample_inputs_logcumsumexp, + error_inputs_func=error_inputs_logcumsumexp), + UnaryUfuncInfo('sigmoid', + aliases=('special.expit', 'nn.functional.sigmoid'), + aten_backward_name='sigmoid_backward', + ref=reference_sigmoid if TEST_SCIPY else None, + decorators=(precisionOverride({torch.float16: 1e-2, + torch.complex64: 1e-1, + torch.bfloat16: 1e-2}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/56012 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.complex64, torch.cdouble], device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda')), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.complex32, torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + assert_autodiffed=True, + # sigmoid(z) = 1 / (1 + exp(-z)), at z = j * pi * odd_number, the denominator is zero + reference_numerics_filter=NumericsFilter( + condition=lambda x: (close_to_int(x / (math.pi * 1j)) + if x.is_complex() else x.new_tensor(False, dtype=torch.bool)), + safe_val=0)), + UnaryUfuncInfo('digamma', + ref=scipy.special.digamma if TEST_SCIPY else None, + aliases=('special.psi', 'special.digamma',), + decorators=(precisionOverride({torch.float16: 5e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + UnaryUfuncInfo('erf', + ref=scipy.special.erf if TEST_SCIPY else None, + aliases=('special.erf', ), + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + + ), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + UnaryUfuncInfo('erfc', + ref=scipy.special.erfc if TEST_SCIPY else None, + aliases=('special.erfc', ), + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + UnaryUfuncInfo('erfinv', + ref=scipy.special.erfinv if TEST_SCIPY else None, + aliases=('special.erfinv', ), + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2, + torch.float32: 1e-4}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + domain=(-1, 1), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo(unittest.expectedFailure, 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + OpInfo("nn.functional.smooth_l1_loss", + ref=reference_smooth_l1_loss, + sample_inputs_func=sample_inputs_smooth_l1_loss, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + backward_dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED + # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),)), + OpInfo( + "nn.functional.l1_loss", + ref=loss_reference_reduction_wrapper(lambda input, target: np.abs(input - target)), + sample_inputs_func=sample_inputs_l1_loss, + error_inputs_func=error_inputs_l1_loss, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED + # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch. + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32,), + ), + ), + ), + UnaryUfuncInfo('lgamma', + ref=reference_lgamma if TEST_SCIPY else None, + aliases=('special.gammaln', ), + decorators=(precisionOverride({torch.float16: 7e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), + ), + # lgamma have multiple singularities at x <= 0 + reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)), + OpInfo( + 'logdet', + dtypes=floating_and_complex_types(), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), + # `log_softmax` supports different dtypes based on whether `dtype` argument, + # is passed or not. Hence two OpInfo entries, one with dtype and other without. + OpInfo( + 'log_softmax', + aliases=('special.log_softmax', 'nn.functional.log_softmax'), + supports_out=True, + aten_backward_name='_log_softmax_backward_data', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_softmax_variant, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True), + OpInfo( + 'log_softmax', + variant_test_name='with_dtype', + aliases=('special.log_softmax', 'nn.functional.log_softmax'), + supports_out=True, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True), + UnaryUfuncInfo('logit', + aten_backward_name='logit_backward', + ref=scipy.special.logit if TEST_SCIPY else None, + domain=(0, 1), + aliases=('special.logit', ), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 5e-1, + torch.float16: 5e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_logit), + OpInfo('where', + # Currently only the `input` is tested in gradcheck. + # If we pass `condition` first, none of the input which supports + # autograd will be tested. Hence the following lambda. + op=lambda self, condition, other, **kwargs: torch.where(condition, self, other, **kwargs), + ref=lambda self, condition, other: np.where(condition, self, other), + sample_inputs_func=sample_inputs_where, + reference_inputs_func=reference_inputs_where, + error_inputs_func=error_inputs_where, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo(onlyCUDA, "TestCommon", 'test_errors'),), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + ), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf)), + OpInfo('nonzero', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_nonzero, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # nonzero(): argument 'out' must be Tensor, not tuple + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # https://github.com/pytorch/pytorch/issues/67458 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # nonzero is not raising a warning when the out is resized + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # Can't find schemas for this operator for some reason + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + # Compiler issue on ROCm. Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('nonzero_static', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_nonzero_static, + supports_out=False, + supports_autograd=False, + decorators=[onlyCPU], + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'), + DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + # Following tests are for jiterator's python interface + # Jiterator can be used to author elementwise CUDA kernel + # jiterator._create_jit_fn returns a callable that behaves like a regular pytorch op + # See create_jit_fn in jiterator.py for more information + UnaryUfuncInfo( + 'jiterator_unary', + op=torch.cuda.jiterator._create_jit_fn("template T unary(T x) { return x * x + x; }"), + ref=lambda x: x * x + x, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + decorators=[ + onlyCUDA, + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_hard'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_normal'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_small'), + ], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Skip reference_numerics tests for bool type, as the defined function doesn't work for bool + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.bool]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + dtypes=[torch.bool]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + dtypes=[torch.bool]), + # ROCm generates -inf+infj instead of nan+infj for complex64 for some of the results + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.complex64], active_if=TEST_WITH_ROCM), + # Newer numpy generates -inf+infj instead of nan+infj for complex64 for some of the results + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.complex64], device_type='cuda'), + # Expected failure: torch.jiterator_unary is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + BinaryUfuncInfo( + 'jiterator_binary', + op=torch.cuda.jiterator._create_jit_fn( + "template T binary(T x, T y, T alpha) { return x + alpha * y; }", alpha=1), + ref=lambda input, other, *, alpha=1: ( + np.add(input, other) + if alpha == 1 + else np.add(input, np.multiply(alpha, other)) + ), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-3.14), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + supports_rhs_python_scalar=False, + decorators=[onlyCUDA], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Expected failure: torch.jiterator_binary is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + OpInfo( + 'jiterator_4inputs_with_extra_args', + op=torch.cuda.jiterator._create_jit_fn( + "template T binary(T i0, T i1, T i2, T i3, T alpha, T beta) { return alpha * i0 + beta * i1 + i2 + i3; }", + alpha=1, beta=1), + ref=lambda i0, i1, i2, i3, *, alpha=1, beta=1: alpha * i0 + beta * i1 + i2 + i3, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=4, alpha=3.14, beta=-4.20), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + decorators=[onlyCUDA], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + BinaryUfuncInfo( + 'jiterator_binary_return_by_ref', + op=torch.cuda.jiterator._create_multi_output_jit_fn( + """ + template + void binary_return_by_ref(T i0, T i1, T& out0) { + out0 = i0 + i1; + } + """, + num_outputs=1), + ref=operator.add, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-0.42), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + supports_rhs_python_scalar=False, + decorators=[onlyCUDA], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + OpInfo( + 'jiterator_2inputs_2outputs', + op=torch.cuda.jiterator._create_multi_output_jit_fn( + """ + template + void binary_2outputs(T i0, T i1, T& out0, T& out1) { + out0 = i0 + i1; + out1 = i0 - i1; + } + """, + num_outputs=2), + ref=lambda i0, i1, *, alpha=1: (i0 + i1, i0 - i1), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + decorators=[onlyCUDA], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + # `torch.norm` has multiple code paths depending on the value of `p`. + # These paths have different dtype support. Also JIT supports, + # most variants but not all of them. So we split the OpInfo entries, + # for `norm` based on the code-paths and JIT support. + OpInfo( + "norm", + sample_inputs_func=sample_inputs_norm, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # TODO Benchmark again with the new implementation + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + check_batched_forward_grad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Dispatches in Python to vector_norm. Not sure how to make this test happy + # Happens to pass on complex64. Also a mystery + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float32,)),) + ), + OpInfo('norm', + variant_test_name='nuc', + sample_inputs_func=sample_inputs_norm_nuc, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + check_batched_gradgrad=False, + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients + # got: Could not allocate memory to change Tensor SizesAndStrides! + check_batched_forward_grad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_and_complex_types(), + dtypesIfCUDA=floating_and_complex_types(), + skips=( + # Dispatches in Python to matrix_norm. Not sure how to make this test happy + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.complex64, torch.float32,)),) + ), + OpInfo('norm', + variant_test_name='fro', + sample_inputs_func=sample_inputs_norm_fro, + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients + # got: Could not allocate memory to change Tensor SizesAndStrides! + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + skips=( + # MPS has some mild accuracy issues for float16. We divide the tolerances by 10 + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-4, rtol=0.01)}), + 'TestConsistency', + 'test_output_match', + + ), + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + # Dispatches in Python to vector_norm. Not sure how to make this test happy + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.complex64, torch.float32,)),) + ), + OpInfo( + "norm", + variant_test_name="inf", + sample_inputs_func=sample_inputs_norm_inf, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + # fast gradcheck produces NaNs + gradcheck_fast_mode=False, + skips=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda', + ), + # Dispatches in Python to vector_norm. Not sure how to make this test happy + # Happens to pass on complex64. Also a mystery + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float32,)) + ), + ), + OpInfo('t', + sample_inputs_func=sample_inputs_t, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + assert_autodiffed=True, + error_inputs_func=error_inputs_t), + OpInfo('t_copy', + sample_inputs_func=sample_inputs_t, + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + assert_autodiffed=True, + error_inputs_func=error_inputs_t), + OpInfo( + "nn.functional.dropout", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs), + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Probably because we have used lambda for the op here + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # inplace variant dispatches to dropout kernel, while on CUDA + # the op dispatches to _fused_dropout (with a few more conditions) + # hence, different values and this skip here + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + supports_out=False, + sample_inputs_func=sample_inputs_dropout, + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs, inplace=True)), + OpInfo( + "native_dropout_backward", + op=torch.ops.aten.native_dropout_backward.default, + aten_name="native_dropout_backward", + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_dropout_backward, + skips=( + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + # Lazy tensor failures + DecorateInfo(unittest.skip('Skipped!'), 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + # These tests fail only when built with ASAN + DecorateInfo(unittest.skip("Fails with ASAN"), 'TestLazyOpInfo', 'test_correctness', active_if=TEST_WITH_ASAN), + DecorateInfo( + unittest.skip("Fails with ASAN"), + 'TestLazyOpInfo', + 'test_correctness_with_reusing_ir', + active_if=TEST_WITH_ASAN + ), + ), + ), + OpInfo( + "nn.functional.dropout2d", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs), + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + check_batched_forward_grad=False, + # As per the docs, valid input dims are (3, 4) + sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(3, 4)), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.dropout3d", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs), + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + check_batched_forward_grad=False, + # As per the docs, valid input dims are (4, 5) + sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(4, 5)), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.alpha_dropout", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs), + dtypes=floating_types_and(torch.float16, torch.bfloat16), + gradcheck_wrapper=wrapper_set_seed, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_inputs_func=sample_inputs_dropout, + check_batched_forward_grad=False, + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs, inplace=True), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # AssertionError: Tensor-likes are not close! + # Fails in cuda11.7 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='cuda'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='xpu'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),), + # In training mode, feature_alpha_dropout currently doesn't support inputs of complex dtype + # unlike when `train=False`, it supports complex inputs, hence 2 OpInfos to cover all cases + OpInfo( + "nn.functional.feature_alpha_dropout", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs), + variant_test_name="with_train", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: + # vmap: We do not yet support calling random operations inside of vmap. + # Please perform random operations outside of vmap as a workaround + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_forward_mode_AD"), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_inplace_forward_mode_AD"), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + # As per the docs, valid input dims are (4, 5) + sample_inputs_func=partial(sample_inputs_dropout, train=True, valid_input_dim=(4, 5)), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.feature_alpha_dropout", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs), + variant_test_name="without_train", + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),), + gradcheck_wrapper=wrapper_set_seed, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_inputs_func=partial(sample_inputs_dropout, train=False), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.one_hot", + ref=reference_one_hot, + supports_out=False, + dtypes=_dispatch_dtypes((torch.int64,)), + sample_inputs_func=sample_inputs_one_hot, + ), + OpInfo( + "nn.functional.embedding", + aten_backward_name="embedding_dense_backward", + # We use lambda to reshuffle the positional arguments. + # This is because currently only the `input` field of SampleInput + # is tested in gradient tests. + op=lambda weight, idx, **kwargs: torch.nn.functional.embedding(idx, weight, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_embedding, + allow_cow_input_materialize_forward=[0], + error_inputs_func=error_inputs_embedding, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Fails on CI https://github.com/pytorch/pytorch/issues/85377 + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), + # Reference: https://github.com/pytorch/pytorch/issues/67084 + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'), + # Not a problem: embedding does weird stuff to its input (it renormalizes) + DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), + # Fails due to non-determinism (see issue #74679) + # TODO: Investigate why more granular skips in the test don't work in CI + DecorateInfo(unittest.skip('Skipped!'), + 'TestExpandedWeightFunctional', + 'test_expanded_weight_forward'), + ), + supports_expanded_weight=True, + supports_out=False, + ), + OpInfo( + "nn.functional.embedding_bag", + # We use lambda to reshuffle the positional arguments. + # This is because currently only the `input` field of SampleInput + # is tested in gradient tests. + op=lambda weight, idx, **kwargs: torch.nn.functional.embedding_bag(idx, weight, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + # backward is not supported for mode `max` and dtype `bfloat16` + backward_dtypesIfCUDA=floating_types_and(torch.float16), + sample_inputs_func=sample_inputs_embedding_bag, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Not a problem: embedding_bag does weird stuff to its input (it renormalizes) + DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False, + supports_gradgrad=False, + allow_cow_input_materialize_forward=[0], + ), + OpInfo( + "nn.functional.multi_head_attention_forward", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.multi_head_attention_forward, input, *args, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_multi_head_attention_forward, + skips=( + # Tensor-likes are not close + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', dtypes=(torch.float32,)), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-3, rtol=0)}), 'TestDecomp', 'test_comprehensive'), + + # TODO skip this for now since we can't skip on runtime arch support (taken from scaled_dot_product_attention) + DecorateInfo(unittest.skip("Skipped!"), 'TestInductorOpInfo', 'test_comprehensive'), + # randomness + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # lambda impl + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # tests running very slowly break slow tests, so we skip them instead of using `slowTest`. + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'), + DecorateInfo( + unittest.skip("Skipped - baddbmm decomp does not have enough precision for 16-bit float"), + 'TestDecomp', + 'test_comprehensive', + dtypes=(torch.bfloat16, torch.float16), + ), + DecorateInfo( + unittest.skip("Skipped - baddbmm decomp does not have enough precision for 16-bit float"), + 'TestDecomp', + 'test_quick', + dtypes=(torch.bfloat16, torch.float16))), + supports_out=False, + supports_gradgrad=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + ), + UnaryUfuncInfo( + "nn.functional.softplus", + aten_backward_name='softplus_backward', + ref=reference_softplus, + sample_kwargs=lambda device, dtype, input: ({'beta': 3, 'threshold': .2}, {'beta': 3, 'threshold': .2}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'beta': 3, 'threshold': .2}), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + decorators=( + DecorateInfo( + toleranceOverride + ({ + torch.half: tol(atol=1e-2, rtol=1e-2), + torch.bfloat16: tol(atol=1e-2, rtol=1e-2), + }), + 'TestUnaryUfuncs'), + ), + ), + OpInfo( + "nn.functional.mse_loss", + aten_backward_name='mse_loss_backward', + ref=loss_reference_reduction_wrapper(lambda input, target: (input - target) ** 2), + sample_inputs_func=sample_inputs_loss, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), + ), + ), + OpInfo( + "nn.functional.grid_sample", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_grid_sample, + reference_inputs_func=reference_inputs_grid_sample, + supports_gradgrad=False, + gradcheck_nondet_tol=1e-15), + # TODO: delete this OpInfo once we add meta support for grid_sampler_3d + OpInfo( + "grid_sampler_2d", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_grid_sampler_2d, + supports_gradgrad=False, + gradcheck_nondet_tol=1e-15, + skips=( + DecorateInfo(slowTest, 'TestDecomp', 'test_comprehensive', dtypes=(torch.float32, torch.float64), + active_if=IS_WINDOWS), + ),), + # TODO: Remove grid_sampler_3d tests once `nn.functional.grid_sample` has + # MPS support for all cases. + OpInfo( + "grid_sampler_3d", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_grid_sampler_3d, + supports_gradgrad=False, + gradcheck_nondet_tol=1e-15, + skips=( + # NOTE: Only run on MPS + DecorateInfo(unittest.skip('Skipped!'), device_type='cpu'), + DecorateInfo(unittest.skip('Skipped!'), device_type='cuda'), + DecorateInfo(unittest.skip('Skipped!'), device_type='xpu'), + DecorateInfo(unittest.skip('Skipped!'), device_type='meta'), + ),), + OpInfo( + "argwhere", + ref=np.argwhere, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=False, + sample_inputs_func=sample_inputs_argwhere, + skips=( + # Compiler issue on ROCm. Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + ), + ), + ReductionOpInfo( + 'all', + identity=True, + supports_autograd=False, + result_dtype=torch.bool, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.all), + skips=( + # FIXME: uint8 input returns uint8 instead of bool + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), + ), + ), + ReductionOpInfo( + 'any', + identity=False, + supports_autograd=False, + result_dtype=torch.bool, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.any), + skips=( + # FIXME: uint8 input returns uint8 instead of bool + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), + ), + ), + ReductionOpInfo( + 'amax', + nan_policy='propagate', + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + ref=reference_reduction_numpy(np.amax), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + error_inputs_func=error_inputs_aminmax_amax_amin, + ), + ReductionOpInfo( + 'amin', + nan_policy='propagate', + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + ref=reference_reduction_numpy(np.amin), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + error_inputs_func=error_inputs_aminmax_amax_amin, + ), + ReductionOpInfo( + 'argmax', + supports_multiple_dims=False, + supports_autograd=False, + assert_jit_shape_analysis=True, + result_dtype=torch.int64, + dtypes=all_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.argmax, supports_keepdims=False), + ), + ReductionOpInfo( + 'argmin', + supports_multiple_dims=False, + supports_autograd=False, + result_dtype=torch.int64, + dtypes=all_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.argmin, supports_keepdims=False), + ), + ReductionOpInfo( + 'count_nonzero', + identity=0, + supports_out=False, + supports_autograd=False, + result_dtype=torch.int64, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_reduction_count_nonzero, + ref=reference_reduction_numpy(np.count_nonzero), + skips=( + # FIXME: count_nonzero does not accept keepdim kwarg + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_single_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_unsorted_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_offbounds_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + ), + ), + ReductionOpInfo( + 'mean', + nan_policy='propagate', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # FIXME: mean needs 'dim' parameter when using the 'out' overload. + # Adding it with 'generate_args_kwargs' does not work, since these also get passed + # onto the reference implementations. + supports_out=True, + assert_autodiffed=True, + assert_jit_shape_analysis=True, + promotes_int_to_float=True, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.mean), + error_inputs_func=error_inputs_mean, + skips=( + # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result + # of dtype torch.float32 into an out= with dtype torch.long + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='cuda', dtypes=[torch.float32]), + # FIXME: mean does not support passing keepdim without passing dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: mean reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values', + device_type='cuda', dtypes=[torch.complex64]), + ), + ), + ReductionOpInfo( + 'nanmean', + nan_policy='omit', + assert_autodiffed=True, + promotes_int_to_float=True, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True), + ref=reference_reduction_numpy(np.nanmean), + skips=( + # AssertionError: False is not true : + # Failure in testing nodes' autodifferentiation. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # FIXME: prod reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', + device_type='cuda', dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values', + device_type='cuda', dtypes=[torch.complex64]), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-5, rtol=4e-2)}), + "TestConsistency", "test_output_match", device_type="mps"), + ), + ), + ReductionOpInfo( + 'std', + nan_policy='propagate', + supports_out=True, + complex_to_real=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + promotes_int_to_float=True, + check_batched_forward_grad=False, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + ref=reference_std_var(np.std), + generate_args_kwargs=generate_std_var_kwargs, + skips=( + # FIXME: cannot specify keepdim without dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=(torch.float16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', + dtypes=(torch.float16,)), + ), + ), + ReductionOpInfo( + 'std', + variant_test_name='unbiased', + nan_policy='propagate', + supports_out=False, + complex_to_real=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + promotes_int_to_float=True, + check_batched_forward_grad=False, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var_unbiased, + skips=( + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionOpInfo( + 'var', + nan_policy='propagate', + supports_out=True, + assert_autodiffed=True, + promotes_int_to_float=True, + complex_to_real=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + ref=reference_std_var(np.var), + generate_args_kwargs=generate_std_var_kwargs, + skips=( + # FIXME: cannot specify keepdim without dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values'), + # NumPy is giving NaN for this + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_large_input'), + ), + ), + ReductionOpInfo( + 'var', + variant_test_name='unbiased', + nan_policy='propagate', + supports_out=False, + complex_to_real=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + promotes_int_to_float=True, + check_batched_forward_grad=False, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var_unbiased, + skips=( + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionOpInfo( + 'prod', + identity=1, + nan_policy='propagate', + supports_multiple_dims=False, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_int64=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_prod, + ref=prod_numpy, + skips=( + # FIXME: prod does not support passing keepdim without passing dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: prod reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: prod does not support passing None to dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16, torch.complex64]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', + dtypes=[torch.uint8, torch.float16, torch.complex64]), + # FIXME: ValueError: The data in MaskedTensor a and Tensor b do not match + DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all', + dtypes=[torch.float16]), + ), + ), + ReductionOpInfo( + 'sum', + identity=0, + nan_policy='propagate', + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_int64=True, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + ref=reference_reduction_numpy(np.sum), + error_inputs_sparse_func=error_inputs_sparse_reduction_sum, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_bsc), + skips=( + # FIXME: sum does not support passing keepdim without passing dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: sum reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', + dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all', + dtypes=[torch.float32]), + ), + ), + ReductionOpInfo( + 'nansum', + identity=0, + nan_policy='omit', + supports_out=True, + promotes_int_to_int64=True, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True), + ref=reference_reduction_numpy(np.nansum), + skips=( + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # FIXME: nansum reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: flaky test so skipped instead of xfailed + # possibly bad low precision reference in numpy + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-3, rtol=4e-2)}), + "TestConsistency", "test_output_match", device_type="mps"), + ), + ), + ReductionOpInfo( + 'hash_tensor', + result_dtype=torch.uint64, + supports_autograd=False, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), + ref=reference_hash_tensor, + skips=( + # hash_tensor reduces all dimensions when dim=[] (as do sum, prod etc.) + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # aten::hash_tensor hit the vmap fallback which is currently disabled + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + # NYI + DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'), + # Sharding strategy NYI + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + ) + ), + OpInfo( + "nn.functional.ctc_loss", + dtypes=floating_types(), + supports_out=False, + sample_inputs_func=sample_inputs_ctc_loss, + # gradcheck_wrapper, see https://github.com/pytorch/pytorch/issues/52241 + gradcheck_wrapper=gradcheck_wrapper_ctc_loss, + skips=( + # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented + DecorateInfo( + unittest.expectedFailure, + "TestBwdGradients", + "test_fn_gradgrad", + dtypes=(torch.float64,), + ), + # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32,), + ), + # Ref: https://github.com/pytorch/pytorch/issues/85231 + DecorateInfo(unittest.skip("Fails with ASAN"), + 'TestProxyTensorOpInfo', + 'test_make_fx_fake_exhaustive', active_if=TEST_WITH_ASAN), + ), + ), + OpInfo( + "nn.functional.cosine_embedding_loss", + dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-4, rtol=2e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type="cuda", + ), + ], + sample_inputs_func=sample_inputs_cosine_embedding_loss, + ), + OpInfo( + "nn.functional.nll_loss", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_nll_loss, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + skips=( + # RuntimeError: + # undefined value tensor: + # File "", line 3 + # def the_method(i0, i1): + # return torch.nn.functional.nll_loss(i0, i1, weight=tensor([8.4784, 1.7658, 4.3228], dtype=torch.float32)) + # ~~~~~~ <--- HERE + DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), + # Fails for unknown reason: https://github.com/pytorch/pytorch/issues/120782 + DecorateInfo( + unittest.skip("Skipped!"), + "TestCompositeCompliance", + "test_cow_input", + device_type='cuda', + ), + DecorateInfo(unittest.skip("FP16 nll_loss cases have not been enabled on MPS yet"), + dtypes=(torch.half,), device_type="mps"), + + ), + ), + OpInfo( + "nn.functional.gaussian_nll_loss", + dtypes=floating_types_and(torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_gaussian_nll_loss, + error_inputs_func=error_inputs_gaussian_nll_loss, + skips=( + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=2e-3)}), + "TestConsistency", "test_output_match", device_type="mps"), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ), + ), + OpInfo( + "nn.functional.hinge_embedding_loss", + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_hinge_embedding_loss, + error_inputs_func=error_inputs_hinge_embedding_loss, + reference_inputs_func=reference_inputs_hinge_embedding_loss, + ), + OpInfo( + "nn.functional.huber_loss", + aten_backward_name='huber_loss_backward', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_huber_loss, + error_inputs_func=error_inputs_huber_loss, + skips=( + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), + ) + ), + OpInfo( + "nn.functional.pdist", + ref=reference_pdist, + sample_inputs_func=sample_inputs_pdist, + dtypes=floating_types(), + supports_out=False, + supports_gradgrad=False, + skips=( + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + ) + ), + OpInfo( + "nn.functional.poisson_nll_loss", + dtypes=all_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_poisson_nll_loss, + error_inputs_func=error_inputs_poisson_nll_loss, + ), + OpInfo( + "argsort", + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_sort, + supports_out=False, + supports_autograd=False, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_non_standard_bool_values", + dtypes=[torch.bool], + device_type='cuda', + active_if=not TEST_WITH_ROCM + ), + ), + ), + OpInfo( + "repeat_interleave", + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_repeat_interleave, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32, torch.complex64), + ), + ), + ), + OpInfo( + "nn.functional.pairwise_distance", + ref=lambda a, b, p=2.0, eps=1e-6, keepdim=False: ( + np.sum(np.abs(a - b + eps) ** p, axis=-1, keepdims=keepdim) ** (1 / p) + ), + sample_inputs_func=sample_inputs_pairwise_distance, + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32, torch.complex64), + ), + ), + ), + OpInfo( + "nn.functional.pixel_shuffle", + sample_inputs_func=sample_inputs_pixel_shuffle, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32, torch.complex64), + ), + ), + ), + OpInfo( + "nn.functional.pixel_unshuffle", + sample_inputs_func=sample_inputs_pixel_unshuffle, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32, torch.complex64), + ), + ), + ), + OpInfo( + "nn.functional.channel_shuffle", + sample_inputs_func=sample_inputs_channel_shuffle, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + allow_cow_input_materialize_forward=[0], + allow_cow_input_materialize_backward=[0, 'output grad 0'], + skips=( + # Skip due to NotImplementedError for MPS device. + DecorateInfo(unittest.expectedFailure, 'TestConsistency'), + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + ), + ), + OpInfo( + "nn.functional.kl_div", + sample_inputs_func=sample_inputs_kl_div, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + OpInfo( + "diagflat", + ref=lambda input, offset=0: np.diagflat(input, k=offset), + sample_inputs_func=sample_inputs_diagflat, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + ), + OpInfo( + 'scatter_reduce', + variant_test_name='sum', + inplace_variant=torch.Tensor.scatter_reduce_, + # complex not added to dtypes as complex gradients are not properly handled + # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter_reduce, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + ), + ), + OpInfo( + 'scatter_reduce', + variant_test_name='prod', + # complex not added to dtypes as complex gradients are not properly handled + # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_scatter_reduce, + skips=( + # Not implemented + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + ), + ), + OpInfo( + 'scatter_reduce', + variant_test_name='mean', + # complex not added to dtypes as complex gradients are not properly handled + # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter_reduce, + ), + OpInfo( + 'scatter_reduce', + variant_test_name='amin', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter_reduce, + ), + OpInfo( + 'scatter_reduce', + variant_test_name='amax', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter_reduce, + ), + OpInfo( + '_segment_reduce', + aten_name='segment_reduce', + variant_test_name='lengths', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented + supports_gradgrad=False, + sample_inputs_func=sample_inputs_segment_reduce, + skips=( + # FIXME: CUDA driver API confirmed a leak in + # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32 + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + ), + ), + OpInfo( + '_segment_reduce', + aten_name='segment_reduce', + variant_test_name='offsets', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented + supports_gradgrad=False, + sample_inputs_func=partial(sample_inputs_segment_reduce, mode='offsets'), + skips=( + # FIXME: CUDA driver API confirmed a leak in + # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32 + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + ), + ), +] +op_db += opinfo.definitions.op_db + + +# Separate registry for experimental Python Reference OpInfos. +python_ref_db = [ + # + # Elementwise Unary OpInfos + # + ElementwiseUnaryPythonRefInfo( + "_refs.abs", + torch_opinfo_name="abs", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/49224 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + dtypes=[torch.int8], active_if=TEST_WITH_ASAN), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.acos", + torch_opinfo_name="acos", + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.acosh", + torch_opinfo_name="acosh", + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.asin", + torch_opinfo_name="asin", + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}), + 'TestUnaryUfuncs', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=5e-05, rtol=2e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu' + ), + precisionOverride({torch.bfloat16: 1e-2}), + ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.asinh", + torch_opinfo_name="asinh", + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_normal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + PythonRefInfo( + "_refs.lerp", + torch_opinfo_name="lerp", + ), + PythonRefInfo( + "_refs.ones", + torch_opinfo_name="ones", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.zeros", + torch_opinfo_name="zeros", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.cauchy", + torch_opinfo_name="cauchy", + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: cauchy is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: cauchy is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip("Expected: cauchy is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + ) + ), + PythonRefInfo( + "_refs.exponential", + torch_opinfo_name="exponential", + supports_out=True, + decorators=( + # dtypes that do not support check_uniform_bounds of rand_like + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'), + + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: exponential is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: exponential is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip("Expected: exponential is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + ) + ), + PythonRefInfo( + "_refs.geometric", + torch_opinfo_name="geometric", + supports_out=True, + decorators=( + # dtypes that do not support check_uniform_bounds of rand_like + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)), + + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: geometric is not comparable"), + 'TestCommon', + 'test_python_ref_executor', device_type='cuda'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: geometric is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: geometric is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: geometric is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + ) + ), + PythonRefInfo( + "_refs.log_normal", + torch_opinfo_name="log_normal", + supports_out=True, + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), + 'TestCommon', + 'test_python_ref_executor', device_type='cuda'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + ) + ), + PythonRefInfo( + "_refs.normal", + torch_opinfo_name="normal", + supports_out=True, + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ) + ), + PythonRefInfo( + "_refs.normal", + torch_opinfo_name="normal", + torch_opinfo_variant_name="number_mean", + supports_out=True, + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ) + ), + PythonRefInfo( + "_refs.normal_", + op=torch.Tensor.normal_, + torch_opinfo_name="normal", + torch_opinfo_variant_name="in_place", + supports_out=False, + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ) + ), + PythonRefInfo( + "_refs.arange", + torch_opinfo_name="arange", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.linspace", + torch_opinfo_name="linspace", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # cpu implementation is wrong on some integral types + # https://github.com/pytorch/pytorch/issues/81996 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), + + # cuda implementation is off-by-one on some inputs due to precision issues + # https://github.com/pytorch/pytorch/issues/82230 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + ), + ), + PythonRefInfo( + "_refs.linspace", + torch_opinfo_name="linspace", + torch_opinfo_variant_name="tensor_overload", + skips=( + # TypeError: 'int' object is not subscriptable + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + + # cpu implementation is wrong on some integral types + # https://github.com/pytorch/pytorch/issues/81996 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), + + # cuda implementation is off-by-one on some inputs due to precision issues + # https://github.com/pytorch/pytorch/issues/82230 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + # TODO torch.ops.aten.copy is not in _refs + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.float32, torch.float64, torch.float16, torch.complex64, torch.complex128, torch.bfloat16), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.float32, torch.float64, torch.float16, torch.complex64, torch.complex128, torch.bfloat16), + device_type="cpu"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + ), + ), + PythonRefInfo( + "_refs.logspace", + torch_opinfo_name="logspace", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + ), + ), + PythonRefInfo( + "_refs.logspace", + torch_opinfo_name="logspace", + torch_opinfo_variant_name="tensor_overload", + skips=( + # TypeError: 'int' object is not subscriptable + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + # TODO copy doesn't have prim refs + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=( + torch.float32, torch.float64, torch.float16, torch.complex64, + torch.complex128, torch.bfloat16, torch.int8, torch.uint8 + ), + device_type="cuda" + ), + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=( + torch.float32, torch.float64, torch.float16, + torch.complex64, torch.complex128, torch.bfloat16, + torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8 + ), + device_type="cpu"), + ), + ), + PythonRefInfo( + "_refs.meshgrid", + torch_opinfo_name="meshgrid", + torch_opinfo_variant_name="variadic_tensors", + ), + PythonRefInfo( + "_refs.take_along_dim", + torch_opinfo_name="take_along_dim", + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestCommon', + 'test_python_ref'), + ), + ), + PythonRefInfo( + "_refs.to", + torch_opinfo_name="to", + ), + PythonRefInfo( + "_refs.triu", + torch_opinfo_name="triu", + ), + PythonRefInfo( + "_refs.tril", + torch_opinfo_name="tril", + ), + PythonRefInfo( + "_refs.triu_indices", + torch_opinfo_name="triu_indices", + # the implementation uses torch.stack that violates view consistency + validate_view_consistency=False, + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), + )), + PythonRefInfo( + "_refs.tril_indices", + torch_opinfo_name="tril_indices", + # the implementation uses torch.stack that violates view consistency + validate_view_consistency=False, + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), + )), + PythonRefInfo( + "_refs.meshgrid", + torch_opinfo_name="meshgrid", + torch_opinfo_variant_name="list_of_tensors", + ), + PythonRefInfo( + "_refs.movedim", + aliases=('moveaxis',), + torch_opinfo_name="movedim", + ), + PythonRefInfo( + "_refs.bucketize", + torch_opinfo_name="bucketize", + skips=( + # RuntimeError: It appears that you're trying to get value out of a tracing tensor with + # aten._local_scalar_dense.default - erroring out! [...] + # triggered by mid_val = boundaries[mid] + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref_executor"), + ) + ), + PythonRefInfo( + "_refs.equal", + torch_opinfo_name="equal", + skips=( + # RuntimeError: Cannot cast FakeTensor to number + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.atan", + torch_opinfo_name="atan", + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.atanh", + torch_opinfo_name="atanh", + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cfloat], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.bitwise_not", + torch_opinfo_name="bitwise_not", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.ceil", + torch_opinfo_name="ceil", + # Fails on int32 + # https://github.com/pytorch/pytorch/issues/85258 + ), + PythonRefInfo( + "_refs.item", + torch_opinfo_name="item", + skips=( + # RuntimeError: Cannot cast FakeTensor(FakeTensor(..., device='meta', size=()), cpu) to number + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'), + # ValueError: Can't convert a tensor with 10 elements to a number! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.conj_physical", + torch_opinfo_name="conj_physical", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.cos", + torch_opinfo_name="cos", + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', + active_if=IS_WINDOWS), + # This fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.cosh", + torch_opinfo_name="cosh", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48641 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.int8]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (6000,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.digamma", + torch_opinfo_name="digamma", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.erf", + torch_opinfo_name="erf", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.erfinv", + torch_opinfo_name="erfinv", + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2, + torch.float32: 1e-4}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611 + DecorateInfo( + unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo( + unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo( + unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.erfc", + torch_opinfo_name="erfc", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.exp", + torch_opinfo_name="exp", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48010 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.expm1", + torch_opinfo_name="expm1", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.exp2", + torch_opinfo_name="exp2", + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.cdouble]), + # Reference: https://github.com/pytorch/pytorch/issues/48010 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.fill", + torch_opinfo_name="fill", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.floor", + torch_opinfo_name="floor", + # Fails on int32 + # https://github.com/pytorch/pytorch/issues/85258 + ), + ElementwiseUnaryPythonRefInfo( + "_refs.frexp", + torch_opinfo_name="frexp", + # Skipped due to numerical failures on Windows CI. + # This is also skipped in frexp earlier in the file. + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.frac", + torch_opinfo_name="frac", + skips=( + DecorateInfo( + unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.imag", + torch_opinfo_name="imag", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isfinite", + torch_opinfo_name="isfinite", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isinf", + torch_opinfo_name="isinf", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isposinf", + torch_opinfo_name="isposinf", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isneginf", + torch_opinfo_name="isneginf", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isnan", + torch_opinfo_name="isnan", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isreal", + torch_opinfo_name="isreal", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.i0", + torch_opinfo_name="i0", + decorators=(precisionOverride({torch.bfloat16: 3e-1, + torch.float16: 5e-1}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), + 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.int8,)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.lgamma", + torch_opinfo_name="lgamma", + decorators=(precisionOverride({torch.float16: 7e-1}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.multigammaln", + torch_opinfo_name="mvlgamma", + torch_opinfo_variant_name="mvlgamma_p_1", + skips=skips_mvlgamma(), + decorators=( + DecorateInfo(torch.testing._internal.common_utils.markDynamoStrictTest, 'TestUnaryUfuncs', + 'test_reference_numerics_large'), + DecorateInfo(torch.testing._internal.common_utils.xfailIfTorchDynamo, 'TestUnaryUfuncs', + 'test_reference_numerics_large'), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.multigammaln", + torch_opinfo_name="mvlgamma", + torch_opinfo_variant_name="mvlgamma_p_3", + skips=skips_mvlgamma(), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.multigammaln", + torch_opinfo_name="mvlgamma", + torch_opinfo_variant_name="mvlgamma_p_5", + skips=skips_mvlgamma(), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.log", + torch_opinfo_name="log", + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.log1p", + torch_opinfo_name="log1p", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.log10", + torch_opinfo_name="log10", + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.log2", + torch_opinfo_name="log2", + decorators=(precisionOverride({torch.bfloat16: 1e-1}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble]), + ), + ), + PythonRefInfo( + "_refs.logsumexp", + torch_opinfo_name="logsumexp", + # When keepdim=False logsumexp function uses squeeze operation + # that is not yet exposed in nvFuser's Python API. + ), + PythonRefInfo( + "_refs.log_softmax", + torch_opinfo_name="log_softmax", + torch_opinfo_variant_name="with_dtype", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nan_to_num", + torch_opinfo_name="nan_to_num", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.neg", + torch_opinfo_name="neg", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.positive", + torch_opinfo_name="positive", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.real", + torch_opinfo_name="real", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.reciprocal", + torch_opinfo_name="reciprocal", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/45690 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.round", + torch_opinfo_name="round", + # Fails on int32 + # https://github.com/pytorch/pytorch/issues/85258 + skips=( + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), + "TestUnaryUfuncs", "test_reference_numerics_extremal", + device_type="cuda"), + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), + "TestUnaryUfuncs", "test_reference_numerics_normal", + device_type="cuda"), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.rsqrt", + torch_opinfo_name="rsqrt", + decorators=(precisionOverride({torch.half: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble)), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (700,) (up to 0.01 allowed) + # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.chalf,)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sigmoid", + torch_opinfo_name="sigmoid", + aliases=('_refs.special.expit',), + # Reference: https://github.com/pytorch/pytorch/issues/56012 + handles_complex_extremal_values=False, + handles_large_floats=False, + decorators=(precisionOverride({torch.float16: 1e-2, + torch.complex64: 1e-1, + torch.bfloat16: 1e-2}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/56012 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.complex64, torch.cdouble], device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda') + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sign", + torch_opinfo_name="sign", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/41245 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.bfloat16, torch.float16, torch.float32, + torch.float64]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sgn", + torch_opinfo_name="sgn", + # This is an issue with the vectorised abs on CPU + handles_complex_extremal_values=False, + handles_large_floats=False, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/41245 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.bfloat16, torch.float16, torch.float32, + torch.float64]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.signbit", + torch_opinfo_name="signbit", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sin", + torch_opinfo_name="sin", + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + # Fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sinc", + torch_opinfo_name="sinc", + decorators=(precisionOverride({torch.bfloat16: 1e-2, + torch.float16: 1e-2}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/49133 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + dtypes=[torch.cfloat]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sinh", + torch_opinfo_name="sinh", + decorators=(precisionOverride({torch.float16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cdouble,)), + # Reference: https://github.com/pytorch/pytorch/issues/48641 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.int8]), + ), + ), + PythonRefInfo( + "_refs.softmax", + torch_opinfo_name="softmax", + torch_opinfo_variant_name="with_dtype", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sqrt", + torch_opinfo_name="sqrt", + decorators=( + precisionOverride({torch.bfloat16: 7e-2}), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestUnaryUfuncs', 'test_reference_numerics_large'), + ), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/47358 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=(torch.cfloat, torch.cdouble), + active_if=IS_MACOS), + # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.bfloat16,)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.square", + torch_opinfo_name="square", + decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),), + skips=( + # AssertionError: Reference result was farther (2.2417024338305655e-07) from the precise computation + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor', dtypes=(torch.complex64,)), + # Reference: https://github.com/pytorch/pytorch/issues/52549 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.tan", + torch_opinfo_name="tan", + decorators=[ + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=1e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), + ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.tanh", + torch_opinfo_name="tanh", + decorators=[ + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=2e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), + ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.trunc", + torch_opinfo_name="trunc", + # Fails on int32 + # https://github.com/pytorch/pytorch/issues/85258 + ), + PythonRefInfo( + "_refs.special.log_softmax", + torch_opinfo_name="log_softmax", # alias + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + PythonRefInfo( + "_refs.special.softmax", + torch_opinfo_name="softmax", # alias + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + # + # Elementwise Unary Special OpInfos + # + ElementwiseUnaryPythonRefInfo( + "_refs.special.logit", + torch_opinfo_name="logit", + ), + # + # Elementwise Unary nn.functional OpInfos + # + PythonRefInfo( + "_refs.nn.functional.alpha_dropout", + torch_opinfo_name="nn.functional.alpha_dropout", + decorators=( + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref_executor', device_type='cuda'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.celu", + torch_opinfo_name="nn.functional.celu", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.channel_shuffle", + torch_opinfo_name="nn.functional.channel_shuffle", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.threshold", + torch_opinfo_name="nn.functional.threshold", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.dropout", + torch_opinfo_name="nn.functional.dropout", + decorators=( + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # dropout is not comparable + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.elu", + torch_opinfo_name="nn.functional.elu", + supports_out=True, + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1.2e-03), + torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.hardtanh", + torch_opinfo_name="nn.functional.hardtanh", + supports_out=True, + ), + PythonRefInfo( # TODO: Port this to an UnaryOpInfo + "_refs.nn.functional.gelu", + torch_opinfo_name="nn.functional.gelu", + ), + PythonRefInfo( + "_refs.nn.functional.layer_norm", + torch_opinfo_name="nn.functional.layer_norm", + skips=( + # Reference result was farther (3.5762786809723224e-07) from the precise computation + # than the torch result was (2.5068410824946596e-07)! + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.float32,), device_type='cpu'), + ), + ), + PythonRefInfo( + "_refs.nn.functional.glu", + torch_opinfo_name="nn.functional.glu", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.pairwise_distance", + torch_opinfo_name="nn.functional.pairwise_distance", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.pdist", + torch_opinfo_name="nn.functional.pdist", + supports_out=True, + skips=( + # RunTimeError: no _refs support for torch.Tensor.index_select + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + # Reference result was farther (1.946091651916504e-05) from the precise + # computation than the torch result was (1.1920928955078125e-06)! + DecorateInfo( + unittest.expectedFailure, + 'TestCommon', + 'test_python_ref_torch_fallback', + dtypes=(torch.float32,), + device_type='cpu', + ), + )), + PythonRefInfo( + "_refs.nn.functional.leaky_relu", + torch_opinfo_name="nn.functional.leaky_relu", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.log_softmax", + torch_opinfo_name="log_softmax", # alias + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + PythonRefInfo( + "_refs.nn.functional.pixel_shuffle", + torch_opinfo_name="nn.functional.pixel_shuffle", + ), + PythonRefInfo( + "_refs.nn.functional.pixel_unshuffle", + torch_opinfo_name="nn.functional.pixel_unshuffle", + ), + PythonRefInfo( + "_refs.nn.functional.poisson_nll_loss", + torch_opinfo_name="nn.functional.poisson_nll_loss", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.prelu", + torch_opinfo_name="nn.functional.prelu", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.relu", + torch_opinfo_name="nn.functional.relu", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.relu6", + torch_opinfo_name="nn.functional.relu6", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.mish", + torch_opinfo_name="nn.functional.mish", + supports_out=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), + 'TestUnaryUfuncs',), ], + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.selu", + torch_opinfo_name="nn.functional.selu", + supports_out=True, + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-2, rtol=1.8e-2), + torch.bfloat16: tol(atol=1e-2, rtol=1.8e-2) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + PythonRefInfo( + "_refs.nn.functional.softmax", + torch_opinfo_name="softmax", # alias + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + PythonRefInfo( + "_refs.nn.functional.softmin", + torch_opinfo_name="nn.functional.softmin", + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.softplus", + torch_opinfo_name="nn.functional.softplus", + ), + PythonRefInfo( + "_refs.nn.functional.l1_loss", + torch_opinfo_name="nn.functional.l1_loss", + ), + PythonRefInfo( + "_refs.nn.functional.margin_ranking_loss", + torch_opinfo_name="nn.functional.margin_ranking_loss", + ), + PythonRefInfo( + "_refs.nn.functional.mse_loss", + torch_opinfo_name="nn.functional.mse_loss", + ), + PythonRefInfo( + "_refs.nn.functional.smooth_l1_loss", + torch_opinfo_name="nn.functional.smooth_l1_loss", + ), + PythonRefInfo( + "_refs.nn.functional.hinge_embedding_loss", + torch_opinfo_name="nn.functional.hinge_embedding_loss" + ), + PythonRefInfo( + "_refs.nn.functional.nll_loss", + torch_opinfo_name="nn.functional.nll_loss", + # The corresponding PyTorch op doesn't support out. But the ref is + # registered as a decomp and ATen has an out variant. + supports_out=True, + # For simpler indexing, we flatten target indices, then reshape the result tensor. + # This creates inconsistent view state with reference impl. + validate_view_consistency=False, + skips=( + # RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', device_type="cuda" + ), + ), + ), + PythonRefInfo( + "_refs.nn.functional.huber_loss", + torch_opinfo_name="nn.functional.huber_loss", + # The corresponding PyTorch op doesn't support out. But the ref is + # registered as a decomp and ATen has an out variant. + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.tanhshrink", + torch_opinfo_name="nn.functional.tanhshrink", + decorators=[ + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_normal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo( + toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02), + torch.complex64: tol(atol=6e-04, rtol=1e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), + ], + skips=( + # in each case, pytorch will produce a nan while numpy will not + DecorateInfo(unittest.skip("Fails on some jobs works on others!"), + 'TestUnaryUfuncs', "test_reference_numerics_large", + dtypes=(torch.complex64, torch.complex128), + active_if=(IS_MACOS)), + DecorateInfo(unittest.skip("Fails on some jobs works on others!"), + 'TestUnaryUfuncs', "test_reference_numerics_extremal", + dtypes=(torch.complex64, torch.complex128), + device_type='cpu', + active_if=(IS_MACOS or IS_WINDOWS)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.hardshrink", + torch_opinfo_name="nn.functional.hardshrink", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.softshrink", + torch_opinfo_name="nn.functional.softshrink", + ), + # + # Elementwise Binary Reference OpInfos + # + ElementwiseBinaryPythonRefInfo( + "_refs.add", + torch_opinfo_name="add", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + ), + skips=( + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=(torch.complex64, torch.complex128)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.atan2", + torch_opinfo_name="atan2", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_and", + torch_opinfo_name="bitwise_and", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_left_shift", + torch_opinfo_name="bitwise_left_shift", + skips=( + # https://github.com/pytorch/pytorch/issues/70904 + DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_right_shift", + torch_opinfo_name="bitwise_right_shift", + skips=( + # # https://github.com/pytorch/pytorch/issues/70904 + DecorateInfo(unittest.skip("Skipped some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_or", + torch_opinfo_name="bitwise_or", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_xor", + torch_opinfo_name="bitwise_xor", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.copysign", + torch_opinfo_name="copysign", + skips=( + # RuntimeError: Expected divisor (b) to be on the same device (cuda:0) as dividend (a), but it is found on cpu! + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + # FIXME output 0: meta disagrees with real impl + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs.div", + torch_opinfo_name="div", + torch_opinfo_variant_name="no_rounding_mode", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + skips=( + # NotImplementedError: argument of type: + DecorateInfo( + unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32, torch.complex64, torch.complex128,) + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32,), device_type="cuda" + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32,), device_type="cuda" + ), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.div", + torch_opinfo_name="div", + torch_opinfo_variant_name="trunc_rounding", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + decorators=( + # See https://github.com/pytorch/pytorch/issues/111126 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.div", + torch_opinfo_name="div", + torch_opinfo_variant_name="floor_rounding", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + decorators=( + # See https://github.com/pytorch/pytorch/issues/111126 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + # Reference result was farther (nan) from the precise computation than the + # torch result was (inf)! + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_python_ref", + dtypes=(torch.bfloat16,), + device_type="cpu", + active_if=not IS_S390X, + ), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.eq", + torch_opinfo_name="eq", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.float_power", + torch_opinfo_name="float_power", + skips=( + # Test doesn't account for float -> double type promotion + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + # Complex values error with: Greatest absolute difference: nan at index + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=[torch.complex64, torch.complex128]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_large_values', + dtypes=[torch.complex64, torch.complex128]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=[torch.complex64, torch.complex128]), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.logaddexp", + torch_opinfo_name="logaddexp", + skips=( + # failure due to mismatch in edge cases, which boils down to what torch.exp(inf + infj) should be + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + ), + ), + PythonRefInfo( + "_refs.logaddexp2", + torch_opinfo_name="logaddexp2", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.floor_divide", + torch_opinfo_name="floor_divide", + rhs_make_tensor_kwargs=dict(exclude_zero=True), + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + # bfloat16 floor_divide compared with a float32 reference works inconsistently + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.bfloat16,)), + # bfloat16 floor_divide compared with a float32 reference works inconsistently + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', + dtypes=(torch.bfloat16,)), + # int8 floor divide has different results for -128 // -1 vs. NumPy + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.int8,)), + # The following tests fails on some jobs + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=(torch.float16,)), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + # FIXME output 0: meta disagrees with real impl + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.fmax", + torch_opinfo_name="fmax", + supports_rhs_python_scalar=False, + ), + ElementwiseBinaryPythonRefInfo( + "_refs.fmin", + torch_opinfo_name="fmin", + supports_rhs_python_scalar=False, + ), + ElementwiseBinaryPythonRefInfo( + "_refs.fmod", + torch_opinfo_name="fmod", + rhs_make_tensor_kwargs={'exclude_zero': True}, + supports_rhs_python_scalar=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.bfloat16,), device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.bfloat16,), device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_contig_vs_every_other', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_non_contig', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.gcd", + torch_opinfo_name="gcd", + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.int8,)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.ge", + torch_opinfo_name="ge", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.gt", + torch_opinfo_name="gt", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.heaviside", + torch_opinfo_name="heaviside", + supports_rhs_python_scalar=False, + skips=( + # PyTorch's heaviside does not appear to propagate NaNs + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.hypot", + torch_opinfo_name="hypot", + supports_rhs_python_scalar=False, + ), + ElementwiseBinaryPythonRefInfo( + "_refs.igamma", + torch_opinfo_name="igamma", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.igammac", + torch_opinfo_name="igammac", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.isclose", + torch_opinfo_name="isclose", + skips=( + # Intentional xfail -- isclose does not type promote + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.lcm", + torch_opinfo_name="lcm", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.le", + torch_opinfo_name="le", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.logical_and", + torch_opinfo_name="logical_and", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.logical_not", + torch_opinfo_name="logical_not", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.logical_or", + torch_opinfo_name="logical_or", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.logical_xor", + torch_opinfo_name="logical_xor", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.lt", + torch_opinfo_name="lt", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.maximum", + torch_opinfo_name="maximum", + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.minimum", + torch_opinfo_name="minimum", + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.mul", + torch_opinfo_name="mul", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + skips=( + # Reference result was farther (0.0) from the precise computation + # than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32,), + ), + # Reference result was farther (0.0) from the precise computation + # than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32,), device_type='cuda' + ), + # Reference result was farther (0.0) from the precise computation + # than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32,), device_type='cuda' + ), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs.ne", + torch_opinfo_name="ne", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.nextafter", + torch_opinfo_name="nextafter", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.pow", + torch_opinfo_name="pow", + decorators=( + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05), + torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_scalar_support'), + ), + skips=( + # Reference result was farther (inf) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32,), + ), + # Reference result was farther (inf) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32,), device_type="cuda" + ), + # Reference result was farther (inf) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32,), device_type="cuda" + ), + # Skipping integers because they are being raised to negative powers causing an error + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=[torch.int8, torch.int16, torch.int32, torch.int64]), + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', + 'test_reference_numerics_large_values', + dtypes=[torch.int16, torch.int32, torch.int64]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_large_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.remainder", + torch_opinfo_name="remainder", + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.bfloat16,), device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.bfloat16,), device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.rsub", + torch_opinfo_name="rsub", + # https://github.com/pytorch/pytorch/issues/76944 + skips=( + # Reference result was farther (nan) from the precise computation than + # the torch result was (nan)! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.chalf,), device_type='cpu'), + # Reference result was farther (nan) from the precise computation than + # the torch result was (nan)! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.chalf,), device_type='cpu'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.sub", + torch_opinfo_name="sub", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0), + torch.bfloat16: tol(atol=1e-5, rtol=5e-3), + torch.complex32: tol(atol=1e-5, rtol=1e-3)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), + 'TestDecomp', 'test_comprehensive', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), + 'TestDecomp', 'test_quick', device_type='cpu'), + ), + skips=( + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.uint8,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.true_divide", + torch_opinfo_name="true_divide", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + skips=( + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32,), + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32,), device_type="cuda" + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32,), device_type="cuda" + ), + ), + ), + # + # Elementwise Ternary Reference OpInfos + # + PythonRefInfo( + "_refs.addcdiv", + torch_opinfo_name="addcdiv", + ), + PythonRefInfo( + "_refs.addcmul", + torch_opinfo_name="addcmul", + skips=( + # Reference result was farther (1.3343989849090576e-05) + # from the precise computation than the torch result + # was (9.592622518539429e-06)! + # FIXME: enable dtype-based tolerances in test_ops.py:TestCommon._ref_test_helper + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.float16,), device_type="cpu"), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.float16,), device_type="cpu"), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.clamp_min", + torch_opinfo_name="clamp_min", + skips=( + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.clamp_max", + torch_opinfo_name="clamp_max", + skips=( + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.clamp", + torch_opinfo_name="clamp", + ), + PythonRefInfo( + "_refs.nn.functional.triplet_margin_loss", + torch_opinfo_name="nn.functional.triplet_margin_loss", + supports_out=False, + # TODO: Uses minimum and clamp + skips=( + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: 6.103515625e-05 at index (4,) (up to 1e-05 allowed) + # Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed) + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.uint8,), device_type="cpu"), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs.xlogy", + torch_opinfo_name="xlogy", + supports_one_python_scalar=True, + ), + # + # Elementwise Binary Special OpInfos + # + ElementwiseBinaryPythonRefInfo( + "_refs.special.xlog1py", + torch_opinfo_name="special.xlog1py", + supports_one_python_scalar=True, + ), + # + # Data Conversion & Data Movement Opinfos + # + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.bfloat16", + torch_opinfo_name="bfloat16", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.bool", + torch_opinfo_name="bool", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.byte", + torch_opinfo_name="byte", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.char", + torch_opinfo_name="char", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs._conversions.complex", + torch_opinfo_name="complex", + error_inputs_func=partial(error_inputs_complex, is_ref=True), + skips=( + # Tests don't account for complex's type promotion semantics + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs._conversions.polar", + torch_opinfo_name="polar", + skips=( + # Tests don't account for complex's type promotion semantics + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.double", + torch_opinfo_name="double", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.float", + torch_opinfo_name="float", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.half", + torch_opinfo_name="half", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.int", + torch_opinfo_name="int", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.long", + torch_opinfo_name="long", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.short", + torch_opinfo_name="short", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.chalf", + torch_opinfo_name="chalf", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.cfloat", + torch_opinfo_name="cfloat", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.cdouble", + torch_opinfo_name="cdouble", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.clone", + torch_opinfo_name="clone", + ), + # + # View & Shape OpInfos + # + PythonRefInfo( + "_refs.alias_copy", + torch_opinfo_name="alias_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.atleast_1d", + torch_opinfo_name="atleast_1d", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.atleast_2d", + torch_opinfo_name="atleast_2d", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.atleast_3d", + torch_opinfo_name="atleast_3d", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.as_strided", + torch_opinfo_name="as_strided", + # FIXME: doesn't support chalf + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.as_strided_copy", + torch_opinfo_name="as_strided_copy", + supports_out=True, + # FIXME: doesn't support chalf + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'), + # The view function this decompose into does not have a ref + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"), + ), + ), + PythonRefInfo( + "_refs.as_strided", + torch_opinfo_name="as_strided", + torch_opinfo_variant_name="partial_views", + # FIXME: doesn't support chalf + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.as_strided_scatter", + torch_opinfo_name="as_strided_scatter", + # returns a view of an intermediate tensor (as_strided) + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.block_diag", + torch_opinfo_name="block_diag", + ), + PythonRefInfo( + "_refs.broadcast_shapes", + torch_opinfo_name="broadcast_shapes", + ), + PythonRefInfo( + "_refs.broadcast_tensors", + torch_opinfo_name="broadcast_tensors", + ), + PythonRefInfo( + "_refs.broadcast_to", + torch_opinfo_name="broadcast_to", + ), + PythonRefInfo( + "_refs.cat", + torch_opinfo_name="cat", + skips=( + # FIXME: AssertionError: RuntimeError not raised + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.chunk", + torch_opinfo_name="chunk", + ), + PythonRefInfo( + "_refs.column_stack", + torch_opinfo_name="column_stack", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.conj", + torch_opinfo_name="conj", + ), + PythonRefInfo( + "_refs.constant_pad_nd", + torch_opinfo_name="constant_pad_nd", + ), + PythonRefInfo( + "_refs.contiguous", + torch_opinfo_name="contiguous", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.deg2rad", + torch_opinfo_name="deg2rad", + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 7e-1}),), + ), + PythonRefInfo( + "_refs.dsplit", + torch_opinfo_name="dsplit", + ), + PythonRefInfo( + "_refs.diag", + torch_opinfo_name="diag", + ), + PythonRefInfo( + "_refs.diagonal", + torch_opinfo_name="diagonal", + ), + PythonRefInfo( + "_refs.diagonal_copy", + torch_opinfo_name="diagonal_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.diagonal_scatter", + torch_opinfo_name="diagonal_scatter", + supports_out=True, + # returns a view of an intermediate tensor (as_strided) + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.diag_embed", + torch_opinfo_name="diag_embed", + supports_out=True, + ), + PythonRefInfo( + "_refs.dstack", + torch_opinfo_name="dstack", + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.expand", + torch_opinfo_name="expand", + ), + PythonRefInfo( + "_refs.expand_as", + torch_opinfo_name="expand_as", + ), + PythonRefInfo( + "_refs.expand_copy", + torch_opinfo_name="expand_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.flatten", + torch_opinfo_name="flatten", + ), + PythonRefInfo( + "_refs.flip", + torch_opinfo_name="flip", + ), + PythonRefInfo( + "_refs.fliplr", + torch_opinfo_name="fliplr", + ), + PythonRefInfo( + "_refs.flipud", + torch_opinfo_name="flipud", + ), + PythonRefInfo( + "_refs.hstack", + torch_opinfo_name="hstack", + skips=( + # https://github.com/pytorch/pytorch/issues/78613 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.narrow", + torch_opinfo_name="narrow", + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=True), + ), + PythonRefInfo( + "_refs.narrow_copy", + torch_opinfo_name="narrow_copy", + supports_out=True, + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=True), + skips=( + # The view function this decompose into does not have a ref + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"), + ), + ), + PythonRefInfo( + "_refs.nn.functional.group_norm", + torch_opinfo_name="nn.functional.group_norm", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.native_layer_norm", + torch_opinfo_name="native_layer_norm", + skips=( + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref", + device_type="cpu", dtypes=(torch.float32,)), + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref_torch_fallback", + device_type="cpu", dtypes=(torch.float32,)), + ), + ), + PythonRefInfo( + "_refs.permute", + torch_opinfo_name="permute", + ), + PythonRefInfo( + "_refs.permute_copy", + torch_opinfo_name="permute_copy", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.rad2deg", + torch_opinfo_name="rad2deg", + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 7e-1}),), + ), + PythonRefInfo( + "_refs.ravel", + torch_opinfo_name="ravel", + ), + PythonRefInfo( + "_refs.renorm", + torch_opinfo_name="renorm", + ), + PythonRefInfo( + "_refs.repeat", + torch_opinfo_name="repeat", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.reshape", + torch_opinfo_name="reshape", + ), + PythonRefInfo( + "_refs.reshape_as", + torch_opinfo_name="reshape_as", + ), + PythonRefInfo( + "_refs.roll", + torch_opinfo_name="roll", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.rot90", + torch_opinfo_name="rot90", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.select_scatter", + torch_opinfo_name="select_scatter", + ), + PythonRefInfo( + "_refs.stack", + torch_opinfo_name="stack", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.squeeze", + torch_opinfo_name="squeeze", + ), + PythonRefInfo( + "_refs.squeeze_copy", + torch_opinfo_name="squeeze_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.squeeze", + torch_opinfo_name="squeeze", + torch_opinfo_variant_name="multiple", + ), + PythonRefInfo( + "_refs.tensor_split", + torch_opinfo_name="tensor_split", + skips=( + # RuntimeError: no _refs support for torch.Tensor.tolist + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + ), + ), + PythonRefInfo( + "_refs.hsplit", + torch_opinfo_name="hsplit", + ), + PythonRefInfo( + "_refs.vsplit", + torch_opinfo_name="vsplit", + ), + PythonRefInfo( + "_refs.dot", + torch_opinfo_name="dot", + error_inputs_func=partial(error_inputs_dot_vdot, is_ref=True), + # .conj() does not set ._is_view() correctly in ATen + validate_view_consistency=False, + skips=( + # RuntimeError: no _refs support for torch.Tensor.is_conj + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', dtypes=[torch.complex64, torch.complex128]), + ), + ), + PythonRefInfo( + "_refs.vdot", + torch_opinfo_name="vdot", + error_inputs_func=partial(error_inputs_dot_vdot, is_ref=True), + # .conj() does not set ._is_view() correctly in ATen + validate_view_consistency=False, + skips=( + # RuntimeError: no _refs support for torch.Tensor.is_conj + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', dtypes=[torch.complex64, torch.complex128]), + ), + ), + PythonRefInfo( + "_refs.transpose", + torch_opinfo_name="transpose", + ), + PythonRefInfo( + "_refs.transpose_copy", + torch_opinfo_name="transpose_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.t", + torch_opinfo_name="t", + ), + PythonRefInfo( + "_refs.t_copy", + torch_opinfo_name="t_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.T", + torch_opinfo_name="T", + error_inputs_func=partial(error_inputs_T, has_ndims_error=True), + ), + PythonRefInfo( + "_refs.unbind_copy", + torch_opinfo_name="unbind_copy", + ), + PythonRefInfo( + "_refs.unfold", + torch_opinfo_name="unfold", + ), + PythonRefInfo( + "_refs.unfold_copy", + torch_opinfo_name="unfold_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.unsqueeze", + torch_opinfo_name="unsqueeze", + ), + PythonRefInfo( + "_refs.unsqueeze_copy", + torch_opinfo_name="unsqueeze_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.view", + torch_opinfo_name="view", + ), + PythonRefInfo( + "_refs.view_as", + torch_opinfo_name="view_as", + ), + PythonRefInfo( + "_refs.view_copy", + torch_opinfo_name="view_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.vstack", + torch_opinfo_name="vstack", + skips=( + # https://github.com/pytorch/pytorch/issues/78613 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.unflatten", + torch_opinfo_name="unflatten", + ), + PythonRefInfo( + "_refs.unbind", + torch_opinfo_name="unbind", + ), + # + # Reduction Reference OpInfos + # + ReductionPythonRefInfo( + "_refs.all", + torch_opinfo_name="all", + skips=( + # FIXME: uint8 input returns uint8 instead of bool + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_result_dtype', + dtypes=[torch.uint8]), + ), + ), + ReductionPythonRefInfo( + "_refs.amax", + torch_opinfo_name="amax", + error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionPythonRefInfo( + "_refs.amin", + torch_opinfo_name="amin", + error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionPythonRefInfo( + "_refs.any", + torch_opinfo_name="any", + skips=( + # FIXME: uint8 input returns uint8 instead of bool + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_result_dtype', + dtypes=[torch.uint8]), + ), + ), + ReductionPythonRefInfo( + "_refs.count_nonzero", + torch_opinfo_name="count_nonzero", + skips=( + # FIXME: count_nonzero does not accept keepdim kwarg + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', + 'test_dim_default_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_single_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', + 'test_dim_multi_unsorted_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + ), + ), + ReductionPythonRefInfo( + "_refs.mean", + torch_opinfo_name="mean", + supports_out=True, + error_inputs_func=partial(error_inputs_mean, is_ref=True), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionPythonRefInfo( + "_refs.std", + torch_opinfo_name="std", + supports_out=True, + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=(torch.float16,)), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', + 'test_ref_duplicate_values', + dtypes=(torch.float16,)), + ), + ), + # std_mean and var_mean are not ReductionInfos + PythonRefInfo( + "_refs.std_mean", + torch_opinfo_name="std_mean", + ), + ReductionPythonRefInfo( + "_refs.sum", + torch_opinfo_name="sum", + supports_out=True, + skips=( + # FIXME: doesn't test out behavior properly for this operator + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # FIXME: mean reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', + 'test_ref_duplicate_values', + dtypes=[torch.float16]), + DecorateInfo( + unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all', + dtypes=[torch.float32]), + ), + ), + PythonRefInfo( + "_refs.cumsum", + torch_opinfo_name="cumsum", + supports_out=True, + skips=( + # doesn't test out behavior properly for this operator + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + ), + ), + PythonRefInfo( + "_refs.cumprod", + torch_opinfo_name="cumprod", + supports_out=True, + skips=( + # doesn't test out behavior properly for this operator + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + ), + ), + PythonRefInfo( + "_refs.sum_to_size", + torch_opinfo_name="sum_to_size", + validate_view_consistency=False, + ), + ReductionPythonRefInfo( + "_refs.prod", + torch_opinfo_name="prod", + supports_out=True, + supports_multiple_dims=True, + skips=( + # FIXME: doesn't test out behavior properly for this operator + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16, torch.complex64]), + ), + ), + ReductionPythonRefInfo( + "_refs.var", + torch_opinfo_name="var", + supports_out=True, + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'), + ), + ), + PythonRefInfo( + "_refs.var_mean", + torch_opinfo_name="var_mean", + validate_view_consistency=False, + ), + # + # Linear Algebra Operators + # + PythonRefInfo( + "_refs.addr", + torch_opinfo_name="addr", + decorators=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',), + ), + ), + PythonRefInfo( + "_refs.trace", + torch_opinfo_name="trace", + ), + PythonRefInfo( + "_refs.norm", + torch_opinfo_name="norm", + supports_out=True, + # Uses vector_norm inside and vector_norm is affected by + # https://github.com/pytorch/pytorch/issues/77216 + validate_view_consistency=False, + ), + # + # Tensor Creation Reference OpInfos + # + PythonRefInfo( + "_refs.empty", + torch_opinfo_name="empty", + skips=( + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # FIXME: shouldn't check empty results + DecorateInfo(unittest.skip("Can't check result for empty"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.empty_like", + torch_opinfo_name="empty_like", + skips=( + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # FIXME: should not compare results of empty_like + DecorateInfo(unittest.skip("Can't check result for empty_like"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.randn", + torch_opinfo_name="randn", + op=lambda *args, **kwargs: wrapper_set_seed(refs.randn, *args, **kwargs), + skips=( + # see https://github.com/pytorch/pytorch/issues/85121 + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), + 'TestCommon', + 'test_python_ref_executor'), + # These tests expect the input to be a tensor or a sequence of tensors + DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.eye", + torch_opinfo_name="eye", + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + ), + ), + PythonRefInfo( + "_refs.new_empty", + torch_opinfo_name="new_empty", + skips=( + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # FIXME: should not compare results of empty_like + DecorateInfo(unittest.skip("Can't check result for new_empty"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.new_empty_strided", + torch_opinfo_name="new_empty_strided", + skips=( + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_neg_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + + ), + ), + PythonRefInfo( + "_refs.empty_strided", + torch_opinfo_name="empty_strided", + skips=( + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_neg_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.new_full", + torch_opinfo_name="new_full", + ), + PythonRefInfo( + "_refs.new_ones", + torch_opinfo_name="new_ones", + ), + PythonRefInfo( + "_refs.new_zeros", + torch_opinfo_name="new_zeros", + ), + # + # Conditional Reference OpInfos + # + PythonRefInfo( + "_refs.masked_fill", + torch_opinfo_name="masked_fill", + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.where", + torch_opinfo_name="where", + op=lambda self, condition, other: refs.where(condition, self, other), + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors', device_type='cuda'), + ), + ), + PythonRefInfo( + "_refs.index_select", + torch_opinfo_name="index_select", + # empty_strided + skips=( + # no _refs support for Tensor.__setitem__ + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + # Sample out= with a stride of zero. This _out operation checks that the input has no + # inner overlap + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),) + ), + PythonRefInfo( + "_refs.index_copy", + torch_opinfo_name="index_copy", + # empty_strided + skips=( + # no _refs support for Tensor.__setitem__ + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + ), + ), + PythonRefInfo( + "_refs.index_add", + torch_opinfo_name="index_add", + # empty_strided + skips=( + # no _refs support for Tensor.__setitem__ + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.index_fill", + torch_opinfo_name="index_fill", + # empty_strided + skips=( + # no _refs support for Tensor.__setitem__ + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),) + ), + # + # Test-related functions + # + PythonRefInfo( + "_refs.allclose", + torch_opinfo_name="allclose", + ), + # + # Misc functions + # + PythonRefInfo( + "_refs.stft", + torch_opinfo_name="stft", + skips=[ + # RuntimeError: no _refs support for aten.pad + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref' + ), + ], + ), + PythonRefInfo( + "_refs.istft", + torch_opinfo_name="istft", + skips=[ + # RuntimeError: no _refs support for aten.unfold_backward + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref' + ), + DecorateInfo( + unittest.skip("Expected: unfold_backward() got an unexpected keyword argument 'input_sizes'"), + 'TestCommon', + 'test_python_ref_executor', + dtypes=(torch.complex64, torch.complex128), + ), + ], + ), + PythonRefInfo( + "_refs.view_as_complex", + torch_opinfo_name="view_as_complex", + ), + PythonRefInfo( + "_refs.split_with_sizes", + torch_opinfo_name="split_with_sizes", + ), +] +python_ref_db += opinfo.definitions.python_ref_db + +# Common operator groupings +ops_and_refs = op_db + python_ref_db +unary_ufuncs = [op for op in ops_and_refs if isinstance(op, UnaryUfuncInfo)] +binary_ufuncs = [op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo)] +binary_ufuncs_and_refs = tuple(op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo)) +spectral_funcs = [op for op in ops_and_refs if isinstance(op, SpectralFuncInfo)] +sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse] +sparse_csr_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse_csr] +sparse_reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo) and op.supports_sparse] +shape_funcs = [op for op in ops_and_refs if isinstance(op, ShapeFuncInfo)] +reduction_ops = [op for op in ops_and_refs if isinstance(op, ReductionOpInfo)] +reference_filtered_ops = [op for op in reduction_ops if op.ref is not None] +reference_masked_ops = [op for op in reference_filtered_ops if op.name.startswith('masked.')] +sparse_masked_reduction_ops = [op for op in sparse_reduction_ops if op.name.startswith('masked.')] + +def index_variable(shape, max_indices, device=torch.device('cpu')): + if not isinstance(shape, tuple): + shape = (shape,) + return torch.testing.make_tensor(*shape, dtype=torch.long, device=device, low=0, high=max_indices) + +def gather_variable(shape, index_dim, max_indices, duplicate=False, device=torch.device('cpu')): + assert len(shape) == 2 + assert index_dim < 2 + batch_dim = 1 - index_dim + index = torch.zeros(*shape, dtype=torch.long, device=device) + for i in range(shape[index_dim]): + index.select(index_dim, i).copy_( + torch.randperm(max_indices, device=device)[:shape[batch_dim]]) + if duplicate: + index.select(batch_dim, 0).copy_(index.select(batch_dim, 1)) + return index + +def bernoulli_scalar(): + return torch.tensor(0, dtype=torch.bool).bernoulli_() + +def mask_not_all_zeros(shape): + assert len(shape) > 0 + while True: + result = torch.randn(shape).gt(0) + if result.sum() > 0: + return result + +# Copied from functorch +def xfail(op_name, variant_name='', *, device_type=None, dtypes=None): + return (op_name, variant_name, device_type, dtypes, True) + + +def skip(op_name, variant_name='', *, device_type=None, dtypes=None): + return (op_name, variant_name, device_type, dtypes, False) + + +def skipOps(test_case_name, base_test_name, to_skip): + all_opinfos = op_db + for xfail in to_skip: + op_name, variant_name, device_type, dtypes, expected_failure = xfail + matching_opinfos = [o for o in all_opinfos + if o.name == op_name and o.variant_test_name == variant_name] + assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}" + for op in matching_opinfos: + decorators = list(op.decorators) + if expected_failure: + decorator = DecorateInfo(unittest.expectedFailure, + test_case_name, base_test_name, + device_type=device_type, dtypes=dtypes) + decorators.append(decorator) + else: + decorator = DecorateInfo(unittest.skip("Skipped!"), + test_case_name, base_test_name, + device_type=device_type, dtypes=dtypes) + decorators.append(decorator) + op.decorators = tuple(decorators) + + # This decorator doesn't modify fn in any way + def wrapped(fn): + return fn + return wrapped diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_pruning.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_pruning.py new file mode 100644 index 0000000000000000000000000000000000000000..13cd86e05bd6f7b4e9515cf102cc1e6d3b49781d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_pruning.py @@ -0,0 +1,385 @@ +# Owner(s): ["module: unknown"] + +from typing import Any +from torch.ao.pruning import BaseSparsifier +import torch +import torch.nn.functional as F +from torch import nn + +class ImplementedSparsifier(BaseSparsifier): + def __init__(self, **kwargs: dict[str, Any]) -> None: + super().__init__(defaults=kwargs) + + def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: dict[str, Any]) -> None: + module.parametrizations.weight[0].mask[0] = 0 # type: ignore[index, union-attr] + linear_state = self.state['linear1.weight'] + linear_state['step_count'] = linear_state.get('step_count', 0) + 1 + + +class MockSparseLinear(nn.Linear): + """ + This class is a MockSparseLinear class to check convert functionality. + It is the same as a normal Linear layer, except with a different type, as + well as an additional from_dense method. + """ + @classmethod + def from_dense(cls, mod: nn.Linear) -> 'MockSparseLinear': + """ + """ + linear = cls(mod.in_features, + mod.out_features) + return linear + + +def rows_are_subset(subset_tensor: torch.Tensor, superset_tensor: torch.Tensor) -> bool: + """ + Checks to see if all rows in subset tensor are present in the superset tensor + """ + i = 0 + for row in subset_tensor: + while i < len(superset_tensor): + if not torch.equal(row, superset_tensor[i]): + i += 1 + else: + break + else: + return False + return True + + +class SimpleLinear(nn.Module): + r"""Model with only Linear layers without biases, some wrapped in a Sequential, + some following the Sequential. Used to test basic pruned Linear-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Linear(7, 5, bias=False), + nn.Linear(5, 6, bias=False), + nn.Linear(6, 4, bias=False), + ) + self.linear1 = nn.Linear(4, 4, bias=False) + self.linear2 = nn.Linear(4, 10, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.linear1(x) + x = self.linear2(x) + return x + + +class LinearBias(nn.Module): + r"""Model with only Linear layers, alternating layers with biases, + wrapped in a Sequential. Used to test pruned Linear-Bias-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Linear(7, 5, bias=True), + nn.Linear(5, 6, bias=False), + nn.Linear(6, 3, bias=True), + nn.Linear(3, 3, bias=True), + nn.Linear(3, 10, bias=False), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + return x + + +class LinearActivation(nn.Module): + r"""Model with only Linear layers, some with bias, some in a Sequential and some following. + Activation functions modules in between each Linear in the Sequential, and each outside layer. + Used to test pruned Linear(Bias)-Activation-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Linear(7, 5, bias=True), + nn.ReLU(), + nn.Linear(5, 6, bias=False), + nn.Tanh(), + nn.Linear(6, 4, bias=True), + ) + self.linear1 = nn.Linear(4, 3, bias=True) + self.act1 = nn.ReLU() + self.linear2 = nn.Linear(3, 10, bias=False) + self.act2 = nn.Tanh() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.linear1(x) + x = self.act1(x) + x = self.linear2(x) + x = self.act2(x) + return x + + +class LinearActivationFunctional(nn.Module): + r"""Model with only Linear layers, some with bias, some in a Sequential and some following. + Activation functions modules in between each Linear in the Sequential, and functional + activationals are called in between each outside layer. + Used to test pruned Linear(Bias)-Activation-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Linear(7, 5, bias=True), + nn.ReLU(), + nn.Linear(5, 6, bias=False), + nn.ReLU(), + nn.Linear(6, 4, bias=True), + ) + self.linear1 = nn.Linear(4, 3, bias=True) + self.linear2 = nn.Linear(3, 8, bias=False) + self.linear3 = nn.Linear(8, 10, bias=False) + self.act1 = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.linear1(x) + x = F.relu(x) + x = self.linear2(x) + x = F.relu(x) + x = self.linear3(x) + x = F.relu(x) + return x + + +class SimpleConv2d(nn.Module): + r"""Model with only Conv2d layers, all without bias, some in a Sequential and some following. + Used to test pruned Conv2d-Conv2d fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, 3, 1, bias=False), + nn.Conv2d(32, 64, 3, 1, bias=False), + ) + self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False) + self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = self.conv2d2(x) + return x + + +class Conv2dBias(nn.Module): + r"""Model with only Conv2d layers, some with bias, some in a Sequential and some outside. + Used to test pruned Conv2d-Bias-Conv2d fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, 3, 1, bias=True), + nn.Conv2d(32, 32, 3, 1, bias=True), + nn.Conv2d(32, 64, 3, 1, bias=False), + ) + self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=True) + self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = self.conv2d2(x) + return x + + +class Conv2dActivation(nn.Module): + r"""Model with only Conv2d layers, some with bias, some in a Sequential and some following. + Activation function modules in between each Sequential layer, functional activations called + in-between each outside layer. + Used to test pruned Conv2d-Bias-Activation-Conv2d fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, 3, 1, bias=True), + nn.ReLU(), + nn.Conv2d(32, 64, 3, 1, bias=True), + nn.Tanh(), + nn.Conv2d(64, 64, 3, 1, bias=False), + nn.ReLU(), + ) + self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False) + self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = F.relu(x) + x = self.conv2d2(x) + x = F.hardtanh(x) + return x + + +class Conv2dPadBias(nn.Module): + r"""Model with only Conv2d layers, all with bias and some with padding > 0, + some in a Sequential and some following. Activation function modules in between each layer. + Used to test that bias is propagated correctly in the special case of + pruned Conv2d-Bias-(Activation)Conv2d fusion, when the second Conv2d layer has padding > 0.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, 3, 1, padding=1, bias=True), + nn.ReLU(), + nn.Conv2d(32, 32, 3, 1, bias=False), + nn.ReLU(), + nn.Conv2d(32, 32, 3, 1, padding=1, bias=True), + nn.ReLU(), + nn.Conv2d(32, 32, 3, 1, padding=1, bias=True), + nn.ReLU(), + nn.Conv2d(32, 64, 3, 1, bias=True), + nn.Tanh(), + ) + self.conv2d1 = nn.Conv2d(64, 48, 3, 1, padding=1, bias=True) + self.act1 = nn.ReLU() + self.conv2d2 = nn.Conv2d(48, 52, 3, 1, padding=1, bias=True) + self.act2 = nn.Tanh() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = self.act1(x) + x = self.conv2d2(x) + x = self.act2(x) + return x + + +class Conv2dPool(nn.Module): + r"""Model with only Conv2d layers, all with bias, some in a Sequential and some following. + Activation function modules in between each layer, Pool2d modules in between each layer. + Used to test pruned Conv2d-Pool2d-Conv2d fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=True), + nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True), + nn.Tanh(), + nn.AvgPool2d(kernel_size=2, stride=2, padding=1), + ) + self.conv2d1 = nn.Conv2d(64, 48, kernel_size=3, padding=1, bias=True) + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1) + self.af1 = nn.ReLU() + self.conv2d2 = nn.Conv2d(48, 52, kernel_size=3, padding=1, bias=True) + self.conv2d3 = nn.Conv2d(52, 52, kernel_size=3, padding=1, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = self.maxpool(x) + x = self.af1(x) + x = self.conv2d2(x) + x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=1) + x = F.relu(x) + x = self.conv2d3(x) + return x + + +class Conv2dPoolFlattenFunctional(nn.Module): + r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d + and a functional Flatten followed by a Linear layer. + Activation functions and Pool2ds in between each layer also. + Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True), + nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True), + nn.Tanh(), + nn.AvgPool2d(kernel_size=2, stride=2, padding=1), + ) + self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True) + self.af1 = nn.ReLU() + self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True) + self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(11, 13, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1) + x = self.af1(x) + x = self.conv2d2(x) + x = self.avg_pool(x) + x = torch.flatten(x, 1) # test functional flatten + x = self.fc(x) + return x + + +class Conv2dPoolFlatten(nn.Module): + r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d + and a Flatten module followed by a Linear layer. + Activation functions and Pool2ds in between each layer also. + Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True), + nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True), + nn.Tanh(), + nn.AvgPool2d(kernel_size=2, stride=2, padding=1), + ) + self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True) + self.af1 = nn.ReLU() + self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True) + self.avg_pool = nn.AdaptiveAvgPool2d((2, 2)) + self.flatten = nn.Flatten() + self.fc = nn.Linear(44, 13, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1) + x = self.af1(x) + x = self.conv2d2(x) + x = self.avg_pool(x) + x = self.flatten(x) + x = self.fc(x) + return x + + +class LSTMLinearModel(nn.Module): + """Container module with an encoder, a recurrent module, and a linear.""" + + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int + ) -> None: + super().__init__() + self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers) + self.linear = nn.Linear(hidden_dim, output_dim) + + def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + output, _hidden = self.lstm(input) + decoded = self.linear(output) + return decoded, output + + +class LSTMLayerNormLinearModel(nn.Module): + """Container module with an LSTM, a LayerNorm, and a linear.""" + + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int + ) -> None: + super().__init__() + self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers) + self.norm = nn.LayerNorm(hidden_dim) + self.linear = nn.Linear(hidden_dim, output_dim) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x, state = self.lstm(x) + x = self.norm(x) + x = self.linear(x) + return x, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_quantization.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4fab8c48bbd84c631838b76ff8d7535046a98b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_quantization.py @@ -0,0 +1,3415 @@ +# mypy: ignore-errors + +r"""Importing this file includes common utility methods and base classes for +checking quantization api and properties of resulting modules. +""" + +import torch +import torch.ao.nn.intrinsic.quantized.dynamic as nniqd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from functorch.experimental import control_flow +from torch.ao.nn.intrinsic import _FusedModule +from torch.ao.quantization import ( + convert, + default_dynamic_qat_qconfig, + default_dynamic_qconfig, + default_dynamic_quant_observer, + default_embedding_qat_qconfig, + default_observer, + default_per_channel_qconfig, + default_qconfig, + default_symmetric_qnnpack_qat_qconfig, + default_weight_observer, + DeQuantStub, + float_qparams_weight_only_qconfig, + get_default_qat_qconfig, + get_default_qat_qconfig_mapping, + get_default_qconfig, + get_default_qconfig_mapping, + PerChannelMinMaxObserver, + propagate_qconfig_, + QConfig, + QConfigMapping, + quantize, + quantize_dynamic_jit, + quantize_jit, + QuantStub, + QuantType, + QuantWrapper, +) +from torch.ao.quantization.backend_config import get_executorch_backend_config +from torch.ao.quantization.quantization_mappings import ( + get_default_dynamic_quant_module_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, +) +from torch.ao.quantization.quantize_pt2e import ( + _convert_to_reference_decomposed_fx, + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) + +from torch.export import export +from torch.jit.mobile import _load_for_lite_interpreter +from torch.testing._internal.common_quantized import override_quantized_engine +from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase + +try: + from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph + + # graph mode quantization based on fx + from torch.ao.quantization.quantize_fx import ( + convert_fx, + convert_to_reference_fx, + prepare_fx, + prepare_qat_fx, + ) + from torch.fx import GraphModule + from torch.fx.graph import Node + + HAS_FX = True +except ImportError: + HAS_FX = False + +import contextlib +import copy +import functools +import io +import os + +import unittest +from typing import Any, Optional, Union +from collections.abc import Callable + +import numpy as np +import torch._dynamo as torchdynamo +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq +import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq +from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer +from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer +from torch.testing import FileCheck + + +class NodeSpec: + """Used for checking GraphModule Node""" + + def __init__(self, op, target): + """ + op: call_function | call_module + target: + for call_function, target would be a function + for call_module, target would be the type of PyTorch module + """ + self.op = op + self.target = target + + @classmethod + def call_function(cls, target): + return NodeSpec("call_function", target) + + @classmethod + def call_method(cls, target): + return NodeSpec("call_method", target) + + @classmethod + def call_module(cls, target): + return NodeSpec("call_module", target) + + def __hash__(self): + return hash((self.op, self.target)) + + def __eq__(self, other): + if not isinstance(other, NodeSpec): + return NotImplemented + + return self.op == other.op and self.target == other.target + + def __repr__(self): + return repr(self.op) + " " + repr(self.target) + + +def get_supported_device_types(): + return ( + ["cpu", "cuda"] if torch.cuda.is_available() and not TEST_WITH_ROCM else ["cpu"] + ) + + +def test_only_eval_fn(model, calib_data): + r""" + Default evaluation function takes a torch.utils.data.Dataset or a list of + input Tensors and run the model on the dataset + """ + for inp in calib_data: + model(*inp) + + +_default_loss_fn = torch.nn.CrossEntropyLoss() + + +def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn): + r""" + Default train function takes a torch.utils.data.Dataset and train the model + on the dataset + """ + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + train_loss, correct, total = 0, 0, 0 + for _ in range(10): + model.train() + + for data, target in train_data: + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + train_loss += loss.item() + _, predicted = torch.max(output, 1) + total += target.size(0) + correct += (predicted == target).sum().item() + return train_loss, correct, total + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches): + model.train() + for cnt, (image, target) in enumerate(data_loader, start=1): + print(".", end="") + image, target = image.to(device), target.to(device) + output = model(image) + loss = criterion(output, target) + optimizer.zero_grad() + loss.backward() + optimizer.step() + accuracy(output, target, topk=(1, 5)) + if cnt >= ntrain_batches: + return + return + + +def ddp_setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + + +def ddp_cleanup(): + dist.destroy_process_group() + + +def run_ddp(rank, world_size, prepared): + ddp_setup(rank, world_size) + prepared.cuda() + prepared = torch.nn.parallel.DistributedDataParallel(prepared, device_ids=[rank]) + prepared.to(rank) + model_with_ddp = prepared + optimizer = torch.optim.SGD(model_with_ddp.parameters(), lr=0.0001) + train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1) # noqa: F821 + ddp_cleanup() + + +def convert_dynamic(module): + convert(module, get_default_dynamic_quant_module_mappings(), inplace=True) + + +def prepare_dynamic(model, qconfig_dict=None): + propagate_qconfig_(model, qconfig_dict) + + +def _make_conv_test_input( + batch_size, + in_channels_per_group, + input_feature_map_size, + out_channels_per_group, + groups, + kernel_size, + X_scale, + X_zero_point, + W_scale, + W_zero_point, + use_bias, + use_channelwise, +): + in_channels = in_channels_per_group * groups + out_channels = out_channels_per_group * groups + + (X_value_min, X_value_max) = (0, 4) + X_init = torch.randint( + X_value_min, + X_value_max, + ( + batch_size, + in_channels, + ) + + input_feature_map_size, + ) + X = X_scale * (X_init - X_zero_point).float() + X_q = torch.quantize_per_tensor( + X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8 + ) + + W_scale = W_scale * out_channels + W_zero_point = W_zero_point * out_channels + # Resize W_scale and W_zero_points arrays equal to out_channels + W_scale = W_scale[:out_channels] + W_zero_point = W_zero_point[:out_channels] + # For testing, we use small values for weights and for activations so that + # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in + # qconv implementation and if there is no overflow. + # In reference we can't exactly match the results with reference. + # Please see the comment in qconv implementation file + # aten/src/ATen/native/quantized/cpu/qconv.cpp for more details. + (W_value_min, W_value_max) = (-5, 5) + # The operator expects them in the format + # (out_channels, in_channels/groups,) + kernel_size + W_init = torch.randint( + W_value_min, + W_value_max, + ( + out_channels, + in_channels_per_group, + ) + + kernel_size, + ) + b_init = torch.randint(0, 10, (out_channels,)) + + if use_channelwise: + W_shape = (-1, 1) + (1,) * len(kernel_size) + W_scales_tensor = torch.tensor(W_scale, dtype=torch.float) + W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float) + W = ( + W_scales_tensor.reshape(*W_shape) + * (W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float() + ) + b = X_scale * W_scales_tensor * b_init.float() + W_q = torch.quantize_per_channel( + W, + W_scales_tensor.double(), + W_zero_points_tensor.long(), + 0, + dtype=torch.qint8, + ) + else: + W = W_scale[0] * (W_init - W_zero_point[0]).float() + b = X_scale * W_scale[0] * b_init.float() + W_q = torch.quantize_per_tensor( + W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8 + ) + + return (X, X_q, W, W_q, b if use_bias else None) + + +def _make_conv_add_extra_input_tensor(scale, zero_point, sizes): + (X_value_min, X_value_max) = (0, 4) + X_init = torch.randint( + X_value_min, + X_value_max, + sizes, # Infer the size of tensor to do the add + ) + X = scale * (X_init - zero_point).float() + X_q = torch.quantize_per_tensor( + X, scale=scale, zero_point=zero_point, dtype=torch.quint8 + ) + return X, X_q + + +def skipIfNoFBGEMM(fn): + reason = "Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer." + if isinstance(fn, type): + if "fbgemm" not in torch.backends.quantized.supported_engines: + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if "fbgemm" not in torch.backends.quantized.supported_engines: + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +def skipIfNoQNNPACK(fn): + reason = "Quantized operations require QNNPACK." + if isinstance(fn, type): + if "qnnpack" not in torch.backends.quantized.supported_engines: + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if "qnnpack" not in torch.backends.quantized.supported_engines: + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +def withQNNPACKBackend(fn): + # TODO(future PR): consider combining with skipIfNoQNNPACK, + # will require testing of existing callsites + reason = "Quantized operations require QNNPACK." + if isinstance(fn, type): + if "qnnpack" not in torch.backends.quantized.supported_engines: + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if "qnnpack" not in torch.backends.quantized.supported_engines: + raise unittest.SkipTest(reason) + with override_quantized_engine("qnnpack"): + fn(*args, **kwargs) + + return wrapper + + +def skipIfNoONEDNN(fn): + reason = "Quantized operations require ONEDNN." + if isinstance(fn, type): + if "onednn" not in torch.backends.quantized.supported_engines: + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if "onednn" not in torch.backends.quantized.supported_engines: + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +def skipIfNoONEDNNBF16(fn): + reason = "Quantized operations require BF16 support." + if isinstance(fn, type): + if not torch.ops.mkldnn._is_mkldnn_bf16_supported(): + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if not torch.ops.mkldnn._is_mkldnn_bf16_supported(): + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +def skipIfNoX86(fn): + reason = "Quantized operations require X86." + if isinstance(fn, type): + if "x86" not in torch.backends.quantized.supported_engines: + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if "x86" not in torch.backends.quantized.supported_engines: + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +def skipIfNoDynamoSupport(fn): + reason = "dynamo doesn't support." + if isinstance(fn, type): + if not torchdynamo.is_dynamo_supported(): + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if not torchdynamo.is_dynamo_supported(): + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +def skipIfNoInductorSupport(fn): + reason = "inductor doesn't support." + if isinstance(fn, type): + if not torchdynamo.is_inductor_supported(): + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if not torchdynamo.is_inductor_supported(): + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + + return wrapper + + +try: + import torchvision # noqa: F401 + + HAS_TORCHVISION = True +except ImportError: + HAS_TORCHVISION = False +skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + + +def get_script_module(model, tracing, data): + return torch.jit.trace(model, data) if tracing else torch.jit.script(model) + + +def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True): + """ + Convert lengths to offsets for embedding_bag + """ + tt = np.zeros((t.shape[0] + 1,), dtype=offset_type) + tt[1:] = t + tt = torch.from_numpy(np.cumsum(tt, dtype=offset_type)) + if use_begin_offset: + return tt[:-1] + return tt[1:] + + +def _group_quantize_tensor(w, n_bit=4, q_group_size=16): + assert w.dim() == 2 + w = w.transpose(0, 1).contiguous() + assert q_group_size > 1 + assert w.shape[-1] % q_group_size == 0 + + to_quant = w.reshape(-1, q_group_size) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + assert torch.isnan(scales).sum() == 0 + + zeros = min_val + scales * (2 ** (n_bit - 1)) + assert torch.isnan(zeros).sum() == 0 + + out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int) + assert torch.isnan(out).sum() == 0 + + out = out.to(dtype=torch.int32).reshape(w.shape) + if out.device != torch.device("cpu"): + out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8) + + # Scales and zeros for the same q-group should be contiguous, so we can + # load as a 32-bit word + scales = scales.view(w.shape[0], -1) + zeros = zeros.view(w.shape[0], -1) + scales_and_zeros = ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + return out, scales_and_zeros + + +def _group_quantize_tensor_symmetric(w, n_bit=4, groupsize=32): + # W is of shape [K x N] + # We transpose W as Quantization is applied on [N x K] + w = w.transpose(0, 1).contiguous() + assert w.dim() == 2 + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + # Calculate scale and zeros + to_quant = w.reshape(-1, groupsize) + max_val = to_quant.abs().amax(dim=1, keepdim=True) + eps = torch.finfo(max_val.dtype).eps + max_int = 2 ** (n_bit - 1) - 1 # For 4-bit, this is 7 + scales = max_val.clamp(min=eps) / max_int + zeros = torch.zeros_like(scales) + + # Quantize the weight + scales = scales.to(torch.float32).reshape(w.shape[0], -1) + zeros = zeros.to(torch.float32).reshape(w.shape[0], -1) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + max_int = 2**n_bit - 1 + w_int8 = to_quant.div(scales).add(8.5).to(torch.int8).clamp(max=max_int) + # We pack 2 signed int4 values in unsigned uint8 container. + # This reduces the weight size by half and improves load perf + out_uint8 = (w_int8[::, 1::2] << 4 | w_int8[::, ::2]).to(torch.uint8) + + scales_and_zeros = scales.squeeze().contiguous() + + return out_uint8, scales_and_zeros + + +def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): + # source: https://github.com/meta-pytorch/gpt-fast/blob/main/quantize.py + # default setup for affine quantization of activations + x_dtype = x.dtype + x = x.float() + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scales and zero_points based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scales = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scales is the same dtype as the original tensor + scales = torch.clamp(scales, min=eps).to(x.dtype) + zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scales/zp + x_div = x / scales.unsqueeze(-1) + x_round = torch.round(x_div) + x_zp = x_round + zero_points.unsqueeze(-1) + quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + + return quant, scales.to(x_dtype), zero_points + + +# QuantizationTestCase used as a base class for testing quantization on modules +class QuantizationTestCase(TestCase): + def setUp(self): + super().setUp() + self.calib_data = [[torch.rand(2, 5, dtype=torch.float)] for _ in range(2)] + self.train_data = [ + [ + torch.rand(2, 5, dtype=torch.float), + torch.randint(0, 1, (2,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)] for _ in range(2)] + self.img_data_2d = [ + [torch.rand(1, 3, 10, 10, dtype=torch.float)] for _ in range(2) + ] + self.img_data_3d = [ + [torch.rand(1, 3, 5, 5, 5, dtype=torch.float)] for _ in range(2) + ] + self.img_data_1d_train = [ + [ + torch.rand(2, 3, 10, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_2d_train = [ + [ + torch.rand(1, 3, 10, 10, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_3d_train = [ + [ + torch.rand(1, 3, 5, 5, 5, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + + self.img_data_dict = { + 1: self.img_data_1d, + 2: self.img_data_2d, + 3: self.img_data_3d, + } + + # Quant types that produce statically quantized ops + self.static_quant_types = [QuantType.STATIC, QuantType.QAT] + # All quant types for (fx based) graph mode quantization + self.all_quant_types = [QuantType.DYNAMIC, QuantType.STATIC, QuantType.QAT] + + def checkNoPrepModules(self, module): + r"""Checks the module does not contain child + modules for quantization preparation, e.g. + quant, dequant and observer + """ + self.assertFalse(hasattr(module, "quant")) + self.assertFalse(hasattr(module, "dequant")) + + def checkNoQconfig(self, module): + r"""Checks the module does not contain qconfig""" + self.assertFalse(hasattr(module, "qconfig")) + + for child in module.children(): + self.checkNoQconfig(child) + + def checkHasPrepModules(self, module): + r"""Checks the module contains child + modules for quantization preparation, e.g. + quant, dequant and observer + """ + self.assertTrue(hasattr(module, "module")) + self.assertTrue(hasattr(module, "quant")) + self.assertTrue(hasattr(module, "dequant")) + + def checkObservers( + self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None + ): + r"""Checks the module or module's leaf descendants + have observers in preparation for quantization + """ + if propagate_qconfig_list is None: + propagate_qconfig_list = get_default_qconfig_propagation_list() + if prepare_custom_config_dict is None: + prepare_custom_config_dict = {} + float_to_observed_module_class_mapping = prepare_custom_config_dict.get( + "float_to_observed_custom_module_class", {} + ) + + # check if a module is a leaf module, ignoring activation_post_process attribute + def is_leaf_module(module): + submodule_name_count = 0 + for name, _ in module.named_children(): + if name != "activation_post_process": + submodule_name_count += 1 + return submodule_name_count == 0 + + if ( + hasattr(module, "qconfig") + and module.qconfig is not None + and ( + ( + is_leaf_module(module) + and not isinstance(module, torch.nn.Sequential) + and type(module) in propagate_qconfig_list + ) + or type(module) in float_to_observed_module_class_mapping + ) + and not isinstance(module, torch.ao.quantization.DeQuantStub) + ): + self.assertTrue( + hasattr(module, "activation_post_process"), + "module: " + str(type(module)) + " do not have observer", + ) + # we don't need to check observers for child modules of the + # qat modules + if ( + type(module) not in get_default_qat_module_mappings().values() + and type(module) not in float_to_observed_module_class_mapping.values() + and not isinstance(module, _FusedModule) + ): + for child in module.children(): + if type(child) is nn.Dropout: + continue + self.checkObservers( + child, propagate_qconfig_list, prepare_custom_config_dict + ) + + def checkQuantDequant(self, mod): + r"""Checks that mod has nn.Quantize and + nn.DeQuantize submodules inserted + """ + self.assertEqual(type(mod.quant), nnq.Quantize) + self.assertEqual(type(mod.dequant), nnq.DeQuantize) + + def checkWrappedQuantizedLinear(self, mod): + r"""Checks that mod has been swapped for an nnq.Linear + module, the bias is qint32, and that the module + has Quantize and DeQuantize submodules + """ + self.assertEqual(type(mod.module), nnq.Linear) + self.checkQuantDequant(mod) + + def checkQuantizedLinear(self, mod): + self.assertEqual(type(mod), nnq.Linear) + + def checkDynamicQuantizedLinear(self, mod, dtype): + r"""Checks that mod has been swapped for an nnqd.Linear + module, the bias is float. + """ + self.assertEqual(type(mod), nnqd.Linear) + self.assertEqual(mod._packed_params.dtype, dtype) + + def checkDynamicQuantizedLinearRelu(self, mod, dtype): + r"""Checks that mod has been swapped for an nnqd.Linear + module, the bias is float. + """ + self.assertEqual(type(mod), nniqd.LinearReLU) + self.assertEqual(mod._packed_params.dtype, dtype) + + def check_eager_serialization(self, ref_model, loaded_model, x): + # Check state dict serialization and torch.save APIs + model_dict = ref_model.state_dict() + b = io.BytesIO() + torch.save(model_dict, b) + b.seek(0) + # weights_only=False as we sometimes get a ScriptObject here (weird) + loaded_dict = torch.load(b, weights_only=False) + loaded_model.load_state_dict(loaded_dict) + ref_out = ref_model(*x) + load_out = loaded_model(*x) + + def check_outputs(ref_out, load_out): + self.assertEqual(ref_out[0], load_out[0]) + if isinstance(ref_out[1], tuple): + self.assertEqual(ref_out[1][0], load_out[1][0]) + self.assertEqual(ref_out[1][1], load_out[1][1]) + else: + self.assertEqual(ref_out[1], load_out[1]) + + check_outputs(ref_out, load_out) + b = io.BytesIO() + torch.save(ref_model, b) + b.seek(0) + # weights_only=False as this is legacy code that saves the model + loaded = torch.load(b, weights_only=False) + load_out = loaded(*x) + check_outputs(ref_out, load_out) + + def check_weight_bias_api(self, ref_model, weight_keys, bias_keys): + weight = ref_model.get_weight() + bias = ref_model.get_bias() + self.assertEqual(weight_keys ^ weight.keys(), set()) + self.assertEqual(bias_keys ^ bias.keys(), set()) + + def checkDynamicQuantizedLSTM(self, mod, reference_module_type, dtype): + r"""Checks that mod has been swapped for an nnqd.LSTM type + module, the bias is float. + """ + wt_dtype_map = { + torch.qint8: "quantized_dynamic", + torch.float16: "quantized_fp16", + } + self.assertEqual(type(mod), reference_module_type) + for packed_params in mod._all_weight_values: + self.assertEqual( + packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype] + ) + + def checkLinear(self, mod): + self.assertEqual(type(mod), torch.nn.Linear) + + def checkDynamicQuantizedModule(self, mod, reference_module_type, dtype): + r"""Checks that mod has been swapped for an nnqd.Linear + module, the bias is float. + """ + wt_dtype_map = { + torch.qint8: "quantized_dynamic", + torch.float16: "quantized_fp16", + } + self.assertEqual(type(mod), reference_module_type) + if hasattr(mod, "_all_weight_values"): + for packed_params in mod._all_weight_values: + self.assertEqual( + packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype] + ) + + def checkScriptable(self, orig_mod, calib_data, check_save_load=False): + scripted = torch.jit.script(orig_mod) + self._checkScriptable(orig_mod, scripted, calib_data, check_save_load) + + # Use first calib_data entry as trace input + traced = torch.jit.trace(orig_mod, calib_data[0]) + self._checkScriptable(orig_mod, traced, calib_data, check_save_load) + + # Call this twice: once for a scripted module and once for a traced module + def _checkScriptable(self, orig_mod, script_mod, calib_data, check_save_load): + self._checkModuleCorrectnessAgainstOrig(orig_mod, script_mod, calib_data) + + # Test save/load + buffer = io.BytesIO() + torch.jit.save(script_mod, buffer) + + buffer.seek(0) + loaded_mod = torch.jit.load(buffer) + # Pending __get_state_ and __set_state__ support + # See tracking task https://github.com/pytorch/pytorch/issues/23984 + if check_save_load: + self._checkModuleCorrectnessAgainstOrig(orig_mod, loaded_mod, calib_data) + + def _checkModuleCorrectnessAgainstOrig(self, orig_mod, test_mod, calib_data): + for inp in calib_data: + ref_output = orig_mod(*inp) + scripted_output = test_mod(*inp) + self.assertEqual(scripted_output, ref_output) + + def checkGraphModeOp( + self, + module, + inputs, + quantized_op, + tracing=False, + debug=False, + check=True, + eval_mode=True, + dynamic=False, + qconfig=None, + ): + if debug: + print("Testing:", str(module)) + qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)} + + if eval_mode: + module = module.eval() + if dynamic: + qconfig_dict = {"": default_dynamic_qconfig if qconfig is None else qconfig} + model = get_script_module(module, tracing, inputs[0]).eval() + if debug: + print("input graph:", model.graph) + models = {} + outputs = {} + for debug in [True, False]: + if dynamic: + models[debug] = quantize_dynamic_jit(model, qconfig_dict, debug=debug) + # make sure it runs + outputs[debug] = models[debug](inputs) + else: + # module under test can contain in-place ops, and we depend on + # input data staying constant for comparisons + inputs_copy = copy.deepcopy(inputs) + models[debug] = quantize_jit( + model, + qconfig_dict, + test_only_eval_fn, + [inputs_copy], + inplace=False, + debug=debug, + ) + # make sure it runs + outputs[debug] = models[debug](*inputs[0]) + + if debug: + print("debug graph:", models[True].graph) + print("non debug graph:", models[False].graph) + + if check: + # debug and non-debug option should have the same numerics + self.assertEqual(outputs[True], outputs[False]) + + # non debug graph should produce quantized op + FileCheck().check(quantized_op).run(models[False].graph) + + return models[False] + + def checkGraphModuleNodes( + self, + graph_module, + expected_node=None, + expected_node_occurrence=None, + expected_node_list=None, + ): + """Check if GraphModule contains the target node + Args: + graph_module: the GraphModule instance we want to check + expected_node, expected_node_occurrence, expected_node_list: + see docs for checkGraphModeFxOp + """ + nodes_in_graph = {} + node_list = [] + modules = dict(graph_module.named_modules(remove_duplicate=False)) + for node in graph_module.graph.nodes: + n = None + if node.op == "call_function" or node.op == "call_method": + n = NodeSpec(node.op, node.target) + elif node.op == "call_module": + n = NodeSpec(node.op, type(modules[node.target])) + + if n is not None: + node_list.append(n) + if n in nodes_in_graph: + nodes_in_graph[n] += 1 + else: + nodes_in_graph[n] = 1 + + if expected_node is not None: + self.assertTrue( + expected_node in nodes_in_graph, + "node:" + str(expected_node) + " not found in the graph module", + ) + + if expected_node_occurrence is not None: + for expected_node, occurrence in expected_node_occurrence.items(): + if occurrence != 0: + self.assertTrue( + expected_node in nodes_in_graph, + "Check failed for node:" + str(expected_node) + " not found", + ) + self.assertTrue( + nodes_in_graph[expected_node] == occurrence, + "Check failed for node:" + + str(expected_node) + + " Expected occurrence:" + + str(occurrence) + + " Found occurrence:" + + str(nodes_in_graph[expected_node]), + ) + else: + self.assertTrue( + expected_node not in nodes_in_graph, + "Check failed for node:" + + str(expected_node) + + " expected no occurrence but found", + ) + + if expected_node_list is not None: + cur_index = 0 + for n in node_list: + if cur_index == len(expected_node_list): + return + if n == expected_node_list[cur_index]: + cur_index += 1 + self.assertTrue( + cur_index == len(expected_node_list), + "Check failed for graph:" + + self.printGraphModule(graph_module, print_str=False) + + "Expected ordered list:" + + str(expected_node_list), + ) + + def printGraphModule(self, graph_module, print_str=True): + modules = dict(graph_module.named_modules(remove_duplicate=False)) + node_infos = [] + for n in graph_module.graph.nodes: + node_info = " ".join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs])) + if n.op == "call_module": + node_info += " module type: " + repr(type(modules[n.target])) + node_infos.append(node_info) + str_to_print = "\n".join(node_infos) + if print_str: + print(str_to_print) + return str_to_print + + if HAS_FX: + + def assert_types_for_matched_subgraph_pairs( + self, + matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]], + expected_types: dict[ + str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]] + ], + gm_a: GraphModule, + gm_b: GraphModule, + ) -> None: + """ + Verifies that the types specified in expected_types match + the underlying objects pointed to by the nodes in matched_subgraph_pairs. + + An example successful test case: + + matched_subgraph_pairs = {'x0': (graph_a_conv_0_node, graph_b_conv_0_node)} + expected_types = {'x0': (nn.Conv2d, nnq.Conv2d)} + + The function tests for key equivalence, and verifies types with + instance checks. + """ + + def _get_underlying_op_type( + node: Node, gm: GraphModule + ) -> Union[Callable, str]: + if node.op == "call_module": + mod = getattr(gm, node.target) + return type(mod) + else: + assert node.op in ("call_function", "call_method") + return node.target + + self.assertTrue( + len(matched_subgraph_pairs) == len(expected_types), + f"Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}", + ) + for k, v in expected_types.items(): + expected_types_a, expected_types_b = v + exp_type_start_a, exp_type_end_a = expected_types_a + exp_type_start_b, exp_type_end_b = expected_types_b + subgraph_a, subgraph_b = matched_subgraph_pairs[k] + + act_type_start_a = _get_underlying_op_type(subgraph_a.start_node, gm_a) + act_type_start_b = _get_underlying_op_type(subgraph_b.start_node, gm_b) + act_type_end_a = _get_underlying_op_type(subgraph_a.end_node, gm_a) + act_type_end_b = _get_underlying_op_type(subgraph_b.end_node, gm_b) + types_match = ( + (exp_type_start_a is act_type_start_a) + and (exp_type_end_a is act_type_end_a) + and (exp_type_start_b is act_type_start_b) + and (exp_type_end_b is act_type_end_b) + ) + self.assertTrue( + types_match, + f"Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, " + f"got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}", + ) + + def assert_ns_compare_dict_valid( + self, + act_compare_dict: dict[str, dict[str, dict[str, Any]]], + ) -> None: + """ + Verifies that the act_compare_dict (output of Numeric Suite APIs) is valid: + 1. for each layer, results are recorded for two models + 2. number of seen tensors match + 3. shapes of each pair of seen tensors match + """ + for layer_name, result_type_to_data in act_compare_dict.items(): + for result_type, layer_data in result_type_to_data.items(): + self.assertTrue( + len(layer_data) == 2, + f"Layer {layer_name} does not have exactly two model results.", + ) + model_name_0, model_name_1 = layer_data.keys() + for res_idx in range(len(layer_data[model_name_0])): + layer_data_0 = layer_data[model_name_0][res_idx] + layer_data_1 = layer_data[model_name_1][res_idx] + self.assertTrue( + layer_data_0["type"] == layer_data_0["type"], + f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.", + ) + + self.assertTrue( + len(layer_data_0["values"]) == len(layer_data_1["values"]), + f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.", + ) + + # F.conv1d weight has rank 3, and toq.conv1d unpacked weight + # has rank 4. For now, skip the length check for conv1d only. + is_weight_functional_conv1d = ( + result_type == NSSingleResultValuesType.WEIGHT.value + and ( + "conv1d" in layer_data_0["prev_node_target_type"] + or "conv1d" in layer_data_1["prev_node_target_type"] + ) + ) + if not is_weight_functional_conv1d: + for idx in range(len(layer_data_0["values"])): + values_0 = layer_data_0["values"][idx] + values_1 = layer_data_1["values"][idx] + if isinstance(values_0, torch.Tensor): + self.assertTrue( + values_0.shape == values_1.shape, + f"Layer {layer_name}, {model_name_0} and {model_name_1} " + + f"have a shape mismatch at idx {idx}.", + ) + elif isinstance(values_0, list): + values_0 = values_0[0] + values_1 = values_1[0] + self.assertTrue( + values_0.shape == values_1.shape, + f"Layer {layer_name}, {model_name_0} and {model_name_1} " + + f"have a shape mismatch at idx {idx}.", + ) + else: + assert isinstance( + values_0, tuple + ), f"unhandled type {type(values_0)}" + assert len(values_0) == 2 + assert len(values_0[1]) == 2 + assert values_0[0].shape == values_1[0].shape + assert values_0[1][0].shape == values_1[1][0].shape + assert values_0[1][1].shape == values_1[1][1].shape + + # verify that ref_node_name is valid + ref_node_name_0 = layer_data_0["ref_node_name"] + ref_node_name_1 = layer_data_1["ref_node_name"] + prev_node_name_0 = layer_data_0["prev_node_name"] + prev_node_name_1 = layer_data_1["prev_node_name"] + if ( + layer_data_0["type"] + == NSSingleResultValuesType.NODE_OUTPUT.value + ): + self.assertTrue(ref_node_name_0 == prev_node_name_0) + self.assertTrue(ref_node_name_1 == prev_node_name_1) + elif ( + layer_data_0["type"] + == NSSingleResultValuesType.NODE_INPUT.value + ): + self.assertTrue(ref_node_name_0 != prev_node_name_0) + self.assertTrue(ref_node_name_1 != prev_node_name_1) + + def checkGraphModeFxOp( + self, + model, + inputs, + quant_type, + expected_node=None, + expected_node_occurrence=None, + expected_node_list=None, + is_reference=False, + print_debug_info=False, + custom_qconfig_dict=None, + prepare_expected_node=None, + prepare_expected_node_occurrence=None, + prepare_expected_node_list=None, + prepare_custom_config=None, + backend_config=None, + ): + """Quantizes model with graph mode quantization on fx and check if the + quantized model contains the quantized_node + + Args: + model: floating point torch.nn.Module + inputs: one positional sample input arguments for model + expected_node: NodeSpec + e.g. NodeSpec.call_function(torch.quantize_per_tensor) + expected_node_occurrence: a dict from NodeSpec to + expected number of occurrences (int) + e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, + NodeSpec.call_method('dequantize'): 1} + expected_node_list: a list of NodeSpec, used to check the order + of the occurrence of Node + e.g. [NodeSpec.call_function(torch.quantize_per_tensor), + NodeSpec.call_module(nnq.Conv2d), + NodeSpec.call_function(F.hardtanh_), + NodeSpec.call_method('dequantize')] + is_reference: if True, enables reference mode + print_debug_info: if True, prints debug info + custom_qconfig_dict: overrides default qconfig_dict + prepare_expected_node: same as expected_node, but for prepare + prepare_expected_node_occurrence: same as + expected_node_occurrence, but for prepare + prepare_expected_node_list: same as expected_node_list, but + for prepare + + Returns: + A dictionary with the following structure: + { + "prepared": ..., # the prepared model + "quantized": ..., # the quantized non-reference model + "quantized_reference": ..., # the quantized reference model + "result": ..., # the result for either quantized or + # quantized_reference model depending on the + # is_reference argument + } + """ + # TODO: make img_data a single example instead of a list + if type(inputs) is list: + inputs = inputs[0] + + if quant_type == QuantType.QAT: + qconfig_mapping = get_default_qat_qconfig_mapping( + torch.backends.quantized.engine + ) + model.train() + elif quant_type == QuantType.STATIC: + qconfig_mapping = get_default_qconfig_mapping( + torch.backends.quantized.engine + ) + model.eval() + else: + qconfig = default_dynamic_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + model.eval() + + if quant_type == QuantType.QAT: + prepare = prepare_qat_fx + else: + prepare = prepare_fx + + # overwrite qconfig_dict with custom_qconfig_dict + if custom_qconfig_dict is not None: + assert type(custom_qconfig_dict) in ( + QConfigMapping, + dict, + ), "custom_qconfig_dict should be a QConfigMapping or a dict" + if isinstance(custom_qconfig_dict, QConfigMapping): + qconfig_mapping = custom_qconfig_dict + else: + qconfig_mapping = QConfigMapping.from_dict(custom_qconfig_dict) + prepared = prepare( + model, + qconfig_mapping, + example_inputs=inputs, + prepare_custom_config=prepare_custom_config, + backend_config=backend_config, + ) + if quant_type != QuantType.DYNAMIC: + prepared(*inputs) + + if print_debug_info: + print() + print("quant type:\n", quant_type) + print("original model:\n", model) + print() + print("prepared model:\n", prepared) + + self.checkGraphModuleNodes( + prepared, + prepare_expected_node, + prepare_expected_node_occurrence, + prepare_expected_node_list, + ) + + prepared_copy = copy.deepcopy(prepared) + qgraph = convert_fx(copy.deepcopy(prepared)) + qgraph_reference = convert_to_reference_fx(copy.deepcopy(prepared)) + result = qgraph(*inputs) + result_reference = qgraph_reference(*inputs) + qgraph_copy = copy.deepcopy(qgraph) + qgraph_reference_copy = copy.deepcopy(qgraph_reference) + + qgraph_to_check = qgraph_reference if is_reference else qgraph + if print_debug_info: + print() + print("quantized model:\n", qgraph_to_check) + self.printGraphModule(qgraph_to_check) + print() + self.checkGraphModuleNodes( + qgraph_to_check, + expected_node, + expected_node_occurrence, + expected_node_list, + ) + return { + "prepared": prepared_copy, + "quantized": qgraph_copy, + "quantized_reference": qgraph_reference_copy, + "quantized_output": result, + "quantized_reference_output": result_reference, + } + + def checkEmbeddingSerialization( + self, + qemb, + num_embeddings, + embedding_dim, + indices, + offsets, + set_qconfig, + is_emb_bag, + dtype=torch.quint8, + ): + # Test serialization of dynamic EmbeddingBag module using state_dict + if is_emb_bag: + inputs = [indices, offsets] + else: + inputs = [indices] + emb_dict = qemb.state_dict() + b = io.BytesIO() + torch.save(emb_dict, b) + b.seek(0) + loaded_dict = torch.load(b) + embedding_unpack = torch.ops.quantized.embedding_bag_unpack + # Check unpacked weight values explicitly + for key in emb_dict: + if isinstance(emb_dict[key], torch._C.ScriptObject): + assert isinstance(loaded_dict[key], torch._C.ScriptObject) + emb_weight = embedding_unpack(emb_dict[key]) + loaded_weight = embedding_unpack(loaded_dict[key]) + self.assertEqual(emb_weight, loaded_weight) + + # Check state dict serialization and torch.save APIs + if is_emb_bag: + loaded_qemb = nnq.EmbeddingBag( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + include_last_offset=True, + mode="sum", + dtype=dtype, + ) + else: + loaded_qemb = nnq.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype + ) + self.check_eager_serialization(qemb, loaded_qemb, inputs) + + loaded_qemb.load_state_dict(loaded_dict) + self.assertEqual( + embedding_unpack(qemb._packed_params._packed_weight), + embedding_unpack(loaded_qemb._packed_params._packed_weight), + ) + + # Test JIT serialization + self.checkScriptable(qemb, [inputs], check_save_load=True) + + # Test from_float call + if is_emb_bag: + float_embedding = torch.nn.EmbeddingBag( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + include_last_offset=True, + scale_grad_by_freq=False, + mode="sum", + ) + else: + float_embedding = torch.nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim + ) + + if set_qconfig: + float_qparams_observer = PerChannelMinMaxObserver.with_args( + dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 + ) + float_embedding.qconfig = QConfig( + activation=default_dynamic_quant_observer, weight=float_qparams_observer + ) + + prepare_dynamic(float_embedding) + + float_embedding(*inputs) + if is_emb_bag: + q_embeddingbag = nnq.EmbeddingBag.from_float(float_embedding) + expected_name = "QuantizedEmbeddingBag" + else: + q_embeddingbag = nnq.Embedding.from_float(float_embedding) + expected_name = "QuantizedEmbedding" + + q_embeddingbag(*inputs) + + self.assertTrue(expected_name in str(q_embeddingbag)) + + +class QuantizationLiteTestCase(QuantizationTestCase): + def _create_quantized_model(self, model_class: type[torch.nn.Module], **kwargs): + # Creates quantized model for testing mobile script modules + qengine = "qnnpack" + with override_quantized_engine(qengine): + # FIXME(rec): shouldn't qconfig be passed to quantize? + qconfig = torch.ao.quantization.get_default_qconfig(qengine) # noqa: F841 + model = model_class(**kwargs) + model = quantize(model, test_only_eval_fn, [self.calib_data]) + + return model + + def _compare_script_and_mobile(self, model: torch.nn.Module, input: torch.Tensor): + # Compares the numerical outputs for script and lite modules + qengine = "qnnpack" + with override_quantized_engine(qengine): + script_module = torch.jit.script(model) + script_module_result = script_module(input) + + max_retry = 5 + for retry in range(1, max_retry + 1): + # retries `max_retry` times; breaks iff succeeds else throws exception + try: + buffer = io.BytesIO( + script_module._save_to_buffer_for_lite_interpreter() + ) + buffer.seek(0) + mobile_module = _load_for_lite_interpreter(buffer) + + mobile_module_result = mobile_module(input) + + torch.testing.assert_close( + script_module_result, mobile_module_result + ) + mobile_module_forward_result = mobile_module.forward(input) + torch.testing.assert_close( + script_module_result, mobile_module_forward_result + ) + + mobile_module_run_method_result = mobile_module.run_method( + "forward", input + ) + torch.testing.assert_close( + script_module_result, mobile_module_run_method_result + ) + except AssertionError as e: + if retry == max_retry: + raise e + else: + continue + break + + +class PT2EQuantizationTestCase(QuantizationTestCase): + """ + Base QuantizationTestCase for PT2 with some helper methods. + """ + + _MAP_TO_FX_TRACED_OPS = { + torch.ops.quantized_decomposed.quantize_per_tensor: torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_channel: torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_channel: torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + } + + def _test_quantizer( + self, + model, + example_inputs, + quantizer, + expected_node_occurrence, + expected_node_list=None, + check_against_fx_quant=False, + fx_qconfig_mapping=None, + export_with_dynamic_shape=False, + is_qat=False, + is_debug_mode=False, + training_ir_node_occurrence=None, + ): + # resetting dynamo cache + torch._dynamo.reset() + m_eager = model.eval() + + # program capture + m = copy.deepcopy(m_eager) + dynamic_shapes = tuple( + {0: torch.export.Dim("dim")} if i == 0 else None + for i in range(len(example_inputs)) + ) + m = export( + m, + example_inputs, + dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + strict=True, + ).module() + + if is_qat: + m = prepare_qat_pt2e(m, quantizer) + else: + m = prepare_pt2e(m, quantizer) + if is_debug_mode: + print("prepared model:", m) + # Calibrate + m(*example_inputs) + m = convert_pt2e(m) + if is_debug_mode: + print("quantized model", m) + + pt2_quant_output = m(*example_inputs) + ns = NodeSpec + node_occurrence = { + ns.call_function(k): v for k, v in expected_node_occurrence.items() + } + if expected_node_list is None: + expected_node_list = [] + node_list = [ns.call_function(n) for n in expected_node_list] + self.checkGraphModuleNodes( + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list + ) + if check_against_fx_quant: + qconfig_mapping = fx_qconfig_mapping + backend_config = get_executorch_backend_config() + m_copy = copy.deepcopy(m_eager) + m_fx = prepare_fx( + m_copy, qconfig_mapping, example_inputs, backend_config=backend_config + ) + m_fx(*example_inputs) + m_fx = _convert_to_reference_decomposed_fx( + m_fx, backend_config=backend_config + ) + m_fx = export( + m_fx, + example_inputs, + dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + strict=True, + ).module() + node_occurrence = {} + for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items(): + if k in expected_node_occurrence: + node_occurrence[ns.call_function(v)] = expected_node_occurrence[k] + if training_ir_node_occurrence is not None: + node_occurrence = { + ns.call_function(k): v + for k, v in training_ir_node_occurrence.items() + } + self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence) + fx_quant_output = m_fx(*example_inputs) + self.assertEqual(fx_quant_output, pt2_quant_output) + return m + + def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): + # resetting dynamo cache + torch._dynamo.reset() + + m = export(m, example_inputs, strict=True).module() + if is_qat: + m = prepare_qat_pt2e(m, quantizer) + else: + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m) + return m + + def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule: + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=is_per_channel + ) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + return self._quantize(m, quantizer, example_inputs) + + +# Below are a series of toy models to use in testing quantization + + +class SingleLayerLinearModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class AnnotatedSingleLayerLinearModel(torch.nn.Module): + def __init__(self, qengine="fbgemm"): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.fc1 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) + + def forward(self, x): + x = self.fc1(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class SingleLayerLinearDynamicModel(torch.nn.Module): + def __init__(self, qengine="fbgemm"): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class LinearAddModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) + self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = torch.add(x, 5) + x = self.fc2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class RNNDynamicModel(torch.nn.Module): + def __init__(self, mod_type): + super().__init__() + self.qconfig = default_dynamic_qconfig + if mod_type == "GRU": + self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) + if mod_type == "LSTM": + self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) + + def forward(self, x): + x = self.mod(x) + return x + + +class RNNCellDynamicModel(torch.nn.Module): + def __init__(self, mod_type): + super().__init__() + self.qconfig = default_dynamic_qconfig + if mod_type == "GRUCell": + self.mod = torch.nn.GRUCell(2, 2).to(dtype=torch.float) + if mod_type == "LSTMCell": + self.mod = torch.nn.LSTMCell(2, 2).to(dtype=torch.float) + if mod_type == "RNNReLU": + self.mod = torch.nn.RNNCell(2, 2, nonlinearity="relu").to(dtype=torch.float) + if mod_type == "RNNTanh": + self.mod = torch.nn.RNNCell(2, 2, nonlinearity="tanh").to(dtype=torch.float) + + def forward(self, x): + x = self.mod(x) + return x + + +class LSTMwithHiddenDynamicModel(torch.nn.Module): + def __init__(self, qengine="fbgemm"): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float) + + def forward(self, x, hid): + x, hid = self.lstm(x, hid) + return x, hid + + +class ConvModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + + def forward(self, x): + x = self.conv(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class ConvTransposeModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float) + + def forward(self, x): + x = self.conv(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class AnnotatedConvModel(torch.nn.Module): + def __init__(self, qengine): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = self.dequant(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class AnnotatedConvTransposeModel(torch.nn.Module): + def __init__(self, qengine): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = self.dequant(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class ConvBnModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class AnnotatedConvBnModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.qconfig = default_qconfig + self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = self.bn(x) + x = self.dequant(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class ConvBnReLUModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class AnnotatedConvBnReLUModel(torch.nn.Module): + def __init__(self, qengine="fbgemm"): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) + self.relu = nn.ReLU(inplace=True) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + x = self.dequant(x) + return x + + def fuse_model(self): + # TODO: remove this check and define two fuse_modules function on this module + if self.training: + torch.ao.quantization.fuse_modules_qat( + self, [["conv", "bn", "relu"]], inplace=True + ) + else: + torch.ao.quantization.fuse_modules( + self, [["conv", "bn", "relu"]], inplace=True + ) + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class TwoLayerConvModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.conv2 = torch.nn.Conv2d(5, 5, 1, bias=False).to(dtype=torch.float) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class TwoLayerLinearModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) + self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class LinearModelWithSubmodule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.subm = TwoLayerLinearModel() + self.fc = nn.Linear(5, 5) + + def forward(self, x): + x = self.subm(x) + x = self.fc(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.subm.get_example_inputs() + + +class AnnotatedTwoLayerLinearModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) + self.fc2 = QuantWrapper(torch.nn.Linear(8, 5).to(dtype=torch.float)) + self.fc2.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class ActivationsTestModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") + self.quant = torch.ao.quantization.QuantStub() + self.hardswish = torch.nn.Hardswish().to(dtype=torch.float) + self.elu = torch.nn.ELU().to(dtype=torch.float) + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.hardswish(x) + x = self.elu(x) + x = self.dequant(x) + return x + + +class LinearReluModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc(x)) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class LinearReluLinearModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class LinearReluAddModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(5, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = torch.add(x, 5) + x = self.fc2(x) + self.relu = torch.nn.ReLU() + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class LinearBnLeakyReluModel(torch.nn.Module): + def __init__(self, with_bn=True): + super().__init__() + self.linear = nn.Linear(5, 5) + self.bn1d = nn.BatchNorm1d(5) + self.leaky_relu = nn.LeakyReLU(0.01) + self.with_bn = with_bn + + def forward(self, x): + x = self.linear(x) + if self.with_bn: + x = self.bn1d(x) + x = self.leaky_relu(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class LinearTanhModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(5, 5) + self.tanh = nn.Tanh() + + def forward(self, x): + x = self.linear(x) + x = self.tanh(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class ConvBnAddReluModel(torch.nn.Module): + def __init__( + self, + with_bn=True, + with_relu=True, + left_conv=True, + two_conv=True, + use_torch_add=True, + ): + super().__init__() + self.conv = nn.Conv2d(5, 5, (2, 2)) + self.conv2 = nn.Conv2d(5, 5, (2, 2)) + self.bn = nn.BatchNorm2d(5) + self.relu = nn.ReLU() + self.with_bn = with_bn + self.with_relu = with_relu + self.two_conv = two_conv + self.left_conv = left_conv + self.use_torch_add = use_torch_add + + def forward(self, x1, x2): + if self.two_conv: + if self.use_torch_add: + if self.with_bn: + x = torch.add(self.bn(self.conv(x1)), self.conv2(x1)) + else: + x = torch.add(self.conv(x1), self.conv2(x1)) + else: + if self.with_bn: + x = self.bn(self.conv(x1)) + self.conv2(x1) + else: + x = self.conv(x1) + self.conv2(x1) + else: + if self.use_torch_add: + if self.left_conv: + if self.with_bn: + x = torch.add(self.bn(self.conv(x1)), x2) + else: + x = torch.add(self.conv(x1), x2) + else: + if self.with_bn: + x = torch.add(x2, self.bn(self.conv(x1))) + else: + x = torch.add(x2, self.conv(x1)) + else: + if self.left_conv: + if self.with_bn: + x = self.bn(self.conv(x1)) + x2 + else: + x = self.conv(x1) + x2 + else: + if self.with_bn: + x = x2 + self.bn(self.conv(x1)) + else: + x = x2 + self.conv(x1) + if self.with_relu: + x = self.relu(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5, 3, 3), torch.rand(1, 5, 2, 2)) + + +# TODO: self.fc should be self.conv +class ConvReluModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc(x)) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +# TODO: self.fc should be self.conv +class ConvReluConvModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +# TODO: self.fc should be self.conv +class ConvReluAddModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = torch.add(x, 5) + x = self.fc2(x) + self.relu = torch.nn.ReLU() + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class NormalizationTestModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.quant = torch.ao.quantization.QuantStub() + self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) + self.layer_norm = torch.nn.LayerNorm(8) + self.group_norm = torch.nn.GroupNorm(2, 8) + self.instance_norm1d = torch.nn.InstanceNorm1d(8) + self.instance_norm2d = torch.nn.InstanceNorm2d(8) + self.instance_norm3d = torch.nn.InstanceNorm3d(8) + + def forward(self, x): + x = self.quant(x) + x = self.fc1(x) + x = self.layer_norm(x) + x = self.group_norm(x.unsqueeze(-1).repeat(1, 1, 3)) + x = self.instance_norm1d(x) + x = self.instance_norm2d(x.unsqueeze(-1)) + x = self.instance_norm3d(x.unsqueeze(-1)) + return x + + +class NestedModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sub1 = LinearReluModel() + self.sub2 = TwoLayerLinearModel() + self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.sub1(x) + x = self.sub2(x) + x = self.fc3(x) + return x + + +class AnnotatedNestedModel(torch.nn.Module): + def __init__(self, qengine): + super().__init__() + self.sub1 = LinearReluModel() + self.sub2 = TwoLayerLinearModel() + self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) + self.fc3.qconfig = default_qconfig + self.sub2.fc1 = QuantWrapper(self.sub2.fc1) + if qengine == "fbgemm": + self.sub2.fc1.qconfig = default_per_channel_qconfig + else: + self.sub2.fc1.qconfig = default_qconfig + + def forward(self, x): + x = self.sub1(x) + x = self.sub2(x) + x = self.fc3(x) + return x + + +class AnnotatedSubNestedModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sub1 = LinearReluModel() + self.sub2 = QuantWrapper(TwoLayerLinearModel()) + self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) + self.fc3.qconfig = default_qconfig + self.sub2.qconfig = default_qconfig + + def forward(self, x): + x = self.sub1(x) + x = self.sub2(x) + x = self.fc3(x) + return x + + +class AnnotatedCustomConfigNestedModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sub1 = LinearReluModel() + self.sub2 = TwoLayerLinearModel() + self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) + self.fc3.qconfig = default_qconfig + self.sub2.qconfig = default_qconfig + + custom_options = {"dtype": torch.quint8, "qscheme": torch.per_tensor_affine} + custom_qconfig = QConfig( + activation=default_observer.with_args(**custom_options), + weight=default_weight_observer, + ) + self.sub2.fc1.qconfig = custom_qconfig + + self.sub2.fc1 = QuantWrapper(self.sub2.fc1) + self.sub2.fc2 = QuantWrapper(self.sub2.fc2) + + def forward(self, x): + x = self.sub1(x) + x = self.sub2(x) + x = self.fc3(x) + return x + + +class QuantSubModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sub1 = LinearReluModel() + self.sub2 = QuantWrapper(TwoLayerLinearModel()) + self.sub2.qconfig = default_qconfig + self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float) + self.fc3.qconfig = default_qconfig + + def forward(self, x): + x = self.sub1(x) + x = self.sub2(x) + x = self.fc3(x) + return x + + +class InnerModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) + self.relu1 = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) + self.relu2 = torch.nn.ReLU() + + def forward(self, x): + return self.relu2(self.fc2(self.relu1(self.fc1(x)))) + + def fuse_modules(self): + fusable_layers = [] + named_children = list(self.named_children()) + for idx, (current_name, layer) in enumerate(named_children): + if isinstance(layer, torch.nn.Linear): + if idx >= len(named_children) - 1: + break + if isinstance(named_children[idx + 1][1], torch.nn.ReLU): + fusable_layers.append([current_name, named_children[idx + 1][0]]) + # TODO: remove this check and define two fuse_modules function on this module + if self.training: + torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True) + else: + torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True) + + +class FunctionalLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.rand((5, 5)) + self.bias = torch.zeros(5) + + def forward(self, x): + return F.linear(x, self.weight, self.bias) + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 5),) + + +class SingleLayerFunctionalLinearModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = FunctionalLinear() + + def forward(self, x): + x = self.linear1(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.linear1.get_example_inputs() + + +class TwoLayerFunctionalLinearModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = FunctionalLinear() + self.linear2 = FunctionalLinear() + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.linear1.get_example_inputs() + + +class FunctionalLinearAddModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = FunctionalLinear() + self.linear2 = FunctionalLinear() + + def forward(self, x): + x = self.linear1(x) + x = torch.add(x, 5) + x = self.linear2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.linear1.get_example_inputs() + + +class FunctionalLinearReluModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = FunctionalLinear() + + def forward(self, x): + x = self.linear(x) + x = F.relu(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.linear.get_example_inputs() + + +class FunctionalLinearReluLinearModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = FunctionalLinear() + self.relu = nn.ReLU() + self.linear2 = FunctionalLinear() + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.linear1.get_example_inputs() + + +class FunctionalConv2d(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.rand(3, 3, 3, 3) + self.bias = torch.rand(3) + self.stride = (1, 1) + self.padding = (0, 0) + self.dilation = (1, 1) + self.groups = 1 + + def forward(self, x): + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def get_example_inputs(self) -> tuple[Any, ...]: + return (torch.rand(1, 3, 5, 5),) + + +class SingleLayerFunctionalConvModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = FunctionalConv2d() + + def forward(self, x): + x = self.conv1(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.conv1.get_example_inputs() + + +class TwoLayerFunctionalConvModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = FunctionalConv2d() + self.conv2 = FunctionalConv2d() + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.conv1.get_example_inputs() + + +class FunctionalConvReluModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = FunctionalConv2d() + + def forward(self, x): + x = self.conv(x) + x = F.relu(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.conv.get_example_inputs() + + +class FunctionalConvReluConvModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = FunctionalConv2d() + self.relu = nn.ReLU() + self.conv2 = FunctionalConv2d() + + def forward(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + return x + + def get_example_inputs(self) -> tuple[Any, ...]: + return self.conv1.get_example_inputs() + + +class SkipQuantModel(torch.nn.Module): + r"""We can skip quantization by explicitly + setting qconfig of a submodule to None + """ + + def __init__(self) -> None: + super().__init__() + self.sub = InnerModule() + self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) + + def forward(self, x): + return self.fc(self.sub(x)) + + def fuse_modules(self): + self.sub.fuse_modules() + + +class AnnotatedSkipQuantModel(torch.nn.Module): + r"""We can skip quantization by explicitly + setting qconfig of a submodule to None + """ + + def __init__(self, qengine): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + self.sub = QuantWrapper(InnerModule()) + self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) + # don't quantize this fc + self.fc.qconfig = None + + def forward(self, x): + return self.fc(self.sub(x)) + + def fuse_modules(self): + self.sub.module.fuse_modules() + + +class QuantStubModel(torch.nn.Module): + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + + def __init__(self) -> None: + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) + + def forward(self, x): + x = self.quant(x) + x = self.fc(x) + return self.dequant(x) + + +class ManualLinearQATModel(torch.nn.Module): + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + + def __init__(self, qengine): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float) + self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float) + + def forward(self, x): + x = self.quant(x) + x = self.fc1(x) + x = self.fc2(x) + return self.dequant(x) + + +class ManualDropoutQATModel(torch.nn.Module): + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + + def __init__(self, qengine): + super().__init__() + self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float) + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, x): + x = self.quant(x) + x = self.fc1(x) + x = self.dropout(x) + return self.dequant(x) + + +class ManualLinearDynamicQATModel(torch.nn.Module): + r"""A Module that uses a dynamic QAT by default.""" + + def __init__(self, qconfig=None): + super().__init__() + self.qconfig = qconfig or default_dynamic_qat_qconfig + self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float) + self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + +class ManualConvLinearQATModel(torch.nn.Module): + r"""A module with manually inserted `QuantStub` and `DeQuantStub` + and contains both linear and conv modules + """ + + def __init__(self, qconfig=None): + super().__init__() + self.qconfig = ( + qconfig + if qconfig + else torch.ao.quantization.get_default_qat_qconfig("qnnpack") + ) + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float) + self.fc1 = torch.nn.Linear(64, 10).to(dtype=torch.float) + self.fc2 = torch.nn.Linear(10, 10).to(dtype=torch.float) + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = x.view(-1, 64).contiguous() + x = self.fc1(x) + x = self.fc2(x) + return self.dequant(x) + + +class ManualConvLinearSymmQATModel(ManualConvLinearQATModel): + r"""Same as ManualConvLinearQATModule but with Symmetric Quantization. + Supported only with qnnpack. + """ + + def __init__(self) -> None: + super().__init__(default_symmetric_qnnpack_qat_qconfig) + + +class ManualEmbeddingBagLinear(nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode="sum") + self.emb.qconfig = default_embedding_qat_qconfig + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.linear = nn.Linear(12, 1).to(dtype=torch.float) + self.qconfig = get_default_qat_qconfig("qnnpack") + + def forward( + self, + input: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + per_sample_weights: Optional[torch.Tensor] = None, + ): + x = self.emb(input, offsets, per_sample_weights) + x = self.quant(x) + x = self.linear(x) + return self.dequant(x) + + +class DeFusedEmbeddingBagLinear(nn.Module): + r"""A module to simulate QAT embedding bag with a linear layer, + this module uses a separate embedding and bagging op, similar + to that which is described in the EmbeddingBag documentation. + + https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html + """ + + def __init__(self) -> None: + super().__init__() + self.emb = nn.Embedding(num_embeddings=10, embedding_dim=12) + self.emb.qconfig = default_embedding_qat_qconfig + self.bagging_op = torch.sum + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.linear = nn.Linear(12, 1).to(dtype=torch.float) + self.qconfig = get_default_qat_qconfig("qnnpack") + + def forward(self, input: torch.Tensor) -> torch.Tensor: + x = self.bagging_op(self.emb(input), dim=1) + x = self.quant(x) + x = self.linear(x) + return self.dequant(x) + + +class SubModelForFusion(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float) + self.bn = nn.BatchNorm2d(2).to(dtype=torch.float) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class SubModelWithoutFusion(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float) + self.relu = nn.ReLU(inplace=False).to(dtype=torch.float) + + def forward(self, x): + return self.relu(self.conv(x)) + + +class ModelForFusion(nn.Module): + def __init__(self, qconfig): + super().__init__() + self.conv1 = nn.Conv2d(3, 2, 1, bias=None).to(dtype=torch.float) + self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float) + self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float) + self.sub1 = SubModelForFusion() + self.sub2 = SubModelWithoutFusion() + self.fc = nn.Linear(36, 10).to(dtype=torch.float) + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.qconfig = qconfig + self.conv2 = nn.Conv3d(3, 2, (1, 1, 1), bias=None).to(dtype=torch.float) + self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float) + self.bn2 = nn.BatchNorm3d(2).to(dtype=torch.float) + self.relu3 = nn.ReLU(inplace=True).to(dtype=torch.float) + self.conv3 = nn.Conv1d(3, 3, 2).to(dtype=torch.float) + self.bn3 = nn.BatchNorm1d(3).to(dtype=torch.float) + self.relu4 = nn.ReLU(inplace=True).to(dtype=torch.float) + # don't quantize sub2 + self.sub2.qconfig = None + self.fc.qconfig = None + + def forward(self, x): + x = x.squeeze(2) + x = self.quant(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu4(x) + x = x.unsqueeze(2) + y = x.unsqueeze(2) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.sub1(x) + x = self.dequant(x) + x = self.sub2(x) + x = x.reshape(-1, 36).contiguous() + x = self.fc(x) + y = self.conv2(y) + y = self.relu2(y) + y = self.bn2(y) + y = self.relu3(y) + y = self.dequant(y) + return x + + +class ConvBNReLU(nn.Sequential): + def __init__(self) -> None: + super().__init__( + nn.Conv2d(3, 3, 1, 1, bias=False), nn.BatchNorm2d(3), nn.ReLU(inplace=False) + ) + + +class ModelWithSequentialFusion(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.Conv2d(3, 3, 1) + self.relu1 = nn.ReLU(inplace=False) + layers = [ConvBNReLU() for _ in range(3)] + self.features = nn.Sequential(*layers) + head = [nn.Linear(300, 10), nn.ReLU(inplace=False)] + self.classifier = nn.Sequential(*head) + self.seq = nn.Sequential() + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv1(x) + x = self.relu1(x) + x = self.features(x) + x = torch.reshape(x, (-1, 3 * 10 * 10)) + x = self.classifier(x) + x = self.seq(x) + x = self.dequant(x) + return x + + +class ModelForFusionWithBias(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.Conv2d(3, 2, 5, bias=True).to(dtype=torch.float) + self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float) + self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float) + self.conv2 = nn.Conv2d(2, 2, 1, bias=True).to(dtype=torch.float) + self.bn2 = nn.BatchNorm2d(2).to(dtype=torch.float) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.dequant(x) + return x + + +class ModelForLinearBNFusion(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 10) + self.bn = nn.BatchNorm1d(10) + nn.init.uniform_(self.bn.weight) + nn.init.uniform_(self.bn.bias) + + def forward(self, x): + return self.bn(self.fc(x)) + + +class DummyObserver(torch.nn.Module): + def calculate_qparams(self): + return 1.0, 0 + + def forward(self, x): + return x + + +class ModelForConvTransposeBNFusion(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.ConvTranspose1d(3, 3, 1) + self.bn1 = nn.BatchNorm1d(3) + self.conv2 = nn.ConvTranspose2d(3, 3, 1) + self.bn2 = nn.BatchNorm2d(3) + self.conv3 = nn.ConvTranspose3d(3, 3, 1) + self.bn3 = nn.BatchNorm3d(3) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = x.unsqueeze(2) + x = self.conv2(x) + x = self.bn2(x) + x = x.unsqueeze(2) + x = self.conv3(x) + x = self.bn3(x) + return x + + +class ModelWithFunctionals(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.mycat = nnq.FloatFunctional() + self.myadd = nnq.FloatFunctional() + self.myadd_relu = nnq.FloatFunctional() + self.mymatmul = nnq.FloatFunctional() + # Tracing doesn't work yet for c10 ops with scalar inputs + # https://github.com/pytorch/pytorch/issues/27097 + # self.my_scalar_add = nnq.FloatFunctional() + # self.my_scalar_mul = nnq.FloatFunctional() + + def forward(self, x): + y = self.mycat.cat([x, x, x]) + z = self.myadd.add(y, y) + w = self.myadd_relu.add_relu(z, z) + u = self.mymatmul.matmul(w, w.T) + # Tracing doesn't work yet for c10 ops with scalar inputs + # https://github.com/pytorch/pytorch/issues/27097 + # w = self.my_scalar_add.add_scalar(w, -0.5) + # w = self.my_scalar_mul.mul_scalar(w, 0.5) + return u + + +class ResNetBase(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + norm_layer = nn.BatchNorm2d + inplanes = 3 + self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.bn1 = norm_layer(inplanes) + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.downsample = torch.nn.Identity() + self.myop = nn.quantized.FloatFunctional() + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = torch.nn.Linear(inplanes, 1) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + identity = self.downsample(x) + out = self.myop.add(out, identity) + out = self.relu2(out) + out = self.avgpool(out) + out = torch.flatten(out, 1) + out = self.fc(out) + return out + + def fuse_model(self): + # TODO: remove this check and define two fuse_model function on this module + if self.training: + torch.ao.quantization.fuse_modules_qat( + self, [["conv1", "bn1", "relu1"]], inplace=True + ) + else: + torch.ao.quantization.fuse_modules( + self, [["conv1", "bn1", "relu1"]], inplace=True + ) + + +class ModelMultipleOps(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + norm_layer = nn.BatchNorm2d + inplanes = 3 + self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.bn1 = norm_layer(inplanes) + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.downsample = torch.nn.Identity() + self.skip_add = nn.quantized.FloatFunctional() + self.cat = nn.quantized.FloatFunctional() + self.avgpool = nn.AdaptiveAvgPool2d((4, 4)) + self.fc = nn.Linear(12, 6) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + identity = self.downsample(x) + out = self.skip_add.add(out, identity) + out = self.relu2(out) + out = self.avgpool(out) + out = self.conv2(out) + out = torch.nn.functional.max_pool2d(out, 2, 2) + out = self.cat.cat([out, out]) + out = out.reshape(-1, 3 * 2 * 2) + out = self.fc(out) + return out + + +# Model to ensure consistency of fake quant with true quant +# Average pooling and mean operations are not modelled +# accurately with fake-quant so this model does not +# contain those operations +class ModelMultipleOpsNoAvgPool(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + norm_layer = nn.BatchNorm2d + inplanes = 3 + self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.bn1 = norm_layer(inplanes) + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.skip_add = nn.quantized.FloatFunctional() + self.cat = nn.quantized.FloatFunctional() + self.maxpool = nn.MaxPool2d((4, 4)) + self.fc = nn.Linear(12, 6) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + skip = self.conv2(x) + out = self.skip_add.add(out, skip) + out = self.relu2(out) + out = self.maxpool(out) + out = self.conv2(out) + out = torch.nn.functional.max_pool2d(out, 2, 2) + out = self.cat.cat([out, out]) + out = out.reshape(-1, 3 * 2 * 2) + out = self.fc(out) + return out + + +class EmbeddingBagModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.EmbeddingBag( + num_embeddings=10, + embedding_dim=12, + include_last_offset=True, + scale_grad_by_freq=False, + mode="sum", + ) + + def forward(self, indices, offsets, per_sample_weights): + return self.emb(indices, offsets, per_sample_weights) + + +class EmbeddingModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) + + def forward(self, indices): + return self.emb(indices) + + +class EmbeddingWithStaticLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12) + self.fc = torch.nn.Linear(4, 2) + self.emb.qconfig = float_qparams_weight_only_qconfig + self.qconfig = default_qconfig + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, indices, offsets, linear_in): + emb = self.emb(indices, offsets) + q_x = self.quant(linear_in) + fc = self.fc(q_x) + fc = self.dequant(fc) + features = torch.cat([fc] + [emb], dim=1) + return features + + +class DenseTopMLP(nn.Module): + def __init__( + self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out + ) -> None: + super().__init__() + + self.dense_mlp = nn.Sequential( + nn.Linear(dense_dim, dense_out), + ) + self.top_mlp = nn.Sequential( + nn.Linear(dense_out + embedding_dim, top_out_in), + nn.Linear(top_out_in, top_out_out), + ) + + def forward( + self, + sparse_feature: torch.Tensor, + dense: torch.Tensor, + ) -> torch.Tensor: + dense_feature = self.dense_mlp(dense) + features = torch.cat([dense_feature] + [sparse_feature], dim=1) + + out = self.top_mlp(features) + return out + + +# thin wrapper around embedding bag, because tracing inside nn.Embedding +# bag is not supported at the moment and this is top level +class EmbBagWrapper(nn.Module): + def __init__(self, num_embeddings, embedding_dim): + super().__init__() + self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum") + + def forward(self, indices, offsets): + return self.emb_bag(indices, offsets) + + +class SparseNNModel(nn.Module): + _NUM_EMBEDDINGS = 10 + _EMBEDDING_DIM = 5 + _DENSE_DIM = 4 + _DENSE_OUTPUT = 2 + _TOP_OUT_IN = 2 + _TOP_OUT_OUT = 2 + _TOP_MLP_DIM = 1 + + def __init__(self) -> None: + super().__init__() + + self.model_sparse = EmbBagWrapper(self._NUM_EMBEDDINGS, self._EMBEDDING_DIM) + self.dense_top = DenseTopMLP( + self._DENSE_DIM, + self._DENSE_OUTPUT, + self._EMBEDDING_DIM, + self._TOP_OUT_IN, + self._TOP_OUT_OUT, + ) + + def forward( + self, + sparse_indices: torch.Tensor, + sparse_offsets: torch.Tensor, + dense: torch.Tensor, + ) -> torch.Tensor: + sparse_feature = self.model_sparse(sparse_indices, sparse_offsets) + out = self.dense_top(sparse_feature, dense) + + return out + + +class TestHelperModules: + class ControlFlow(torch.nn.Module): + def forward( + self, + xs: torch.Tensor, + pred1: torch.Tensor, + pred2: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + def true_nested(y: torch.Tensor) -> torch.Tensor: + y = y + y + y = torch.mm(y, y) + return y + + def false_nested(y: torch.Tensor) -> torch.Tensor: + return torch.mm(y, y) + + def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor: + z = control_flow.cond(pred2, true_nested, false_nested, [x]) + return x + z + + def false_fn(x: torch.Tensor, _) -> torch.Tensor: + return x.cos() + + def map_fn( + x: torch.Tensor, + pred1: torch.Tensor, + pred2: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + x = x.cos() + y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2]) + x = x + y + return x.sin() + + y = torch.mm(y, y) + return control_flow.map(map_fn, xs, pred1, pred2, y) + + def example_inputs(self): + return ( + torch.ones(2, 2), + torch.tensor([False]), + torch.tensor([False]), + torch.ones(2, 2), + ) + + class Conv2dPropAnnotaton(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + x = self.conv(x) + x = x.view(-1, 3) + x = torch.nn.functional.hardtanh(x, -0.5, 0.5) + x = self.linear(x) + return x + + class Conv2dWithObsSharingOps(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.hardtanh = torch.nn.Hardtanh() + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + x = self.conv(x) + x = self.adaptive_avg_pool2d(x) + x = self.hardtanh(x) + x = torch.mean(x) + return x + + class Conv2dWithTwoLinearPermute(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3) + self.linear1 = torch.nn.Linear(16, 8, bias=False) + self.linear2 = torch.nn.Linear(8, 8) + + def forward(self, x): + conv_out = self.conv(x) + permute_out = torch.permute(conv_out, (0, 2, 3, 1)) + return self.linear2(self.linear1(permute_out)) + + class Conv2dWithTwoLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3) + self.linear1 = torch.nn.Linear(64, 8, bias=False) + self.linear2 = torch.nn.Linear(8, 8) + + def forward(self, x): + conv_out = self.conv(x) + reshape_out = torch.reshape(conv_out, (2, 64)) + return self.linear2(self.linear1(reshape_out)) + + class ConvLinearWPermute(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 8, 3) + self.linear1 = torch.nn.Linear(8, 8) + + def forward(self, x): + conv_out = self.conv(x) + permute_out = torch.permute(conv_out, (0, 2, 3, 1)) + return self.linear1(permute_out) + + class TwoLinearModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(8, 16, bias=False) + self.linear2 = torch.nn.Linear(16, 8) + + def forward(self, x): + return self.linear2(self.linear1(x)) + + def example_inputs(self): + return (torch.randn(2, 8),) + + class ConvMaxPool2d(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(2, 2, 1) + self.pool = torch.nn.MaxPool2d(1, 1) + + def forward(self, x): + x = self.conv(x) + x = self.pool(x) + return x + + class ConvWithAdaptiveAvgPool2d(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + x = self.conv(x) + x = self.adaptive_avg_pool2d(x) + return x + + class ConvWithBNRelu(torch.nn.Module): + def __init__(self, relu, dim=2, bn=True, bias=True, padding=0): + super().__init__() + convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} + bns = { + 1: torch.nn.BatchNorm1d, + 2: torch.nn.BatchNorm2d, + 3: torch.nn.BatchNorm3d, + } + self.conv = convs[dim](3, 3, 3, bias=bias, padding=padding) + + if bn: + self.bn = bns[dim](3) + else: + self.bn = torch.nn.Identity() + if relu: + self.relu = torch.nn.ReLU() + else: + self.relu = torch.nn.Identity() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + class ConvTWithBNRelu(torch.nn.Module): + def __init__(self, relu, dim=2, bn=True, bias=True): + super().__init__() + convts = {1: torch.nn.ConvTranspose1d, 2: torch.nn.ConvTranspose2d} + bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d} + self.convt = convts[dim](3, 3, 3, bias=bias) + + if bn: + self.bn = bns[dim](3) + else: + self.bn = torch.nn.Identity() + if relu: + self.relu = torch.nn.ReLU() + else: + self.relu = torch.nn.Identity() + + def forward(self, x): + x = self.convt(x) + x = self.bn(x) + return self.relu(x) + + class Conv2dThenConv1d(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1d = torch.nn.Conv1d(3, 3, 3) + self.conv2d = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + x = self.conv2d(x) + x = x.squeeze(0) + x = self.conv1d(x) + return x + + def example_inputs(self): + return (torch.randn(1, 3, 5, 5),) + + class Conv2dWithCat(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x, y): + x = self.conv1(x) + y = self.conv2(y) + z = torch.cat([x, y], dim=1) + return z + + class Conv2dWithTwoCat(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x1, x2, x3, x4): + x1 = self.conv1(x1) + x2 = self.conv2(x2) + y = torch.cat([x1, x2], dim=1) + z = x3 + x4 + w = torch.cat([z, y]) + return w + + class Conv2dWithSplit(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + x = self.conv1(x) + # use split so we get a list of Tensors + x1, x2 = torch.split(x, 2, dim=1) + y = torch.cat([x1, x2], dim=1) + return y + + def example_inputs(self): + return (torch.randn(1, 3, 16, 16),) + + class ThreeAdd(torch.nn.Module): + def forward(self, x1, x2, x3, x4): + y = x1 + x2 + z = x3 + x4 + w = y + z + return w + + class EmbeddingModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) + + def forward(self, indices): + return self.emb(indices) + + class EmbeddingConvLinearModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8) + self.conv = torch.nn.Conv2d(8, 16, (1, 3)) + self.linear = torch.nn.Linear(16, 8) + + def forward(self, indices): + embeddings = self.emb(indices) + embeddings = torch.unsqueeze(embeddings, dim=0) + embeddings = torch.permute(embeddings, (0, 3, 1, 2)) + conv_out = self.conv(embeddings) + conv_out = torch.permute(conv_out, (0, 2, 3, 1)) + conv_out = torch.squeeze(conv_out, dim=0) + return self.linear(conv_out) + + class AddInplaceAdd(torch.nn.Module): + def forward(self, x, y): + x = x + y + x += y + return x + + class MulInplaceMul(torch.nn.Module): + def forward(self, x, y): + x = x * y + x *= y + return x + + class AddMulScalar(torch.nn.Module): + def forward(self, x): + x = x + 3 + x = x * 3 + x += 3 + x *= 3 + return x + + class ConvBnReLU2dAndLinearReLU(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv_bn_relu = TestHelperModules.ConvWithBNRelu(relu=True) + self.linear = torch.nn.Linear(3, 8, bias=False) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.conv_bn_relu(x) + permute_out = torch.permute(x, (0, 2, 3, 1)) + linear_out = self.linear(permute_out) + return linear_out + + class GroupwiseConv2d(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(4, 4, 3, groups=2) + + def forward(self, x): + return self.conv(x) + + def example_inputs(self): + return (torch.randn(2, 4, 10, 10),) + + class LinearReluModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc(x)) + return x + + +def _generate_qdq_quantized_model( + mod, inputs, is_qat=False, is_dynamic=False, quantizer=None +): + def get_default_quantizer(is_qat, is_dynamic, inputs): + has_xpu = any( + isinstance(input, torch.Tensor) and input.device.type == "xpu" + for input in inputs + ) + if has_xpu: + quantizer = XPUInductorQuantizer() + assert (not is_qat) and ( + not is_dynamic + ), "QAT and dynamic quantization is not supported at XPU backend currently" + quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config()) + else: + quantizer = X86InductorQuantizer() + quantizer.set_global( + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, is_dynamic=is_dynamic + ) + ) + return quantizer + + maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() + with maybe_no_grad: + export_model = export(mod, inputs, strict=True).module(check_guards=False) + quantizer = ( + quantizer + if quantizer + else get_default_quantizer(is_qat, is_dynamic, inputs) + ) + prepare_model = ( + prepare_qat_pt2e(export_model, quantizer) + if is_qat + else prepare_pt2e(export_model, quantizer) + ) + prepare_model(*inputs) + torch.ao.quantization.move_exported_model_to_eval(prepare_model) + convert_model = convert_pt2e(prepare_model) + return convert_model diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/composite_compliance.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/composite_compliance.py new file mode 100644 index 0000000000000000000000000000000000000000..773bea63eef82f7dd83034d764a484a6085ed3ab --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/composite_compliance.py @@ -0,0 +1,608 @@ +# mypy: ignore-errors + +import torch +from torch import Tensor +import itertools + +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten +from torch.utils import _pytree as pytree +from functools import partial +from torch.utils._mode_utils import no_dispatch, all_same_mode +import torch.autograd.forward_ad as fwAD +from collections.abc import Callable +import re + + +def check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor): + elem = wrapper_tensor.elem + metadata_wrapper_tensor = metadata_accessor(wrapper_tensor) + metadata_elem = metadata_accessor(elem) + if metadata_wrapper_tensor == metadata_elem: + return + raise RuntimeError( + f"This operator is not Composite Compliant: the " + f"{metadata_name} of the tensor was modified directly without " + f"going through the PyTorch dispatcher.") + +def check_metadata_consistency(wrapper_tensor, CCT): + # CCT: CompositeCompliantTensor class which is generated using generate_cct + if not isinstance(wrapper_tensor, CCT): + return + things_to_check = { + 'shape': Tensor.size, + 'dtype': lambda x: x.dtype, + 'device': lambda x: x.device, + 'numel': Tensor.numel, + 'stride': Tensor.stride, + 'storage_offset': Tensor.storage_offset, + } + for metadata_name, metadata_accessor in things_to_check.items(): + check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor) + +def is_view_fn(func): + return func.overloadpacket.__name__ in { + 'as_strided', + 'detach', + 'diagonal', + 'expand', + 'expand_as', + 'movedim', + 'narrow', + 'permute', + 'select', + 'squeeze', + 'transpose', + 't', + 'real', + 'imag', + 'view_as_real', + 'view_as_complex', + 'unflatten', + 'unfold', + 'unsqueeze', + 'view', + 'view_as', + 'unbind', + 'split', + 'split_with_sizes', + 'vsplit', + 'hsplit', + 'tensor_split', + 'chunk', + 'swapaxes', + 'slice', + '_reshape_alias', + '_unsafe_view', + '_conj', + 'alias', + } + +# manually populated from native_functions that have inplace_view: True. +# In the future we will probably be able to grab that list directly +def is_inplace_view_fn(func): + return func.overloadpacket.__name__ in { + 'as_strided_', + 'detach_', + 'squeeze_', + 'swapaxes_', + 'swapdims_', + 't_', + 'transpose_', + 'unsqueeze_', + } + + +# Introspection please save us +def is_inplace(func): + name = func.overloadpacket.__name__ + if re.match('__i.+__', name): + return True + if re.match('__.+__', name): + return False + return name[-1] == '_' + + +def generate_cct_and_mode(autograd_view_consistency=True): + # This function returns a new class CompositeCompliantTensor + # The two arguments control the behaviour described below. + + # autograd_view_consistency: + # If True, alias result using `set_` if func returns a view + # (See Note [Alias Result]). + # Since Forward AD doesn't work with `set_` + # we disable it by setting alias to False. + + class CompositeCompliantTensor(torch.Tensor): + elem: torch.Tensor + + __slots__ = ['elem'] + + @staticmethod + def __new__(cls, elem, mode, *args, **kwargs): + assert type(elem) is not cls, \ + "Wrapping a CompositeCompliantTensor in a CompositeCompliantTensor is not supported" + + # The storage of CompositeCompliantTensor should never be used directly + # by a Composite operation; if the Composite + # operator attempts to read from the storage without dispatching then it'll + # raise a RuntimeError due to it being a meta storage. + r = torch.Tensor._make_wrapper_subclass( + cls, elem.size(), + dtype=elem.dtype, layout=elem.layout, + device=elem.device, requires_grad=elem.requires_grad, + strides=elem.stride(), storage_offset=elem.storage_offset()) + + if elem.requires_grad: + # CompositeCompliantTensor steals the "requires_grad"-ness. + # Why a new copy of `elem`? Because sometimes OpInfo shares inputs between tests... + tmp = torch.empty( + (), + dtype=elem.dtype, + device=elem.device, + layout=elem.layout, + requires_grad=False, + ) + # Use set_ rather than empty_strided() + copy_ so that we can preserve + # things like storage_offset. + tmp.set_( + source=elem.untyped_storage().clone(), + storage_offset=elem.storage_offset(), + size=elem.size(), + stride=elem.stride(), + ) + r.elem = tmp + else: + r.elem = elem + + assert r.stride() == r.elem.stride() + + # Propagate conjugate bits to the wrapper tensor + # Ref: https://github.com/albanD/subclass_zoo/issues/24 + # Ref: https://github.com/albanD/subclass_zoo/issues/21 + torch._C._set_conj(r, r.elem.is_conj()) + torch._C._set_neg(r, r.elem.is_neg()) + + r.mode = mode + return r + + def __repr__(self): + return f"CompositeCompliantTensor({self.elem})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + all_args = pytree.arg_tree_leaves(*args, **(kwargs or {})) + modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor)) + if not all_same_mode(modes): + raise RuntimeError("Multiple CompositeCompliantTensorModes NYI") + with modes[0]: + return func(*args, **kwargs) + + class CompositeCompliantTensorMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + def unwrap(e): + return e.elem if isinstance(e, CompositeCompliantTensor) else e + + def wrap(e): + return CompositeCompliantTensor(e, self) if isinstance(e, torch.Tensor) else e + + if func is torch.ops.aten._local_scalar_dense.default: + raise RuntimeError( + ".item() is not allowed to be called inside of composite " + "functions in the PyTorch library because not all backends " + "and/or Tensor subclasses (e.g. vmap, ProxyTensor) support them.") + + if func.overloadpacket.__name__ in ('set_', 'resize_'): + raise RuntimeError( + f"{func.__name__} is not allowed to be called inside of " + f"Composite operators.") + + if is_inplace(func): + # NB: We are making an assumption that if the function is in-place, + # then the first argument is being written to. Introspection please save us! + mutated_argument = args[0] + if not isinstance(mutated_argument, CompositeCompliantTensor) and \ + any(isinstance(a, CompositeCompliantTensor) for a in args[1:]): + raise RuntimeError( + 'Not composite compliant: performing in-place operation ' + f'{func.__name__} where the Tensor being written to is ' + 'regular Tensor but the other tensors are Tensor Subclasses. ' + 'Please try to avoid this in-place operation.') + + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + unwrapped_rs = func(*unwrapped_args, **unwrapped_kwargs) + rs = tree_map(wrap, unwrapped_rs) + + if is_view_fn(func) and autograd_view_consistency: + # Note [Alias Result] + # Autograd asserts that for B = A.view_fn(...), B and A's storages + # are the same. Here we try to make B alias A to avoid those asserts. + # See https://github.com/pytorch/pytorch/issues/65339 for more information + # about the issue. + with no_dispatch(): + # Idea: this is a weird way of getting a storage that aliases the input. + # This is a workaround for #65339. + # 1. under no_dispatch, all of the wrapper tensors look like regular + # tensors with special storage (the storage is nullptr and + # advertises CPU/CUDA device. + # 2. we run func, which ends up running the view operation + # 3. All view operations reuse the input's storage and return + # result Tensor(s) with new sizes/strides/offset that alias + # the input. + # 4. we set the storage (and sizes/strides/offset) of the wrapper + # tensor results to be that of the tensors that alias the input + result = func(*args, **kwargs) + if isinstance(result, (tuple, list)): + for a, b in zip(rs, result, strict=True): + a.set_(b) + else: + rs.set_(result) + + # Some operations are allowed to in-place modify the metadata of the + # inputs. The only ones are the "inplace view functions"; when we + # run into these, we manually modify the metadata of the input. + with no_dispatch(): + if is_inplace_view_fn(func): + func(*args, **kwargs) + + # For each CompositeCompliantTensor t, we check that t and t.elem + # have consistent metadata. If they don't have consistent metadata, + # that means the operator did something fishy. + check = partial(check_metadata_consistency, CCT=CompositeCompliantTensor) + pytree.tree_map_(check, args) + pytree.tree_map_(check, kwargs) + pytree.tree_map_(check, rs) + return rs + + return CompositeCompliantTensor, CompositeCompliantTensorMode() + +def is_tensorlist(lst): + if not isinstance(lst, list) and not isinstance(lst, tuple): + return False + if len(lst) == 0: + return False + all_tensors = all(isinstance(elt, torch.Tensor) for elt in lst) + if all_tensors: + return True + exists_one_tensor = all(isinstance(elt, torch.Tensor) for elt in lst) + if exists_one_tensor: + raise RuntimeError('This test assumes that PyTorch APIs cannot take ' + 'mixed lists of Tensor and other things') + return False + + +def maybe_map(fn, should_map, arg): + return fn(arg) if should_map else arg + + +def wrap(arg, CCT, cct_mode): + # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode + if isinstance(arg, torch.Tensor): + return CCT(arg, cct_mode) + if is_tensorlist(arg): + return [CCT(a, cct_mode) for a in arg] + raise RuntimeError("wrap assumes that the input can be wrapped") + + +# Given a list of flat arguments, some of which may be Tensors, return all +# possible ways some of the arguments could be CompositeCompliantTensors (CCT). +# For example, given Tensors A, B, C and flat_args = [A, 1, B], +# We would return the following 4 options: +# [CCT(A), 1, CCT(B)] +# [CCT(A), 1, B] +# [A, 1, CCT(B)] +# [A, 1, B] +# NB: Yes, this is exponential. No, we don't care too much because PyTorch ops +# don't accept that many input Tensors. +def generate_subclass_choices(flat_args, CCT, cct_mode): + # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode + is_tensor_likes = [isinstance(arg, torch.Tensor) or is_tensorlist(arg) for arg in flat_args] + subclass_options = [[False, True] if is_tensor_like else [False] for is_tensor_like in is_tensor_likes] + + for which_args_are_wrapped in itertools.product(*subclass_options): + + result = [maybe_map(partial(wrap, CCT=CCT, cct_mode=cct_mode), should_wrap_arg, arg) + for should_wrap_arg, arg in zip(which_args_are_wrapped, flat_args, strict=True)] + yield result, which_args_are_wrapped + + +# For an operation f(*args, **kwargs), each Tensor argument may either be +# a regular Tensor or a Tensor Subclass. This iterator iterates through +# all of those options. +def generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): + # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode + flat_kwargs, spec = tree_flatten(kwargs) + flat_args_kwargs = list(args) + list(flat_kwargs) + for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT, cct_mode): + new_args = choice[:len(args)] + new_kwargs = tree_unflatten(choice[len(args):], spec) + which_args_are_wrapped = debug_metadata[:len(args)] + which_kwargs_are_wrapped = tree_unflatten(debug_metadata[len(args):], spec) + yield new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped + + +def raise_composite_compliance_error(err, additional_info=''): + raise RuntimeError( + "Composite compliance check failed with " + "the above error.\n" + f"{additional_info}" + "If you are adding an OpInfo of an " + "existing operator, please feel free to skip this test " + "because the problem was pre-existing and file an issue. " + "Otherwise, if you added a new operator, please read " + "through the Composite Compliance section in " + "aten/src/ATen/native/README.md for how to resolve this. " + ) from err + + +# This test checks ALL possible permutations of calling `op` with arguments +# that are individually either a regular Tensor or a Tensor subclass. +# +# The general strategy is to wrap some Tensor args and kwargs in +# CompositeCompliantTensor wrappers and call the operation. + +# If some composite operation does any non-compliant behavior, +# CompositeCompliantTensor will raise an error. +def check_all_permutations(op, args, kwargs, assert_equal_fn): + CCT, cct_mode = generate_cct_and_mode() + expected = op(*args, **kwargs) + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): + new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice + + try: + actual = op(*new_args, **new_kwargs) + # NOTE: [What errors are Composite Compliance trying to catch?] + # + # There's two things we want to catch: + # - errors that would raise within the torch_dispatch impl + # - data_ptr accesses + # The first is easy to filter for (we could make the error a different + # error class), the second is always going to be a RuntimeError due to + # how it is implemented (if you try to access the data_ptr of the + # wrapper Tensor, it raises you some internal RuntimeError). + # + # So the most general thing to catch here was RuntimeError. If you + # are here and debugging why your test failed, it's plausible that + # the operator itself is broken and that there are other tests failing. + except RuntimeError as err: + raise_composite_compliance_error( + err, + f"- wrapped_args: {which_args_are_wrapped}\n" + f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" + ) + + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + assert_equal_fn(tree_map(unwrap, actual), expected) + +# Checks via the usage of torch dispatch mode certain anti-patterns that +# are not composite compliant. +# +# In particular, the anti-pattern we are trying to prevent is a user +# creating an empty tensor and then resize_-ing it. Torch Dispatch Mode helps +# here because all factory functions will create tensors that are +# CompositeCompliantTensor. +# +# The general strategy is to wrap all Tensor args and kwargs in +# CompositeCompliantTensor wrappers. If an operator that is +# Composite does any non-compliant behavior, +# CompositeCompliantTensor will raise an error. +def check_with_mode(op, args, kwargs, assert_equal_fn): + CCT, cct_mode = generate_cct_and_mode() + + def wrap(e): + return CCT(e, cct_mode) if isinstance(e, torch.Tensor) else e + + expected = op(*args, **kwargs) + + args = tree_map(wrap, args) + kwargs = tree_map(wrap, kwargs) + try: + with cct_mode: + actual = op(*args, **kwargs) + # see NOTE: [What errors are Composite Compliance trying to catch?] + except RuntimeError as err: + raise_composite_compliance_error(err) + + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + assert_equal_fn(tree_map(unwrap, actual), expected) + +def gather_leaf_tensors(args, kwargs): + leaf_tensors = [] + args, _args_spec = tree_flatten(args) + kwargs, _kwargs_spec = tree_flatten(kwargs) + args = args + kwargs + for arg in args: + if not isinstance(arg, torch.Tensor): + continue + if arg.requires_grad: + leaf_tensors.append(arg) + return leaf_tensors + + +def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradcheck_wrapper=None): + if gradcheck_wrapper is None: + results = op(*args, **kwargs) + else: + results = gradcheck_wrapper(op, *args, **kwargs) + + if output_process_fn_grad is not None: + results = output_process_fn_grad(results) + + flat_results = pytree.tree_leaves(results) + flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)] + flat_diff_results = [r for r in flat_results if r.requires_grad] + assert len(flat_diff_results) > 0 + + grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) for r in flat_diff_results] + leaf_tensors = gather_leaf_tensors(args, kwargs) + assert len(leaf_tensors) > 0 + return torch.autograd.grad(flat_diff_results, leaf_tensors, + grads, allow_unused=True, retain_graph=True) + + +# Checks if the backward formula is composite compliant by testing +# all possible permutations of {inputs, grad_outputs} being +# CompositeCompliantTensor or regular Tensors. +# +# NB: it is important that op is accepted as a Callable and not an OpInfo, +# this means we can apply check_backward_formula to things that aren't OpInfos +# while debugging. +def check_backward_formula(op: Callable, args, kwargs, + output_process_fn_grad=None, + gradcheck_wrapper=None, assert_equal_fn=None): + CCT, cct_mode = generate_cct_and_mode() + + expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper) + + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): + new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice + leaf_tensors = gather_leaf_tensors(new_args, new_kwargs) + assert len(leaf_tensors) > 0 + + try: + if gradcheck_wrapper is None: + results = op(*new_args, **new_kwargs) + else: + results = gradcheck_wrapper(op, *new_args, **new_kwargs) + if output_process_fn_grad is not None: + results = output_process_fn_grad(results) + # see NOTE: [What errors are Composite Compliance trying to catch?] + except RuntimeError as err: + raise_composite_compliance_error( + err, + f"- wrapped_args: {which_args_are_wrapped}\n" + f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" + ) + + flat_results = pytree.tree_leaves(results) + flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)] + flat_diff_results = [r for r in flat_results if r.requires_grad] + assert len(flat_diff_results) > 0 + + # NB: ones, not ones_like, so we get a regular Tensor here + grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) + for r in flat_diff_results] + for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT, cct_mode): + try: + actual = torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads, + allow_unused=True, retain_graph=True) + # see NOTE: [What errors are Composite Compliance trying to catch?] + except RuntimeError as err: + raise_composite_compliance_error( + err, + f"- wrapped_args: {which_args_are_wrapped}\n" + f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" + f"- wrapped_grads: {which_grad_is_batched}\n" + ) + + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + assert_equal_fn(tuple(map(unwrap, actual)), expected, equal_nan=True) + +# Checks if the forward AD formula is composite compliant by testing +# all possible permutations of {primals, tangents} being +# CompositeCompliantTensor or regular Tensors. +# +# NB: it is important that op is accepted as a Callable and not an OpInfo, +# this means we can apply check_forward_ad_formula to things that aren't OpInfos +# while debugging. +def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, assert_equal_fn=None): + CCT, cct_mode = generate_cct_and_mode(autograd_view_consistency=False) + + def maybe_tangent(t): + assert type(t) is not CCT + # Generate `tangent` tensor + # if given object is a Tensor and requires grad is set. + if isinstance(t, torch.Tensor) and t.requires_grad: + return torch.randn_like(t) + elif is_tensorlist(t): + return [torch.randn_like(e) if e.requires_grad else None for e in t] + return None + + tangent_args = tuple(maybe_tangent(arg) for arg in args) + flat_kwargs, spec = tree_flatten(kwargs) + flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs) + tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec) + + with fwAD.dual_level(): + def maybe_make_dual(dual): + # Returns dual tensor if primal is a tensor/tensor subclass + # with requires_grad set. + primal, tangent = dual + if isinstance(primal, torch.Tensor) and primal.requires_grad: + return fwAD.make_dual(primal.detach(), tangent) + elif is_tensorlist(primal): + return tuple(fwAD.make_dual(pri.detach(), tang) if tang is not None else pri + for pri, tang in zip(primal, tangent, strict=True)) + return primal + + def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs): + op_args = tuple(map(maybe_make_dual, zip(args, tangent_args, strict=True))) + op_kwargs = {k: maybe_make_dual((v, tangent_kwargs[k])) for k, v in kwargs.items()} + + if gradcheck_wrapper is None: + return op(*op_args, **op_kwargs) + return gradcheck_wrapper(op, *op_args, **op_kwargs) + + expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs) + expected = tree_map(fwAD.unpack_dual, expected) + expected_primals = tree_map( + lambda x: x.primal, + expected, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + expected_tangents = tree_map( + lambda x: x.tangent, + expected, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + + # Permutations of arg and kwargs in CCT. + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): + new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice + + # Permutations tangent arg and tangent kwargs in CCT. + for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT, cct_mode): + new_tang_args, new_tang_kwargs, \ + which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice + + op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args, strict=True))) + op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()} + + try: + if gradcheck_wrapper is None: + actual = op(*op_args, **op_kwargs) + else: + actual = gradcheck_wrapper(op, *op_args, **op_kwargs) + # see NOTE: [What errors are Composite Compliance trying to catch?] + except RuntimeError as err: + raise_composite_compliance_error( + err, + f"- wrapped_args: {which_args_are_wrapped}\n" + f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" + f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n" + f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n" + ) + + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + actual = tree_map(fwAD.unpack_dual, actual) + actual_primals = tree_map( + lambda x: unwrap(x.primal), + actual, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + actual_tangents = tree_map( + lambda x: unwrap(x.tangent), + actual, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + assert_equal_fn(actual_primals, expected_primals, equal_nan=True) + assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/custom_op_db.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/custom_op_db.py new file mode 100644 index 0000000000000000000000000000000000000000..32982d0a3e2a358a2530abd234b37a24c6efe77d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/custom_op_db.py @@ -0,0 +1,585 @@ +# mypy: allow-untyped-defs +import torch +import functools +from torch.testing import make_tensor +from torch.testing._internal.opinfo.core import ( + OpInfo, + SampleInput, +) +from torch.testing._internal.common_dtype import all_types_and +import numpy as np +from torch.testing._internal.autograd_function_db import ( + sample_inputs_numpy_cube, + sample_inputs_numpy_mul, + sample_inputs_numpy_mul_scalar, + sample_inputs_numpy_sort, + sample_inputs_numpy_take, +) +from torch import Tensor +from torch.types import Number +from typing import * # noqa: F403 + +# Note: [custom op db] +# +# This is a collection of custom operator test cases written as OpInfos +# so they can easily be consumed by OpInfo-based tests to check if subsystems +# support them correctly. + +def to_numpy(tensor): + return tensor.cpu().numpy() + +@torch.library.custom_op("_torch_testing::numpy_cube", mutates_args=()) +def numpy_cube(x: Tensor) -> tuple[Tensor, Tensor]: + x_np = to_numpy(x) + dx = torch.tensor(3 * x_np ** 2, device=x.device) + return torch.tensor(x_np ** 3, device=x.device), dx + +@numpy_cube.register_fake +def _(x): + return x.clone(), x.clone() + +def numpy_cube_setup_context(ctx, inputs, output): + x, = inputs + _cube, dx = output + ctx.save_for_backward(x, dx) + +def numpy_cube_backward(ctx, grad_out, grad_dx): + x, dx = ctx.saved_tensors + grad_x = numpy_mul(grad_out, dx) + 6 * numpy_mul(grad_dx, x) + return grad_x + +numpy_cube.register_autograd(numpy_cube_backward, setup_context=numpy_cube_setup_context) + +def numpy_cube_vmap(info, in_dims, x): + result = numpy_cube(x) + return result, (in_dims[0], in_dims[0]) + +numpy_cube.register_vmap(numpy_cube_vmap) + +@torch.library.custom_op("_torch_testing::numpy_mul", mutates_args=()) +def numpy_mul(x: Tensor, y: Tensor) -> Tensor: + return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) + +@numpy_mul.register_fake +def _(x, y): + assert x.device == y.device + return (x * y).contiguous() + +def numpy_mul_setup_context(ctx, inputs, output): + ctx.save_for_backward(*inputs) + +def numpy_mul_backward(ctx, grad_out): + x, y = ctx.saved_tensors + grad_x = grad_out * y if ctx.needs_input_grad[0] else None + grad_y = grad_out * x if ctx.needs_input_grad[1] else None + return grad_x, grad_y + +numpy_mul.register_autograd(numpy_mul_backward, setup_context=numpy_mul_setup_context) + +def numpy_mul_vmap(info, in_dims, x, y): + x_bdim, y_bdim = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + result = x * y + result = result.movedim(-1, 0) + return result, 0 + +numpy_mul.register_vmap(numpy_mul_vmap) + +@torch.library.custom_op("_torch_testing::numpy_mul_scalar", mutates_args=()) +def numpy_mul_scalar(x: Tensor, *, scalar: float) -> Tensor: + return torch.tensor(to_numpy(x) * scalar, device=x.device) + +@numpy_mul_scalar.register_fake +def _(x, *, scalar): + return (x * scalar).contiguous() + +def numpy_mul_scalar_setup_context(ctx, inputs, keyword_only_inputs, output): + ctx.scalar = keyword_only_inputs["scalar"] + +def numpy_mul_scalar_backward(ctx, grad_out): + grad_x = grad_out * ctx.scalar + return grad_x + +numpy_mul_scalar.register_autograd(numpy_mul_scalar_backward, setup_context=numpy_mul_scalar_setup_context) + +def numpy_mul_scalar_vmap(info, in_dims, x, *, scalar): + x_bdim, = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + result = x * scalar + result = result.movedim(-1, 0) + return result, 0 + +numpy_mul_scalar.register_vmap(numpy_mul_scalar_vmap) + +@torch.library.custom_op("_torch_testing::numpy_sort", mutates_args=()) +def numpy_sort(x: Tensor, dim: int) -> tuple[Tensor, Tensor, Tensor]: + device = x.device + x = to_numpy(x) + ind = np.argsort(x, axis=dim) + ind_inv = np.argsort(ind, axis=dim) + result = np.take_along_axis(x, ind, axis=dim) + return ( + torch.tensor(result, device=device), + torch.tensor(ind, device=device), + torch.tensor(ind_inv, device=device), + ) + +@numpy_sort.register_fake +def _(x, dim): + return torch.empty_like(x), torch.empty_like(x, dtype=torch.long), torch.empty_like(x, dtype=torch.long) + +def numpy_sort_setup_context(ctx, inputs, output): + _out, ind, ind_inv = output + ctx.dim = inputs[1] + ctx.save_for_backward(ind, ind_inv) + ctx.mark_non_differentiable(ind, ind_inv) + +def numpy_sort_backward(ctx, grad_out, grad_ind, grad_ind_inv): + ind, ind_inv = ctx.saved_tensors + return numpy_take(grad_out, ind_inv, ind, ctx.dim), None + +numpy_sort.register_autograd(numpy_sort_backward, setup_context=numpy_sort_setup_context) + +def numpy_sort_vmap(info, in_dims, x, dim): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 0) + dim = dim if dim >= 0 else dim + x.dim() - 1 + result = numpy_sort(x, dim + 1) + return result, (0, 0, 0) + +numpy_sort.register_vmap(numpy_sort_vmap) + +@torch.library.custom_op("_torch_testing::numpy_take", mutates_args=()) +def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor: + device = x.device + x = to_numpy(x) + ind = to_numpy(ind) + return torch.tensor(np.take_along_axis(x, ind, dim), device=device) + +@numpy_take.register_fake +def _(x, ind, ind_inv, dim): + assert x.device == ind.device + assert x.device == ind_inv.device + assert ind.dtype == torch.long + assert ind_inv.dtype == torch.long + return torch.empty_like(x) + +def numpy_take_setup_context(ctx, inputs, output): + _x, ind, ind_inv, dim = inputs + ctx.dim = dim + ctx.save_for_backward(ind, ind_inv) + +def numpy_take_backward(ctx, grad_out): + ind, ind_inv = ctx.saved_tensors + grad_x = numpy_take(grad_out, ind_inv, ind, ctx.dim) + return grad_x, None, None, None + +numpy_take.register_autograd(numpy_take_backward, setup_context=numpy_take_setup_context) + +def numpy_take_vmap(info, in_dims, x, ind, ind_inv, dim): + x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims + + # wrap dim + logical_dim = x.dim() if x_bdim is None else x_bdim - 1 + dim = dim if dim >= 0 else dim + logical_dim + + def expand_bdim(x, x_bdim): + if x_bdim is None: + return x.expand(info.batch_size, *x.shape) + return x.movedim(x_bdim, 0) + + x = expand_bdim(x, x_bdim) + ind = expand_bdim(ind, ind_bdim) + ind_inv = expand_bdim(ind_inv, ind_inv_bdim) + + return numpy_take(x, ind, ind_inv, dim + 1), 0 + +numpy_take.register_vmap(numpy_take_vmap) + +@torch.library.custom_op("_torch_testing::numpy_nonzero", mutates_args=()) +def numpy_nonzero(x: Tensor) -> Tensor: + x_np = to_numpy(x) + res = np.stack(np.nonzero(x_np), axis=1) + if res.shape[0] <= 1: + raise RuntimeError("not supported") + return torch.tensor(res, device=x.device) + +@numpy_nonzero.register_fake +def _(x): + ctx = torch._custom_op.impl.get_ctx() + i0 = ctx.create_unbacked_symint() + shape = [i0, x.dim()] + result = x.new_empty(shape, dtype=torch.long) + return result + +def sample_inputs_numpy_nonzero(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + shape = 10 + result = make_arg(shape, low=0.9, high=2) + mask = make_tensor(shape, low=0, high=2, device=device, dtype=torch.long) + with torch.no_grad(): + result *= mask + + yield SampleInput(result, args=()) + +def numpy_nonzero_vmap(info, in_dims, x): + raise NotImplementedError("Operator is data-dependent and cannot be vmapped.") + +numpy_nonzero.register_vmap(numpy_nonzero_vmap) + +@torch.library.custom_op("_torch_testing::numpy_view_copy", mutates_args=()) +def numpy_view_copy(x: Tensor, shape: Sequence[int]) -> Tensor: + return torch.tensor(np.copy(to_numpy(x).reshape(shape)), device=x.device) + +@numpy_view_copy.register_fake +def _(x, shape) -> Tensor: + return x.clone().view(shape).clone() + +def numpy_view_copy_setup_context(ctx, inputs, output) -> None: + ctx.x_shape = inputs[0].shape + +def numpy_view_copy_backward(ctx, grad_out): + return torch.ops._torch_testing.numpy_view_copy(grad_out, ctx.x_shape), None + +numpy_view_copy.register_autograd(numpy_view_copy_backward, setup_context=numpy_view_copy_setup_context) + +def numpy_view_copy_vmap(info, in_dims, x, shape): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 0) + x_shape = x.shape[0] + batch_shape = (x_shape, *shape) + result = numpy_view_copy(x, batch_shape) + return result, 0 + +numpy_view_copy.register_vmap(numpy_view_copy_vmap) + +def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + result = make_arg(2, 3, 4, low=0.9, high=2) + yield SampleInput(result, args=([2, 12],)) + +@torch.library.custom_op('_torch_testing::numpy_cat', mutates_args=()) +def numpy_cat(xs: Sequence[Tensor], dim: int) -> Tensor: + assert len(xs) > 0 + assert all(x.device == xs[0].device for x in xs) + assert all(x.dtype == xs[0].dtype for x in xs) + np_xs = [to_numpy(x) for x in xs] + np_out = np.concatenate(np_xs, axis=dim) + return torch.tensor(np_out, device=xs[0].device) + +@numpy_cat.register_fake +def _(xs, dim): + assert len(xs) > 0 + assert all(x.device == xs[0].device for x in xs) + assert all(x.dtype == xs[0].dtype for x in xs) + return torch.cat(xs, dim=dim) + +def numpy_cat_setup_context(ctx, inputs, output): + xs, dim = inputs + ctx.dim_sizes = [x.shape[dim] for x in xs] + ctx.dim = dim + +def numpy_cat_backward(ctx, grad_out): + dim_sizes = ctx.dim_sizes + dim = ctx.dim + + splits = list(np.cumsum(dim_sizes)[:-1]) + grad_xs = torch.ops._torch_testing.numpy_split_copy(grad_out, splits, dim) + return grad_xs, None + +numpy_cat.register_autograd(numpy_cat_backward, setup_context=numpy_cat_setup_context) + +def numpy_cat_vmap(info, in_dims, x, dim): + x_bdim, = in_dims + result = numpy_cat(x, dim) + return result, x_bdim + +numpy_cat.register_vmap(numpy_cat_vmap) + +def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + r0 = make_arg(2, 3, 4, low=0.9, high=2) + r1 = make_arg(4, 3, 4, low=0.9, high=2) + r2 = make_arg(5, 3, 4, low=0.9, high=2) + yield SampleInput([r0, r1, r2], args=(0,)) + +@torch.library.custom_op('_torch_testing::numpy_split_copy', mutates_args=()) +def numpy_split_copy(x: Tensor, splits: Sequence[int], dim: int) -> List[Tensor]: + x_np = to_numpy(x) + arrs = np.split(x_np, splits, axis=dim) + return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs] + +@numpy_split_copy.register_fake +def _(x, splits, dim): + return [xi.clone() for xi in torch.tensor_split(x, splits, dim)] + +def numpy_split_copy_setup_context(ctx, inputs, output): + _, _, dim = inputs + ctx.dim = dim + +def numpy_split_copy_backward(ctx, grad_out): + result = torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim) + return result, None, None + +numpy_split_copy.register_autograd(numpy_split_copy_backward, setup_context=numpy_split_copy_setup_context) + +def numpy_split_copy_vmap(info, in_dims, x, splits, dim): + x_bdim, _ , _ = in_dims + x = x.movedim(x_bdim, 0) + result = numpy_split_copy(x, splits, dim + 1) + return result, 0 + +numpy_split_copy.register_vmap(numpy_split_copy_vmap) + +def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + x = make_arg(2, 9, low=0.9, high=2) + yield SampleInput(x, args=([1, 3, 6], 1)) + +@torch.library.custom_op('_torch_testing::numpy_split_copy_with_int', mutates_args=()) +def numpy_split_copy_with_int(x: Tensor, splits: Sequence[int], dim: int) -> tuple[List[Tensor], int]: + x_np = to_numpy(x) + arrs = np.split(x_np, splits, axis=dim) + return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs], len(splits) + +@numpy_split_copy_with_int.register_fake +def _(x, splits, dim): + return [xi.clone() for xi in torch.tensor_split(x, splits, dim)], len(splits) + +def numpy_split_copy_with_int_setup_context(ctx, inputs, output): + _, _, dim = inputs + ctx.dim = dim + +def numpy_split_copy_with_int_backward(ctx, grad_out, _): + return torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim), None, None + +numpy_split_copy_with_int.register_autograd( + numpy_split_copy_with_int_backward, + setup_context=numpy_split_copy_with_int_setup_context) + +def numpy_split_copy_with_int_vmap(info, in_dims, x, splits, dim): + x_bdim, _ , _ = in_dims + x = x.movedim(x_bdim, 0) + result, len_split = numpy_split_copy_with_int(x, splits, dim + 1) + return (result, len_split), ([0 for _ in range(len(result))], None) + +numpy_split_copy_with_int.register_vmap(numpy_split_copy_with_int_vmap) + +@torch.library.custom_op("_torch_testing::numpy_nms", mutates_args=()) +def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor: + # Adapted from Ross Girshick's fast-rcnn implementation at + # https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py + assert boxes.device == scores.device + device = boxes.device + + boxes = to_numpy(boxes) + scores = to_numpy(scores) + + N = boxes.shape[0] + assert boxes.shape == (N, 4) + assert scores.shape == (N,) + + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= iou_threshold)[0] + order = order[inds + 1] + + result = torch.tensor(np.stack(keep), device=device) + # Needed for data-dependent condition :( + assert result.size(0) >= 2 + return result + +@numpy_nms.register_fake +def _(boxes, scores, iou_threshold): + assert boxes.device == scores.device + N = boxes.shape[0] + assert boxes.shape == (N, 4) + assert scores.shape == (N,) + + ctx = torch._custom_op.impl.get_ctx() + i0 = ctx.create_unbacked_symint() + result = boxes.new_empty([i0], dtype=torch.int64) + return result + +def numpy_nms_vmap(info, in_dims, boxes, scores, iou_threshold): + raise NotImplementedError("Operator is data-dependent and cannot be vmapped.") + +numpy_nms.register_vmap(numpy_nms_vmap) + +def sample_inputs_numpy_nms(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial(make_tensor, device=device, dtype=dtype) + N = 64 + xs = make_arg([N], low=0, high=28) + dx = make_arg([N], low=0, high=4) + ys = make_arg([N], low=0, high=28) + dy = make_arg([N], low=0, high=4) + boxes = torch.stack([xs, ys, xs + dx, ys + dy], dim=1).requires_grad_(requires_grad) + scores = make_arg([N], low=0, high=1, requires_grad=requires_grad) + iou_threshold = make_arg([], low=0, high=1).item() + + yield SampleInput(boxes, args=(scores, iou_threshold)) + +custom_op_db = [ + OpInfo( + 'NumpyCubeCustomOp', + op=numpy_cube._opoverload, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyMulCustomOp', + op=numpy_mul._opoverload, + sample_inputs_func=sample_inputs_numpy_mul, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyMulScalarCustomOp', + op=numpy_mul_scalar._opoverload, + sample_inputs_func=sample_inputs_numpy_mul_scalar, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpySortCustomOp', + op=numpy_sort._opoverload, + sample_inputs_func=sample_inputs_numpy_sort, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyTakeCustomOp', + op=numpy_take._opoverload, + sample_inputs_func=sample_inputs_numpy_take, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyNonzeroCustomOp', + op=numpy_nonzero._opoverload, + sample_inputs_func=sample_inputs_numpy_nonzero, + dtypes=all_types_and(torch.bool, torch.half), + supports_autograd=False, + supports_out=False, + ), + OpInfo( + 'NumpyNMSCustomOp', + op=torch.ops._torch_testing.numpy_nms, + sample_inputs_func=sample_inputs_numpy_nms, + dtypes=all_types_and(torch.bool, torch.half), + supports_autograd=False, + supports_out=False, + ), + OpInfo( + 'NumpyViewCopyCustomOp', + op=torch.ops._torch_testing.numpy_view_copy, + sample_inputs_func=sample_inputs_numpy_view_copy, + dtypes=all_types_and(torch.bool, torch.half), + supports_autograd=True, + supports_out=False, + ), + OpInfo( + 'NumpyCatCustomOp', + op=torch.ops._torch_testing.numpy_cat, + sample_inputs_func=sample_inputs_numpy_cat, + dtypes=all_types_and(torch.bool, torch.half), + supports_autograd=True, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_out=False, + ), + OpInfo( + 'NumpySplitCopyCustomOp', + op=torch.ops._torch_testing.numpy_split_copy, + sample_inputs_func=sample_inputs_numpy_split_copy, + dtypes=all_types_and(torch.bool, torch.half), + supports_autograd=True, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_out=False, + ), + OpInfo( + 'NumpySplitCopyWithIntCustomOp', + op=torch.ops._torch_testing.numpy_split_copy_with_int, + sample_inputs_func=sample_inputs_numpy_split_copy, + dtypes=all_types_and(torch.bool, torch.half), + gradcheck_wrapper=lambda op, *args, **kwargs: op(*args, **kwargs)[0], + supports_autograd=True, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_out=False, + ), +] + + +# ============================================================== +# some mechanical test cases +# ============================================================== + +lib = torch.library.Library("_torch_testing", "FRAGMENT") # noqa: TOR901 + +lib.define("source0(Tensor x) -> Tensor") + +@torch.library.register_fake("_torch_testing::source0", lib=lib) +def _(x): + return x.clone() + +lib.define("source1(Tensor x) -> Tensor") + +def source1_fake(x): + return x.clone() + +torch.library.register_fake("_torch_testing::source1", source1_fake, lib=lib) + +lib.define("source2(Tensor x) -> Tensor") + +@torch.library.register_fake("_torch_testing::source2", lib=lib) +def _(x): + return x.clone() + +lib.define("source3(Tensor x) -> Tensor") + +def source3_fake(x): + return x.clone() + +torch.library.register_fake("_torch_testing::source3", source3_fake, lib=lib) + + +@torch.library.custom_op("_torch_testing::source4", mutates_args=()) +def source4(x: Tensor) -> Tensor: + return x.clone() + +@source4.register_fake +def _(x): + return x.clone() + +@torch.library.custom_op("_torch_testing::source5", mutates_args=()) +def source5(x: Tensor) -> Tensor: + return x.clone() + +def source5_fake(x): + return x.clone() + +source5.register_fake(source5_fake) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/custom_tensor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/custom_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..de1b44ba8dac890142eaf2b013c2399eb59c2193 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/custom_tensor.py @@ -0,0 +1,160 @@ +# mypy: ignore-errors + + +from collections import namedtuple + +import torch +import torch.utils._pytree as pytree +from torch.utils._python_dispatch import return_and_correct_aliasing + + +FancyNamedTuple = namedtuple("FancyNamedTuple", ["foo", "bar"]) + + +# A simple tensor subclass that holds a tensor with custom metadata and custom method +class ConstantExtraMetadataTensor(torch.Tensor): + @staticmethod + def __new__(cls, elem): + shape = elem.shape + kwargs = {} + kwargs["strides"] = elem.stride() + kwargs["storage_offset"] = elem.storage_offset() + kwargs["device"] = elem.device + kwargs["layout"] = elem.layout + kwargs["requires_grad"] = elem.requires_grad + kwargs["dtype"] = elem.dtype + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, elem): + self.elem = elem + self.constant_attribute = 4 + + def __repr__(self): + inner_repr = repr(self.elem) + return f"CustomTensor({inner_repr})" + + def get_complicated_metadata(self): + return FancyNamedTuple(self.constant_attribute, self.constant_attribute) + + def __tensor_flatten__(self): + return ["elem"], self.constant_attribute + + def add_constant(self, a): + self.constant_attribute += a + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert meta is not None + elem = inner_tensors["elem"] + out = ConstantExtraMetadataTensor(elem) + out.constant_attribute = meta + return out + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args_inner = pytree.tree_map_only( + ConstantExtraMetadataTensor, lambda x: x.elem, args + ) + + kwargs_inner = pytree.tree_map_only( + ConstantExtraMetadataTensor, lambda x: x.elem, kwargs + ) + + out_inner = func(*args_inner, **kwargs_inner) + out_inner_flat, spec = pytree.tree_flatten(out_inner) + # for aten ops that return non-tensors, just assume that + # our cust inner tensors return the same value + out_flat = [ + ConstantExtraMetadataTensor(o_inner) + if isinstance(o_inner, torch.Tensor) + else o_inner + for o_inner in out_inner_flat + ] + out = pytree.tree_unflatten(out_flat, spec) + return return_and_correct_aliasing(func, args, kwargs, out) + + +# A simple tensor subclass that always returns plain tensor during __torch_dispatch__ +# It is similar to TwoTensor and is used to simulate torchao quantized tensors +class CustomTensorPlainOut(torch.Tensor): + @staticmethod + def __new__(cls, elem1, elem2): + shape = elem1.shape + kwargs = {} + kwargs["strides"] = elem1.stride() + kwargs["storage_offset"] = elem1.storage_offset() + kwargs["device"] = elem1.device + kwargs["layout"] = elem1.layout + kwargs["requires_grad"] = elem1.requires_grad + kwargs["dtype"] = elem1.dtype + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, elem1, elem2): + self.elem1 = elem1 + self.elem2 = elem2 + + def get_elem(self): + return self.elem1 + + def __repr__(self): + inner_repr_1 = repr(self.elem1) + inner_repr_2 = repr(self.elem2) + return f"CustomTensorPlainOut({inner_repr_1}, {inner_repr_2})" + + def __tensor_flatten__(self): + return ["elem1", "elem2"], None + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + elem1 = inner_tensors["elem1"] + elem2 = inner_tensors["elem2"] + out = CustomTensorPlainOut(elem1, elem2) + return out + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + # Don't use this tensor with view ops + if kwargs is None: + kwargs = {} + args_inner_1 = pytree.tree_map_only( + CustomTensorPlainOut, lambda x: x.elem1, args + ) + + kwargs_inner_1 = pytree.tree_map_only( + CustomTensorPlainOut, lambda x: x.elem1, kwargs + ) + + args_inner_2 = pytree.tree_map_only( + CustomTensorPlainOut, lambda x: x.elem2, args + ) + + kwargs_inner_2 = pytree.tree_map_only( + CustomTensorPlainOut, lambda x: x.elem2, kwargs + ) + + out_inner_1 = func(*args_inner_1, **kwargs_inner_1) + out_inner_2 = func(*args_inner_2, **kwargs_inner_2) + + out_inner_flat_1, spec = pytree.tree_flatten(out_inner_1) + out_inner_flat_2, spec = pytree.tree_flatten(out_inner_2) + + if func.is_view: + new_out = pytree.tree_unflatten( + ( + CustomTensorPlainOut(tensor1, tensor2) + for tensor1, tensor2 in zip( + out_inner_flat_1, out_inner_flat_2, strict=True + ) + ), + spec, + ) + return return_and_correct_aliasing(func, args, kwargs, new_out) + + out_new = ( + out_inner_flat_1[ix] + out_inner_flat_2[ix] + for ix in range(len(out_inner_flat_1)) + ) + + return pytree.tree_unflatten(out_new, spec) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e3572cfc4c6a0ddc3d8fa2e1b056415204acdfa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/__init__.py @@ -0,0 +1 @@ +# mypy: ignore-errors diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/network1.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/network1.py new file mode 100644 index 0000000000000000000000000000000000000000..8755643a78cca80668988df9e9db3de75778b5db --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/network1.py @@ -0,0 +1,10 @@ +# mypy: ignore-errors + +import torch.nn as nn + + +class Net(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(10, 20) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/network2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/network2.py new file mode 100644 index 0000000000000000000000000000000000000000..19b0b8ee53d3b530aa33978c7a13da4e5fee4ebd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/network2.py @@ -0,0 +1,11 @@ +# mypy: ignore-errors + +import torch.nn as nn + + +class Net(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(10, 20) + self.relu = nn.ReLU() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/dist_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..45af2552cf25cef03a517f5b136c1a2e61c3a61d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/dist_utils.py @@ -0,0 +1,199 @@ +# mypy: ignore-errors + +import re +import sys +import time +from functools import partial, wraps + +import torch.distributed as dist +import torch.distributed.rpc as rpc +from torch.distributed.rpc import _rref_context_get_debug_info +from torch.testing._internal.common_utils import FILE_SCHEMA, TEST_WITH_TSAN + + +if not dist.is_available(): + print("c10d not available, skipping tests", file=sys.stderr) + sys.exit(0) + + +INIT_METHOD_TEMPLATE = FILE_SCHEMA + "{file_name}" + +def dist_init( + old_test_method=None, + setup_rpc: bool = True, + clean_shutdown: bool = True, + faulty_messages=None, + messages_to_delay=None, +): + """ + We use this decorator for setting up and tearing down state since + MultiProcessTestCase runs each `test*` method in a separate process and + each process just runs the `test*` method without actually calling + 'setUp' and 'tearDown' methods of unittest. + + Note: pass the string representation of MessageTypes that should be used + with the faulty agent's send function. By default, all retriable messages + ("RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT", "RREF_USER_DELETE", + "CLEANUP_AUTOGRAD_CONTEXT_REQ") will use the faulty send (this default is + set from faulty_rpc_agent_test_fixture.py). + """ + # If we use dist_init without arguments (ex: @dist_init), old_test_method is + # appropriately set and we return the wrapper appropriately. On the other + # hand if dist_init has arguments (ex: @dist_init(clean_shutdown=False)), + # old_test_method is None and we return a functools.partial which is the real + # decorator that is used and as a result we recursively call dist_init with + # old_test_method and the rest of the arguments appropriately set. + if old_test_method is None: + return partial( + dist_init, + setup_rpc=setup_rpc, + clean_shutdown=clean_shutdown, + faulty_messages=faulty_messages, + messages_to_delay=messages_to_delay, + ) + + @wraps(old_test_method) + def new_test_method(self, *arg, **kwargs): + # Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted + # in tests. + import torch.distributed.rpc.api as api + + api._ignore_rref_leak = False + self.worker_id = self.rank + self.setup_fault_injection(faulty_messages, messages_to_delay) + + rpc_backend_options = self.rpc_backend_options + if setup_rpc: + if TEST_WITH_TSAN: + # TSAN runs much slower. + rpc_backend_options.rpc_timeout = rpc.constants.DEFAULT_RPC_TIMEOUT_SEC * 5 + rpc.constants.DEFAULT_SHUTDOWN_TIMEOUT = 60 + + rpc.init_rpc( + name=f"worker{self.rank:d}", + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) + + return_value = old_test_method(self, *arg, **kwargs) + + if setup_rpc: + rpc.shutdown(graceful=clean_shutdown) + + return return_value + + return new_test_method + + +def noop() -> None: + pass + + +def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str: + """ + Loops until an RPC to the given rank fails. This is used to + indicate that the node has failed in unit tests. + Args: + rank (int): Rank of the node expected to fail + expected_error_regex (optional, str): Regex of exception message expected. Useful to ensure a specific failure + occurs, not just any. + """ + while True: + try: + rpc.rpc_sync(f"worker{rank}", noop, args=()) + time.sleep(0.1) + except Exception as e: + if re.search(pattern=expected_error_regex, string=str(e)): + return str(e) + + +def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None: + """ + The RRef protocol holds forkIds of rrefs in a map until those forks are + confirmed by the owner. The message confirming the fork may arrive after + our tests check whether this map is empty, which leads to failures and + flaky tests. to_here also does not guarantee that we have finished + processind the owner's confirmation message for the RRef. This function + loops until the map is empty, which means the messages have been received + as processed. Call this function before asserting the map returned by + _get_debug_info is empty. + """ + start = time.time() + while True: + debug_info = _rref_context_get_debug_info() + num_pending_futures = int(debug_info["num_pending_futures"]) + num_pending_users = int(debug_info["num_pending_users"]) + if num_pending_futures == 0 and num_pending_users == 0: + break + time.sleep(0.1) + if time.time() - start > timeout: + raise ValueError( + f"Timed out waiting to flush pending futures and users, " + f"had {num_pending_futures} pending futures and {num_pending_users} pending users" + ) + + +def get_num_owners_and_forks() -> tuple[str, str]: + """ + Retrieves number of OwnerRRefs and forks on this node from + _rref_context_get_debug_info. + """ + rref_dbg_info = _rref_context_get_debug_info() + num_owners = rref_dbg_info["num_owner_rrefs"] + num_forks = rref_dbg_info["num_forks"] + return num_owners, num_forks + + +def wait_until_owners_and_forks_on_rank( + num_owners: int, num_forks: int, rank: int, timeout: int = 20 +) -> None: + """ + Waits until timeout for num_forks and num_owners to exist on the rank. Used + to ensure proper deletion of RRefs in tests. + """ + start = time.time() + while True: + num_owners_on_rank, num_forks_on_rank = rpc.rpc_sync( + worker_name(rank), get_num_owners_and_forks, args=(), timeout=5 + ) + num_owners_on_rank = int(num_owners_on_rank) + num_forks_on_rank = int(num_forks_on_rank) + if num_owners_on_rank == num_owners and num_forks_on_rank == num_forks: + return + time.sleep(1) + if time.time() - start > timeout: + raise ValueError( + f"Timed out waiting {timeout} sec for {num_owners} owners and {num_forks} forks on rank," + f" had {num_owners_on_rank} owners and {num_forks_on_rank} forks" + ) + + +def initialize_pg(init_method, rank: int, world_size: int) -> None: + # This is for tests using `dist.barrier`. + if not dist.is_initialized(): + dist.init_process_group( + backend="gloo", + init_method=init_method, + rank=rank, + world_size=world_size, + ) + + +def worker_name(rank: int) -> str: + return f"worker{rank}" + + +def get_function_event(function_events, partial_event_name): + """ + Returns the first event that matches partial_event_name in the provided + function_events. These function_events should be the output of + torch.autograd.profiler.function_events(). + + Args: + function_events: function_events returned by the profiler. + event_name (str): partial key that the event was profiled with. + """ + event = [event for event in function_events if partial_event_name in event.name][0] # noqa: RUF015 + return event diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..10002da5854421a2d53076eb8458f42ac7a1e4e2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_utils.py @@ -0,0 +1,67 @@ +# mypy: allow-untyped-defs + +from contextlib import contextmanager +from datetime import timedelta +from functools import partial, wraps + +import torch.distributed as dist +import torch.distributed.distributed_c10d as c10d + + +class MockProcessGroup(dist.ProcessGroup): + def getBackendName(self): + return "mock_process_group" + + +def create_mock_pg(prefix_store, rank, world_size, timeout): + return MockProcessGroup(rank, world_size) + + +dist.Backend.register_backend("mock_process_group", create_mock_pg) + + +def mock_init_dist(rank, world_size): + # !!! WARNING !!! + # Kids don't try this at home, this is a cute pile of hacks that + # depends on a small mountain of c10d internals + assert not dist.is_initialized() + store = dist.HashStore() + # Trick _store_based_barrier into believing everyone else already checked-in + # Zero is the group index + store.add(f"{c10d.STORE_BASED_BARRIER_PREFIX}:0", world_size - 1) + dist.init_process_group( + backend="mock_process_group", + rank=rank, + world_size=world_size, + store=store, + group_name="fake", + timeout=timedelta(seconds=1), + ) + + +@contextmanager +def with_dist(rank=0, world_size=2): + """ + Context manager that initializer c10d with a fake process group. + """ + mock_init_dist(rank=rank, world_size=world_size) + try: + yield + finally: + dist.destroy_process_group() + + +def with_fake_comms(func=None, rank=0, world_size=2): + """ + Function wrapper that inits a fake process group designed for testing. + Right now only querying for world size is available + """ + if func is None: + return partial(with_fake_comms, rank=rank, world_size=world_size) + + @wraps(func) + def wrapper(self, *args, **kwargs): + with with_dist(rank, world_size): + func(self, *args, **kwargs) + + return wrapper diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/dynamo_test_failures.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/dynamo_test_failures.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc69b7920cf06d24dceac0bb2743004c0b6c64e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/dynamo_test_failures.py @@ -0,0 +1,145 @@ +""" +This file contains the list of tests that are known to fail under Dynamo + +We generate xFailIfTorchDynamo* for all tests in `dynamo_expected_failures` +We generate skipIfTorchDynamo* for all tests in `dynamo_skips` +We generate runWithoutCompiledAutograd for all tests in `compiled_autograd_skips` + +For an easier-than-manual way of generating and updating these lists, +see scripts/compile_tests/update_failures.py + +If you're adding a new test, and it's failing PYTORCH_TEST_WITH_DYNAMO=1, +either add the appropriate decorators to your test or add skips for them +via test/dynamo_skips and test/dynamo_expected_failures. + +*These are not exactly unittest.expectedFailure and unittest.skip. We'll +always execute the test and then suppress the signal, if necessary. +If your tests crashes, or is slow, please use @skipIfTorchDynamo instead. + +The expected failure and skip files are located in test/dynamo_skips and +test/dynamo_expected_failures. They're individual files rather than a list so +git will merge changes easier. +""" + +import logging +import os +import sys +from typing import Optional + + +def find_test_dir() -> Optional[str]: + # Find the path to the dynamo expected failure and skip files. + from os.path import abspath, basename, dirname, exists, join, normpath + + if sys.platform == "win32": + return None + + # Check relative to this file (local build): + test_dir = normpath(join(dirname(abspath(__file__)), "../../../test")) + if exists(join(test_dir, "dynamo_expected_failures")): + return test_dir + + # Check relative to __main__ (installed builds relative to test file): + main = sys.modules["__main__"] + file = getattr(main, "__file__", None) + if file is None: + # Generated files do not have a module.__file__ + return None + test_dir = dirname(abspath(file)) + while dirname(test_dir) != test_dir: + if basename(test_dir) == "test" and exists( + join(test_dir, "dynamo_expected_failures") + ): + return test_dir + test_dir = dirname(test_dir) + + # Not found + return None + + +test_dir = find_test_dir() +if not test_dir: + logger = logging.getLogger(__name__) + logger.warning( + "test/dynamo_expected_failures directory not found - known dynamo errors won't be skipped." + ) + +# Tests that run without strict mode in PYTORCH_TEST_WITH_INDUCTOR=1. +# Please don't add anything to this list. +FIXME_inductor_non_strict = { + "test_modules", + "test_ops", + "test_ops_gradients", + "test_torch", +} + +# We generate unittest.expectedFailure for all of the following tests +# when run under PYTORCH_TEST_WITH_DYNAMO=1. +# see NOTE [dynamo_test_failures.py] for more details +# +# This lists exists so we can more easily add large numbers of failing tests, +if test_dir is None: + dynamo_expected_failures = set() + dynamo_skips = set() + + inductor_expected_failures = set() + inductor_skips = set() + + compiled_autograd_skips = set() +else: + dynamo_failures_directory = os.path.join(test_dir, "dynamo_expected_failures") + dynamo_skips_directory = os.path.join(test_dir, "dynamo_skips") + + dynamo_expected_failures = set(os.listdir(dynamo_failures_directory)) + dynamo_skips = set(os.listdir(dynamo_skips_directory)) + + inductor_failures_directory = os.path.join(test_dir, "inductor_expected_failures") + inductor_skips_directory = os.path.join(test_dir, "inductor_skips") + + inductor_expected_failures = set(os.listdir(inductor_failures_directory)) + inductor_skips = set(os.listdir(inductor_skips_directory)) + + compiled_autograd_skips_directory = os.path.join( + test_dir, "compiled_autograd_skips" + ) + compiled_autograd_skips = set(os.listdir(compiled_autograd_skips_directory)) + +# TODO: due to case sensitivity problems, for now list these files by hand +extra_dynamo_skips = { + "TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_T_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_t_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_T_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_t_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_T_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_t_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_T_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_t_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_T_cpu_float32", + "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_t_cpu_float32", +} +dynamo_skips = dynamo_skips.union(extra_dynamo_skips) + + +# verify some invariants +for test in ( + dynamo_expected_failures + | dynamo_skips + | inductor_expected_failures + | inductor_skips +): + if len(test.split(".")) != 2: + raise AssertionError(f'Invalid test name: "{test}"') + +dynamo_intersection = dynamo_expected_failures.intersection(dynamo_skips) +if len(dynamo_intersection) > 0: + raise AssertionError( + "there should be no overlap between dynamo_expected_failures " + "and dynamo_skips, got " + str(dynamo_intersection) + ) + +inductor_intersection = inductor_expected_failures.intersection(inductor_skips) +if len(inductor_intersection) > 0: + raise AssertionError( + "there should be no overlap between inductor_expected_failures " + "and inductor_skips, got " + str(inductor_intersection) + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module.py new file mode 100644 index 0000000000000000000000000000000000000000..1e93c41de72a765415a8dce5d3c98c8cd0cf2c41 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module.py @@ -0,0 +1,44 @@ +import sys +from typing import Optional + +from torch.utils._config_module import Config, install_config_module + + +e_bool = True +e_int = 1 +e_float = 1.0 +e_string = "string" +e_list = [1] +e_set = {1} +e_tuple = (1,) +e_dict = {1: 2} +e_none: Optional[bool] = None +e_optional: Optional[bool] = True +e_ignored = True +_e_ignored = True +magic_cache_config_ignored = True +# [@compile_ignored: debug] +e_compile_ignored = True +e_config: bool = Config(default=True) +e_jk: bool = Config(justknob="does_not_exist", default=True) +e_jk_false: bool = Config(justknob="does_not_exist", default=False) +e_env_default: bool = Config(env_name_default="ENV_TRUE", default=False) +e_env_default_FALSE: bool = Config(env_name_default="ENV_FALSE", default=True) +e_env_default_str: bool = Config(env_name_default="ENV_STR", default="default") +e_env_default_str_empty: bool = Config( + env_name_default="ENV_STR_EMPTY", default="default" +) +e_env_force: bool = Config(env_name_force="ENV_TRUE", default=False) +e_aliased_bool: bool = Config( + alias="torch.testing._internal.fake_config_module2.e_aliasing_bool" +) + + +class nested: + e_bool = True + + +_cache_config_ignore_prefix = ["magic_cache_config"] +_save_config_ignore = ["e_ignored"] + +install_config_module(sys.modules[__name__]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module2.py new file mode 100644 index 0000000000000000000000000000000000000000..77c2e2baa4ddca7685adf734809488979c21ab63 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module2.py @@ -0,0 +1,13 @@ +import sys + +from torch.utils._config_module import Config, install_config_module + + +e_aliasing_bool = False + +e_env_default_multi: bool = Config( + env_name_default=["ENV_TRUE", "ENV_FALSE"], default=False +) +e_env_force_multi: bool = Config(env_name_force=["ENV_FAKE", "ENV_TRUE"], default=False) + +install_config_module(sys.modules[__name__]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/generated/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/generated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/generated/annotated_fn_args.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/generated/annotated_fn_args.py new file mode 100644 index 0000000000000000000000000000000000000000..2c8fdd3bb138fa7225a10d7518f148784a7bb116 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/generated/annotated_fn_args.py @@ -0,0 +1,2905 @@ +""" +This file is needed for generating procedural tests required for +testing __torch_function__. See tests/test_overrides.py. +""" + +# flake8: noqa +import torch + +annotated_args = { + torch._C._VariableFunctions._cast_Byte: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Char: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Double: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Float: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Int: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Long: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Short: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Half: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._make_dual: [{'is_kwarg_only': 'False', 'name': 'primal', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tangent', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._unpack_dual: [{'is_kwarg_only': 'False', 'name': 'dual', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.align_tensors: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._assert_async: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._assert_async: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions._assert_scalar: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions._functional_assert_scalar: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._functional_assert_async: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._assert_tensor_metadata: [{'is_kwarg_only': 'False', 'name': 'a', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._print: [{'is_kwarg_only': 'False', 'name': 's', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.sym_constrain_range: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.sym_constrain_range_for_size: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._functional_sym_constrain_range: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._functional_sym_constrain_range_for_size: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._make_dep_token: [], + torch._C._VariableFunctions._use_cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._use_cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'zero_infinity', 'simple_type': 'bool'}], + torch._C._VariableFunctions._cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'zero_infinity', 'simple_type': 'bool'}], + torch._C._VariableFunctions._use_cudnn_rnn_flatten_weight: [], + torch._C._VariableFunctions._cudnn_rnn_flatten_weight: [{'is_kwarg_only': 'False', 'name': 'weight_arr', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight_stride0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'proj_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions._cudnn_rnn: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight_stride0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'weight_buf', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cx', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'proj_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dropout_state', 'simple_type': 'Tensor?'}], + torch._C._VariableFunctions._cudnn_init_dropout_state: [{'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dropout_seed', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._debug_has_internal_overlap: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._fused_dropout: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}], + torch._C._VariableFunctions._masked_scale: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}], + torch._C._VariableFunctions.native_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool?'}], + torch._C._VariableFunctions._sobol_engine_draw: [{'is_kwarg_only': 'False', 'name': 'quasi', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sobolstate', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_generated', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType?'}], + torch._C._VariableFunctions._sobol_engine_ff_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sobolstate', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_generated', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._sobol_engine_scramble_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ltm', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._sobol_engine_initialize_state_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._reshape_from_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._shape_as_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.feature_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.feature_dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.alpha_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.alpha_dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.feature_alpha_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.feature_alpha_dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.abs_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.absolute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.absolute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.angle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.angle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_real: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_complex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sgn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sgn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.real: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.imag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conj_physical_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.resolve_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.resolve_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._neg_view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.avg_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.adaptive_avg_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.adaptive_max_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._add_relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.addmv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.affine_grid_generator: [{'is_kwarg_only': 'False', 'name': 'theta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions._is_all_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._is_any_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_check_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_functorch_fallback: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.allclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._dim_arange: [{'is_kwarg_only': 'False', 'name': 'like', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.argmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.as_strided: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.as_strided_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atleast_1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atleast_1d: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.atleast_2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atleast_2d: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.atleast_3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atleast_3d: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.bartlett_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.bartlett_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}], + torch._C._VariableFunctions.quantized_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'output_scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'output_zero_point', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._batch_norm_impl_index: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}], + torch._C._VariableFunctions.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}], + torch._C._VariableFunctions.bilinear: [{'is_kwarg_only': 'False', 'name': 'input1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.binary_cross_entropy_with_logits: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bincount: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._lazy_clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.blackman_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.blackman_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.broadcast_tensors: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.broadcast_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._sparse_broadcast_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.block_diag: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ceil_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.chain_matmul: [{'is_kwarg_only': 'False', 'name': 'matrices', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.chain_matmul: [{'is_kwarg_only': 'False', 'name': 'matrices', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.unsafe_chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor_indices_or_sections', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cudnn_is_acceptable: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.complex: [{'is_kwarg_only': 'False', 'name': 'real', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'imag', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.complex: [{'is_kwarg_only': 'False', 'name': 'real', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'imag', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.polar: [{'is_kwarg_only': 'False', 'name': 'abs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'angle', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.polar: [{'is_kwarg_only': 'False', 'name': 'abs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'angle', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.constant_pad_nd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'transposed', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions._convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'transposed', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}], + torch._C._VariableFunctions._convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'transposed', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}], + torch._C._VariableFunctions._convolution_mode: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.conv1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv_tbc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv_transpose1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv_transpose2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv_transpose3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._copy_from: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dst', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._copy_from_and_resize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dst', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosine_embedding_loss: [{'is_kwarg_only': 'False', 'name': 'input1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cov: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.corrcoef: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cudnn_affine_grid_generator: [{'is_kwarg_only': 'False', 'name': 'theta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'N', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'C', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'H', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'W', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cudnn_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'exponential_average_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'epsilon', 'simple_type': 'double'}], + torch._C._VariableFunctions.cudnn_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'exponential_average_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'epsilon', 'simple_type': 'double'}], + torch._C._VariableFunctions.cudnn_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}], + torch._C._VariableFunctions.cudnn_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}], + torch._C._VariableFunctions.cudnn_convolution_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}], + torch._C._VariableFunctions._mps_convolution_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.cudnn_convolution_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.cudnn_convolution_add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'z', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.cudnn_grid_sampler: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._cummax_helper: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._cummin_helper: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cumulative_trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cumulative_trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diag_embed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diagflat: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diff: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diff: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'ScalarList'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.dot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.dot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.vdot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.vdot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.einsum: [{'is_kwarg_only': 'False', 'name': 'equation', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.embedding: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.embedding_renorm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max_norm', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'norm_type', 'simple_type': 'double'}], + torch._C._VariableFunctions._embedding_bag_forward_only: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._rowwise_prune: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'compressed_indices_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.row_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.row_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.embedding_bag: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.embedding_bag: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_grad_by_freq', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sparse', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'per_sample_weights', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'include_last_offset', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'padding_idx', 'simple_type': 'int64_t?'}], + torch._C._VariableFunctions._embedding_bag: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.empty: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.empty: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.empty: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.empty_permuted: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'physical_layout', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._empty_affine_quantized: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._empty_per_channel_affine_quantized: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'scales', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'zero_points', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'axis', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._resize_output_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'device', 'simple_type': 'Device'}], + torch._C._VariableFunctions.empty_quantized: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'qtensor', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.empty_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.empty_strided: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erf_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.expm1_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'm', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'm', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'DimnameList'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}], + torch._C._VariableFunctions.fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.frac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.full: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.full: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.full: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.full_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.from_file: [{'is_kwarg_only': 'False', 'name': 'filename', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.gcd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gcd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gcd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lcm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lcm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lcm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.grid_sampler: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions.grid_sampler_2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions._grid_sampler_2d_cpu_fallback: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions.grid_sampler_3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions.hann_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.hann_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'double'}], + torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'beta', 'simple_type': 'double'}], + torch._C._VariableFunctions.kaiser_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.kaiser_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.kaiser_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'beta', 'simple_type': 'double'}], + torch._C._VariableFunctions.hinge_embedding_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.group_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_groups', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.native_group_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'N', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'C', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'HxW', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'group', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._fft_r2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'onesided', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fft_r2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'onesided', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fft_c2r: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'last_dim_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions._fft_c2r: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'last_dim_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions._fft_c2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'forward', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fft_c2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'forward', 'simple_type': 'bool'}], + torch._C._VariableFunctions._validate_compressed_sparse_indices: [{'is_kwarg_only': 'False', 'name': 'is_crow', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'compressed_idx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'plain_idx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cdim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'nnz', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._cufft_get_plan_cache_size: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}], + torch._C._VariableFunctions._cufft_get_plan_cache_max_size: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}], + torch._C._VariableFunctions._cufft_set_plan_cache_max_size: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}, {'is_kwarg_only': 'False', 'name': 'max_size', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._cufft_clear_plan_cache: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}], + torch._C._VariableFunctions._unsafe_index: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}], + torch._C._VariableFunctions._unsafe_masked_index: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'fill', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._unsafe_masked_index_put_accumulate: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_put_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._unsafe_index_put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._index_put_impl_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.instance_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'use_input_stats', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}], + torch._C._VariableFunctions.isclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_element', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_element', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'element', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'element', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isnan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_distributed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_floating_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_complex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._is_zerotensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isreal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_same_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_signed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_inference: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.kl_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.kron: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.kron: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.layer_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.native_layer_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.rms_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._fused_rms_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double?'}], + torch._C._VariableFunctions.nan_to_num: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nan_to_num: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nan_to_num_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mkldnn_linear_backward_weights: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias_defined', 'simple_type': 'bool'}], + torch._C._VariableFunctions._cslt_compress: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cslt_sparse_mm: [{'is_kwarg_only': 'False', 'name': 'compressed_A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dense_B', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cslt_sparse_mm_search: [{'is_kwarg_only': 'False', 'name': 'compressed_A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dense_B', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_tile: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_apply: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'thread_masks', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_apply_dense: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'thread_masks', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'meta', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_mm: [{'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1_meta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_addmm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1_meta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._mixed_dtypes_linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_int8_weight_fp32_activation: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_scale', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'weight_zero_point', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_int8_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_scale', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'weight_zero_point', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_quantize_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_pack_gemm_matrix_fp16: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._wrapped_linear_prepack: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._wrapped_quantized_linear_prepacked: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_channel', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fbgemm_linear_fp16_weight_fp32_activation: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}], + torch._C._VariableFunctions.fbgemm_linear_fp16_weight_fp32_activation: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_fp16_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_fp16_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_pack_quantized_matrix: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_pack_quantized_matrix: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'K', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'N', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.ldexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ldexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ldexp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log10_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log1p_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logaddexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logaddexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logaddexp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logaddexp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}], + torch._C._VariableFunctions._log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}], + torch._C._VariableFunctions._log_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._log_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.margin_ranking_loss: [{'is_kwarg_only': 'False', 'name': 'input1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.matrix_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._compute_linear_combination: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'coefficients', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._compute_linear_combination: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'coefficients', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.amax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.amax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max_pool1d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.max_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.mkldnn_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.mkldnn_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._VariableFunctions.quantized_max_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.quantized_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.quantized_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._VariableFunctions.max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.nanmean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nanmean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.amin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.amin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._mps_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.mkldnn_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.mkldnn_rnn_layer: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight0', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight3', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx_', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cx_', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reverse', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.miopen_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'exponential_average_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'epsilon', 'simple_type': 'double'}], + torch._C._VariableFunctions.miopen_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.miopen_convolution_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.miopen_depthwise_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.miopen_convolution_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.miopen_convolution_add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'z', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.miopen_rnn: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight_stride0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cx', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dropout_state', 'simple_type': 'Tensor?'}], + torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._int_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._int_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._convert_weight_to_int4pack: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'innerKTiles', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._weight_int4pack_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qGroupSize', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qScaleAndZeros', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._weight_int4pack_mm_with_scales_and_zeros: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qGroupSize', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qScale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qZeros', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._convert_weight_to_int4pack_for_cpu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'innerKTiles', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._weight_int4pack_mm_for_cpu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qGroupSize', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qScaleAndZeros', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._dyn_quant_pack_4bit_weight: [{'is_kwarg_only': 'False', 'name': 'weights', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scales_zeros', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'block_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'in_features', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_features', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._dyn_quant_matmul_4bit: [{'is_kwarg_only': 'False', 'name': 'inp', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weights', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'block_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'in_features', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_features', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._weight_int8pack_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scales', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_sparse_matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.mv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mvlgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.mvlgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.narrow_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.narrow_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.native_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.native_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit_no_training: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.batch_norm_stats: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.batch_norm_elemt: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.batch_norm_elemt: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.batch_norm_gather_stats: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'count', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.batch_norm_gather_stats_with_counts: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'counts', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.batch_norm_backward_reduce: [{'is_kwarg_only': 'False', 'name': 'grad_out', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'input_g', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'weight_g', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bias_g', 'simple_type': 'bool'}], + torch._C._VariableFunctions.batch_norm_backward_elemt: [{'is_kwarg_only': 'False', 'name': 'grad_out', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'sum_dy', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sum_dy_xmu', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'count', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.batch_norm_update_stats: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}], + torch._C._VariableFunctions.is_vulkan_available: [], + torch._C._VariableFunctions._nnpack_available: [], + torch._C._VariableFunctions._nnpack_spatial_convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._VariableFunctions.ones: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.ones: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.ones: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.ones_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pairwise_distance: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cdist: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._euclidean_dist: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pdist: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosine_similarity: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.permute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.adjoint: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pixel_shuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'upscale_factor', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.pixel_unshuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'downscale_factor', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.channel_shuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.native_channel_shuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions._pin_memory: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pinverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.poisson_nll_loss: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'log_input', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'full', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'reduction', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.rad2deg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rad2deg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rad2deg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.deg2rad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.deg2rad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.deg2rad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scalar_tensor: [{'is_kwarg_only': 'False', 'name': 's', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.rand_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rand_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randn_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.randn_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.ravel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.reciprocal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.neg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.negative: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.negative: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.negative_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.reshape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._mkldnn_reshape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rrelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rrelu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.prelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._prelu_kernel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hardshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hardshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rsqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.selu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.selu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.celu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.celu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logit_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.detach: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.detach_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.select_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.diagonal_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.as_strided_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.smm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}], + torch._C._VariableFunctions._softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}], + torch._C._VariableFunctions._softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.unsafe_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.unsafe_split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.sspaddmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sspaddmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._chunk_cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_chunks', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._chunk_cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_chunks', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.hstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.hstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.vstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.vstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.dstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.dstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.istft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.nansum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nansum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hash_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hash_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.square: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.square: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.square_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.t: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tensordot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims_self', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dims_other', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.tensordot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims_self', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dims_other', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.threshold: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'threshold', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.threshold: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'threshold', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.threshold_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'threshold', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.tile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._mkldnn_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._mkldnn_transpose_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.flip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.fliplr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.flipud: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.roll: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shifts', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._VariableFunctions.rot90: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trapz: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trapz: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._transform_bias_rescale_qkv: [{'is_kwarg_only': 'False', 'name': 'qkv', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_heads', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._nested_tensor_from_mask: [{'is_kwarg_only': 'False', 'name': 't', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_tensor_from_mask_left_aligned: [{'is_kwarg_only': 'False', 'name': 't', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_from_padded: [{'is_kwarg_only': 'False', 'name': 'padded', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cpu_nested_shape_example', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_from_padded_and_nested_example: [{'is_kwarg_only': 'False', 'name': 'padded', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nt_example', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_buffer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_strides', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_buffer_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_strides', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_buffer_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_strides', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_jagged: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_jagged_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_jagged_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_values: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_offsets: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_lengths: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_ragged_idx: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_min_seqlen: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_max_seqlen: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_jagged_dummy: [{'is_kwarg_only': 'False', 'name': 'any', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_compute_contiguous_strides_offsets: [{'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._trilinear: [{'is_kwarg_only': 'False', 'name': 'i1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'i2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'i3', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'expand1', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'expand2', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'expand3', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'sumdim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.triplet_margin_loss: [{'is_kwarg_only': 'False', 'name': 'anchor', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'positive', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'negative', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trunc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fix: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fix: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fix_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._has_compatible_shallow_copy_type: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'from', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._unique: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unique_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.unique_consecutive: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._unique2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unsqueeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.vander: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.norm_except_dim: [{'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._weight_norm: [{'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'g', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._weight_norm_interface: [{'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'g', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.zeros: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.zeros: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.zeros: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._efficientzerotensor: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.zeros_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._standard_gamma_grad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._standard_gamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._dirichlet_grad: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'total', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sample_dirichlet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.poisson: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.binomial: [{'is_kwarg_only': 'False', 'name': 'count', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'prob', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.native_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.native_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType?'}], + torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._sparse_csr_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions._sparse_csr_prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions._sparse_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_log_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.frexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.frexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.frobenius_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.frobenius_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.positive: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.resize_as_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.resize_as_sparse_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.zero_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.rsub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rsub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.heaviside: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.heaviside: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._addmm_activation: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._addmm_activation: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_mm_v2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'recipe_a', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'swizzle_a', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'recipe_b', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'swizzle_b', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType?'}], + torch._C._VariableFunctions._scaled_mm_v2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'recipe_a', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'swizzle_a', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'recipe_b', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'swizzle_b', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType?'}], + torch._C._VariableFunctions._scaled_grouped_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_grouped_mm_v2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'recipe_a', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'swizzle_a', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'recipe_b', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'swizzle_b', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._grouped_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._validate_sparse_coo_tensor_args: [{'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._validate_sparse_compressed_tensor_args: [{'is_kwarg_only': 'False', 'name': 'compressed_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'plain_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'layout', 'simple_type': 'Layout'}], + torch._C._VariableFunctions._validate_sparse_csr_tensor_args: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._validate_sparse_csc_tensor_args: [{'is_kwarg_only': 'False', 'name': 'ccol_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'row_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._validate_sparse_bsr_tensor_args: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._validate_sparse_bsc_tensor_args: [{'is_kwarg_only': 'False', 'name': 'ccol_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'row_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._to_cpu: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._coalesce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hspmm: [{'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hspmm: [{'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._to_sparse_semi_structured: [{'is_kwarg_only': 'False', 'name': 'dense', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantize_per_tensor_dynamic: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'reduce_range', 'simple_type': 'bool'}], + torch._C._VariableFunctions.quantize_per_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.quantize_per_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.quantize_per_tensor: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scales', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_points', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.quantize_per_channel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scales', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_points', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.dequantize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.dequantize: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.q_scale: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.q_zero_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.q_per_channel_scales: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.q_per_channel_zero_points: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.q_per_channel_axis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.int_repr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._make_per_tensor_quantized_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._make_per_channel_quantized_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fake_quantize_per_tensor_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fake_quantize_per_tensor_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fake_quantize_per_tensor_affine_cachemask_tensor_qparams: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fake_quant_enabled', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fake_quantize_learnable_per_tensor_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fake_quantize_per_channel_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fake_quantize_learnable_per_channel_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fused_moving_avg_obs_fake_quant: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'observer_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fake_quant_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_min', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_max', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'averaging_const', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'ch_axis', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fused_moving_avg_obs_fq_helper: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'observer_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fake_quant_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_min', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_max', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'averaging_const', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'ch_axis', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._choose_qparams_per_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._saturate_weight_to_fp16: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.choose_qparams_optimized: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'numel', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'n_bins', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'ratio', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'bit_width', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.meshgrid: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.meshgrid: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'indexing', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.cartesian_prod: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.combinations: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'scalar1', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scalar2', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.can_cast: [{'is_kwarg_only': 'False', 'name': 'from_', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'to', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.promote_types: [{'is_kwarg_only': 'False', 'name': 'type1', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'type2', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._lstm_mps: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.lstm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.lstm: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions.gru: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.gru: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions.rnn_tanh: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.rnn_tanh: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions.rnn_relu: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.rnn_relu: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions.lstm_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gru_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rnn_tanh_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rnn_relu_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantized_lstm_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.quantized_gru_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.quantized_rnn_relu_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.quantized_rnn_tanh_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._pack_padded_sequence: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions._pad_packed_sequence: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'padding_value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'total_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.masked_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._masked_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.index_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.triu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.triu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tril: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tril: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tril_indices: [{'is_kwarg_only': 'False', 'name': 'row', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'col', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.triu_indices: [{'is_kwarg_only': 'False', 'name': 'row', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'col', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.trace: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.take: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.take: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.take_along_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.take_along_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.masked_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.masked_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nonzero_static: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.nonzero_static: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.argwhere: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.triangular_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.triangular_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_check_errors: [{'is_kwarg_only': 'False', 'name': 'info', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'api_name', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'True', 'name': 'is_matrix', 'simple_type': 'bool'}], + torch._C._VariableFunctions.svd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.svd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.swapaxes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'axis1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.swapdims: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.qr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.qr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.geqrf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.geqrf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.orgqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.orgqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ormqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input3', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ormqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input3', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._lu_with_info: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lu_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lu_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lu_unpack: [{'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lu_unpack: [{'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.multinomial: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_samples', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.multinomial: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_samples', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.i0_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.signbit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.signbit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.dist: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._histogramdd_bin_edges: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._histogramdd_from_bin_cts: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._histogramdd_from_bin_tensors: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.histogramdd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.histogramdd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.histogramdd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hypot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hypot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.igamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.igamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.igammac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.igammac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nextafter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nextafter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.msort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.msort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool'}], + torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool'}], + torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.topk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.topk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.renorm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.renorm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._amp_foreach_non_finite_check_and_unscale_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'found_inf', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'inv_scale', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._amp_update_scale_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'growth_tracker', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'found_inf', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_growth_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'scale_backoff_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'growth_interval', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_maximum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_maximum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_maximum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_minimum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_minimum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_minimum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_abs_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_acos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_asin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_atan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_ceil_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_cos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_cosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_erf_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_erfc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_exp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_expm1_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_floor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_frac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weights', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weights', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_lgamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log10_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log1p_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_neg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_reciprocal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_rsqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_tan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_tanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_trunc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_zero_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_copy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.bucketize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'boundaries', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bucketize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'boundaries', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bucketize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'boundaries', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._convert_indices_from_coo_to_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._convert_indices_from_coo_to_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._convert_indices_from_csr_to_coo: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._convert_indices_from_csr_to_coo: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mkldnn_adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.mkldnn_adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions._adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._VariableFunctions._adaptive_avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._VariableFunctions.column_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.column_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.isfinite: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isposinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isposinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isneginf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isneginf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_batch_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._remove_batch_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.det: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slogdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slogdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.inner: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.inner: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.outer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.outer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ger: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ger: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_serialization_subcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_parallel_materialize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_parallel', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b', 'simple_type': 'bool'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch_view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.segment_reduce: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions._nested_tensor_from_tensor_list: [{'is_kwarg_only': 'False', 'name': 'list', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._fw_primal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fw_primal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._make_dual_copy: [{'is_kwarg_only': 'False', 'name': 'primal', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tangent', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._make_dual_copy: [{'is_kwarg_only': 'False', 'name': 'primal', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tangent', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.view_as_real_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_real_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_complex_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_complex_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._conj_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._conj_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._neg_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._neg_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.as_strided_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.as_strided_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._sparse_broadcast_to_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._sparse_broadcast_to_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.diagonal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diagonal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.expand_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.expand_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.permute_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.permute_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._reshape_alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._reshape_alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.select_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.select_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.detach_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.detach_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.split_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.split_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.split_with_sizes_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.split_with_sizes_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.t_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.t_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.transpose_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.transpose_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.unsqueeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.unsqueeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.crow_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.crow_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.col_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.col_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ccol_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ccol_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.row_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.row_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unbind_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unbind_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.unfold_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'step', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.unfold_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'step', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_from_padded_tensor: [{'is_kwarg_only': 'False', 'name': 'padded', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_tensor_softmax_with_shape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._safe_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._transformer_encoder_layer_fwd: [{'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'embed_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_heads', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qkv_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'use_gelu', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'norm_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'norm_weight_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'norm_bias_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'norm_weight_2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'norm_bias_2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_weight_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_bias_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_weight_2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_bias_2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._native_multi_head_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'embed_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_head', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qkv_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._fused_sdp_choice: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_attention_math: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_attention_math_for_mps: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_flash_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_flash_attention_for_cpu: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_efficient_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'attn_bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'compute_log_sumexp', 'simple_type': 'bool'}], + torch._C._VariableFunctions._scaled_dot_product_cudnn_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'attn_bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'compute_log_sumexp', 'simple_type': 'bool'}], + torch._C._VariableFunctions._triton_scaled_dot_attention: [{'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._fill_mem_eff_dropout_mask_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dropout_p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'seed', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'offset', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._triton_multi_head_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'embed_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_head', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qkv_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foobar: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._fused_adam_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adam_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adamw_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adamw_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_sgd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'momentum_buffer_list', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'dampening', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'nesterov', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'is_first_step', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_sgd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'momentum_buffer_list', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'dampening', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'nesterov', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'is_first_step', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adagrad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_sums', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'lr_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adagrad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_sums', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'lr_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._propagate_xla_data: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}], + torch._C._nn.binary_cross_entropy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.binary_cross_entropy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._nn.linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._nn.mkldnn_linear: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._nn.relu6: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.relu6_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.gelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.gelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.gelu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.silu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.silu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.silu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mish_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.one_hot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mkldnn_reorder_conv2d_weight: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mkldnn_reorder_conv3d_weight: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.cross_entropy_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.mse_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.mse_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.l1_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.multi_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.multi_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.multilabel_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.multilabel_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss_nd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.smooth_l1_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.smooth_l1_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.huber_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.huber_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.soft_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.soft_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.elu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.elu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.elu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.glu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.glu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardsigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardsigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardsigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardtanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardtanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardtanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardswish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardswish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardswish_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.leaky_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.leaky_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.leaky_relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.log_sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.log_sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.rrelu_with_noise: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'noise', 'simple_type': 'Tensor'}], + torch._C._nn.rrelu_with_noise: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'noise', 'simple_type': 'Tensor'}], + torch._C._nn.rrelu_with_noise_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'noise', 'simple_type': 'Tensor'}], + torch._C._nn.softplus: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.softplus: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.softshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.softshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.adaptive_avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.adaptive_avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.adaptive_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.adaptive_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.adaptive_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.adaptive_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.fractional_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}], + torch._C._nn.fractional_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}], + torch._C._nn.fractional_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}], + torch._C._nn.fractional_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}], + torch._C._nn.max_pool2d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.max_pool2d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.max_pool3d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.max_pool3d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.max_unpool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.max_unpool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.max_unpool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.max_unpool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.reflection_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.reflection_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.reflection_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}], + torch._C._nn.reflection_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}], + torch._C._nn.reflection_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}], + torch._C._nn.reflection_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}], + torch._C._nn.replication_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.replication_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.replication_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}], + torch._C._nn.replication_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}], + torch._C._nn.replication_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}], + torch._C._nn.replication_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}], + torch._C._nn._pad_circular: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}], + torch._C._nn._pad_enum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}], + torch._C._nn.pad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}], + torch._C._nn.upsample_linear1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_linear1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_linear1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_bilinear2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_bilinear2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_bilinear2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn._upsample_bilinear2d_aa: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_bilinear2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn._upsample_bilinear2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_trilinear3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_trilinear3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_trilinear3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_bicubic2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_bicubic2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_bicubic2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn._upsample_bicubic2d_aa: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_bicubic2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn._upsample_bicubic2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_nearest1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_nearest1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._nn.upsample_nearest1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._nn._upsample_nearest_exact1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_nearest_exact1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._nn._upsample_nearest_exact1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._nn.upsample_nearest2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_nearest2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.upsample_nearest2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn._upsample_nearest_exact2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_nearest_exact2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn._upsample_nearest_exact2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.upsample_nearest3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_nearest3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.upsample_nearest3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn._upsample_nearest_exact3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_nearest_exact3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn._upsample_nearest_exact3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv_transpose2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.slow_conv_transpose2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.slow_conv_transpose3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv_transpose3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.thnn_conv2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.thnn_conv2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn._conv_depthwise2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn._conv_depthwise2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.conv_depthwise3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv_dilated2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.slow_conv_dilated3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.col2im: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.col2im: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.im2col: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.im2col: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn._test_optional_intlist: [{'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'addends', 'simple_type': 'IntArrayRef?'}], + torch._C._nn._test_optional_filled_intlist: [{'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'addends', 'simple_type': 'IntArrayRef?', 'size': 2}], + torch._C._nn._test_optional_floatlist: [{'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'addends', 'simple_type': 'ArrayRef?'}], + torch._C._nn._test_string_default: [{'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._nn._test_ambiguous_defaults: [{'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._nn._test_ambiguous_defaults: [{'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._nn._test_warn_in_autograd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.pad_sequence: [{'is_kwarg_only': 'False', 'name': 'sequences', 'simple_type': 'TensorList'}], + torch._C._nn.flatten_dense_tensors: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._nn.unflatten_dense_tensors: [{'is_kwarg_only': 'False', 'name': 'flat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._nn.scaled_dot_product_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_diagonal: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve_triangular: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'upper', 'simple_type': 'bool'}], + torch._C._linalg.linalg_solve_triangular: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'upper', 'simple_type': 'bool'}], + torch._C._linalg.linalg_vander: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cholesky_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cholesky_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_factor: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_factor: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_factor_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_factor_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_solve: [{'is_kwarg_only': 'False', 'name': 'LU', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_solve: [{'is_kwarg_only': 'False', 'name': 'LU', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_factor_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_factor_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_factor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_factor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_solve: [{'is_kwarg_only': 'False', 'name': 'LD', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_solve: [{'is_kwarg_only': 'False', 'name': 'LD', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lstsq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lstsq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_vecdot: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_vecdot: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eig: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eig: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg._linalg_eigvals: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigvals: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigvals: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigvalsh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigvalsh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_householder_product: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tau', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_householder_product: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tau', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_inv_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_inv_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_inv: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_inv: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'c10::string_view'}], + torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'c10::string_view'}], + torch._C._linalg.linalg_vector_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_vector_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'Scalar'}], + torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'Scalar'}], + torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_svdvals: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_svdvals: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'c10::string_view'}], + torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'c10::string_view'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'double'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'double'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_tensorinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_tensorinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_tensorsolve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_tensorsolve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_qr: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_qr: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._linalg.linalg_matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'double'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'double'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_multi_dot: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._linalg.linalg_multi_dot: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._special.special_entr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_entr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_ndtri: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_ndtri: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_psi: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_psi: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_gammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_gammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfcx: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfcx: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i0e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i0e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i1e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i1e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._special.special_logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._special.special_expit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_expit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._special.special_gammainc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_gammainc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_gammaincc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_gammaincc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_multigammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch._C._special.special_multigammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch._C._special.special_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._special.special_airy_ai: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_airy_ai: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_j1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_j1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_y0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_y0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_y1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_y1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_modified_bessel_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_scaled_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_scaled_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_scaled_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_scaled_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_spherical_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_spherical_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._fft.fft_fftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._fft.fft_rfftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._fft.fft_rfftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._fft.fft_fftshift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifftshift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.retain_grad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rename_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch.Tensor.rename: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch.Tensor.align_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}], + torch.Tensor.align_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'order', 'simple_type': 'DimnameList'}, {'is_kwarg_only': 'False', 'name': 'ellipsis_idx', 'simple_type': 'int64_t'}], + torch.Tensor.align_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.refine_names: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}], + torch.Tensor.abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.abs_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.absolute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.absolute_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.angle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sgn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sgn_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.chalf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.conj_physical_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.resolve_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.resolve_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._neg_view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.acos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arccos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arccos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.addmv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch.Tensor.addmv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch.Tensor.addr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch.Tensor.addr_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch.Tensor._is_all_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._is_any_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.allclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.argmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.argmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.acosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.acosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arccosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arccosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.asinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.asinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arcsinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arcsinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.atanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.atanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arctanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arctanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.as_strided: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.as_strided_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.asin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arcsin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arcsin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.atan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arctan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arctan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch.Tensor.baddbmm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch.Tensor.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}], + torch.Tensor.bernoulli_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Tensor'}], + torch.Tensor.bernoulli_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.bincount: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_not_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.copysign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.copysign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor._lazy_clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logical_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logical_not_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logical_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_xor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_and_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_or_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.broadcast_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ceil_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.unsafe_chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}], + torch.Tensor.chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}], + torch.Tensor.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'SymInt'}], + torch.Tensor.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor_indices_or_sections', 'simple_type': 'Tensor'}], + torch.Tensor.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch.Tensor.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch.Tensor.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch.Tensor.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch.Tensor.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch.Tensor.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch.Tensor.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cov: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.corrcoef: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cumprod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cumprod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cumsum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cumsum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.diag_embed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.diagflat: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.fill_diagonal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch.Tensor.diff: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.true_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.true_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.dot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch.Tensor.vdot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.new_empty: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.new_empty_strided: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.new_full: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch.Tensor.new_zeros: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.new_ones: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.resize_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.erf_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.erfc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exp2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.expm1_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.expand: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.expand_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'DimnameList'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch.Tensor.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}], + torch.Tensor.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.floor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.floor_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.floor_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.frac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.gcd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.gcd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lcm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lcm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.index_copy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_copy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_put_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch.Tensor.index_put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch.Tensor.isclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.isnan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_distributed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_floating_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_complex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._is_zerotensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.isreal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_same_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.is_signed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_inference: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.kron: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch.Tensor.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.nan_to_num: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.nan_to_num_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ldexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.ldexp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log10_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log1p_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logaddexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logaddexp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch.Tensor.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch.Tensor.matrix_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.amax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch.Tensor.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.nanmean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.amin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.multiply_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.multiply_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.mv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch.Tensor.mvlgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch.Tensor.mvlgamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch.Tensor.narrow_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch.Tensor.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch.Tensor.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch.Tensor.permute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch.Tensor.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}], + torch.Tensor.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}], + torch.Tensor.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}], + torch.Tensor.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}], + torch.Tensor.adjoint: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_pinned: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.pin_memory: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.pinverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rad2deg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rad2deg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.deg2rad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.deg2rad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ravel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.reciprocal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.neg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.negative: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.negative_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.repeat: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'Tensor'}], + torch.Tensor.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'SymInt'}], + torch.Tensor.reshape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.reshape_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.prelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch.Tensor.hardshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rsqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'int64_t'}], + torch.Tensor.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch.Tensor.sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logit_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sinc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.detach: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.detach_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.slice_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.slice_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.select_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch.Tensor.diagonal_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.as_strided_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.smm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.unsafe_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch.Tensor.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch.Tensor.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.unsafe_split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch.Tensor.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch.Tensor.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch.Tensor.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch.Tensor.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch.Tensor.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.sspaddmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch.Tensor.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch.Tensor.istft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch.Tensor.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch.Tensor.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.nansum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.hash_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sum_to_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.square: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.square_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.t: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.t_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch.Tensor.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'Dimname'}], + torch.Tensor.transpose_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch.Tensor.flip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch.Tensor.fliplr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.flipud: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.roll: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shifts', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch.Tensor.rot90: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._nested_tensor_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._nested_tensor_strides: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._nested_tensor_storage_offsets: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.trunc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.fix: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.fix_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.type_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.unsqueeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.unsqueeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.view_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.frexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.positive: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.resize_as_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}], + torch.Tensor.resize_as_sparse_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}], + torch.Tensor.zero_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.subtract_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.subtract_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.heaviside: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch.Tensor.heaviside_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch.Tensor.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.addmm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor._addmm_activation: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.sparse_resize_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dense_dim', 'simple_type': 'int64_t'}], + torch.Tensor.sparse_resize_and_clear_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dense_dim', 'simple_type': 'int64_t'}], + torch.Tensor.sparse_mask: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch.Tensor._sparse_mask_projection: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch.Tensor.to_dense: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._to_dense: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sparse_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._dimI: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.dense_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._dimV: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._nnz: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.coalesce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_coalesced: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._values: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._coalesced_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'coalesced', 'simple_type': 'bool'}], + torch.Tensor.indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.values: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.crow_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.col_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ccol_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.row_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}], + torch.Tensor.to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}], + torch.Tensor._to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.to_sparse_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._to_sparse_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.to_sparse_csc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._to_sparse_csc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.to_sparse_bsr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}], + torch.Tensor._to_sparse_bsr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}], + torch.Tensor.to_sparse_bsc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}], + torch.Tensor._to_sparse_bsc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}], + torch.Tensor.to_mkldnn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.dequantize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_scale: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_zero_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_per_channel_scales: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_per_channel_zero_points: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_per_channel_axis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.int_repr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.qscheme: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._autocast_to_reduced_precision: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cuda_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cpu_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cuda_dtype', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'cpu_dtype', 'simple_type': 'ScalarType'}], + torch.Tensor._autocast_to_full_precision: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cuda_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cpu_enabled', 'simple_type': 'bool'}], + torch.Tensor.is_set_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch.Tensor.masked_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.masked_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.masked_scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.masked_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch.Tensor.put_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_reduce_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.index_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.scatter_reduce_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.eq_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.eq_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_and_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_and_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__iand__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__iand__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_or_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_or_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__ior__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__ior__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_xor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_xor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__ixor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__ixor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__ilshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__ilshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_left_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_left_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__irshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__irshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_right_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_right_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.tril_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.triu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.digamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch.Tensor.lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch.Tensor.addbmm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch.Tensor.addbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch.Tensor.random_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'from', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'to', 'simple_type': 'int64_t?'}], + torch.Tensor.random_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'to', 'simple_type': 'int64_t'}], + torch.Tensor.random_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.uniform_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cauchy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log_normal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exponential_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.geometric_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}], + torch.Tensor.diag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.triu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tril: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.trace: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.ne_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.ne_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.not_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.not_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.ge_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.ge_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.greater_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.greater_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.le_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.le_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.less_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.less_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.gt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.gt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.greater_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.greater_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.lt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.less_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.less_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.take: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.take_along_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}], + torch.Tensor.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.masked_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch.Tensor.nonzero_static: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'size', 'simple_type': 'SymInt'}], + torch.Tensor.argwhere: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch.Tensor.addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch.Tensor.addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch.Tensor.addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch.Tensor.triangular_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch.Tensor.svd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.swapaxes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'axis1', 'simple_type': 'int64_t'}], + torch.Tensor.swapaxes_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'axis1', 'simple_type': 'int64_t'}], + torch.Tensor.swapdims: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch.Tensor.swapdims_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch.Tensor.cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cholesky_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch.Tensor.cholesky_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.qr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.geqrf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.orgqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch.Tensor.ormqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input3', 'simple_type': 'Tensor'}], + torch.Tensor.lu_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch.Tensor.multinomial: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_samples', 'simple_type': 'SymInt'}], + torch.Tensor.lgamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.polygamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch.Tensor.erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.erfinv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.i0_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.signbit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.dist: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.atan2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.atan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.arctan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.arctan2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch.Tensor.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch.Tensor.histc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'Tensor'}], + torch.Tensor.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.fmod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.fmod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.hypot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.hypot_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.igamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.igamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.igammac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.igammac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.nextafter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.nextafter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.remainder_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.remainder_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.fmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.fmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch.Tensor.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch.Tensor.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch.Tensor.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}], + torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.msort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool'}], + torch.Tensor.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.topk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch.Tensor.renorm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}], + torch.Tensor.renorm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}], + torch.Tensor.unfold: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'step', 'simple_type': 'int64_t'}], + torch.Tensor.equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch.Tensor.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch.Tensor.pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch.Tensor.pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch.Tensor.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch.Tensor.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch.Tensor.float_power_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch.Tensor.float_power_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch.Tensor.normal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.isfinite: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.isinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.record_stream: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 's', 'simple_type': 'Stream'}], + torch.Tensor.isposinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.isneginf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.det: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.slogdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.inner: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.outer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch.Tensor.ger: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch.Tensor.to_padded_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'double'}], +} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/inductor_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/inductor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..96317780dffb52409562c395c06797b3c658a1db --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/inductor_utils.py @@ -0,0 +1,440 @@ +# mypy: ignore-errors + +import contextlib +import functools +import logging +import os +import re +import sys +import unittest +from subprocess import CalledProcessError + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._inductor.codecache import CppCodeCache +from torch._inductor.codegen.common import ( + get_custom_backend_config_for_device, + get_custom_backend_pass_for_device, + get_scheduling_for_device, + get_wrapper_codegen_for_device, + init_backend_registration, + register_backend_for_device, +) +from torch._inductor.codegen.wrapper import PythonWrapperCodegen +from torch._inductor.compile_fx import shape_env_from_inputs +from torch._inductor.custom_graph_pass import CustomGraphModulePass +from torch._inductor.graph import GraphLowering +from torch._inductor.utils import ( + get_gpu_shared_memory, + get_gpu_type, + GPU_TYPES, + is_big_gpu, + is_gpu, + OrderedSet, +) +from torch.fx.experimental.proxy_tensor import make_fx +from torch.utils._helion import has_helion +from torch.utils._pallas import has_pallas_package, has_tpu_pallas +from torch.utils._triton import has_triton +from torch.utils._config_module import ConfigModule +from torch.testing._internal.common_device_type import ( + get_desired_device_type_test_bases, +) +from torch.testing._internal.common_utils import ( + IS_CI, + IS_FBCODE, + IS_WINDOWS, + LazyVal, + TestCase, +) + +log: logging.Logger = logging.getLogger(__name__) + + +def test_cpu(): + try: + CppCodeCache.load("") + return not IS_FBCODE + except ( + CalledProcessError, + OSError, + torch._inductor.exc.InvalidCxxCompiler, + torch._inductor.exc.CppCompileError, + ): + return False + + +HAS_CPU = LazyVal(test_cpu) + +HAS_TRITON = has_triton() + +HAS_PALLAS = has_pallas_package() + +HAS_HELION = has_helion() + +if HAS_TRITON: + import triton + + TRITON_HAS_CPU = "cpu" in triton.backends.backends +else: + TRITON_HAS_CPU = False + + +HAS_CUDA_AND_TRITON = torch.cuda.is_available() and HAS_TRITON + +HAS_XPU_AND_TRITON = torch.xpu.is_available() and HAS_TRITON + +HAS_MPS = torch.mps.is_available() + +HAS_GPU = HAS_CUDA_AND_TRITON or HAS_XPU_AND_TRITON +HAS_GPU_AND_TRITON = HAS_GPU + +GPU_TYPE = get_gpu_type() + +HAS_MULTIGPU = any( + getattr(torch, gpu).is_available() and getattr(torch, gpu).device_count() >= 2 + for gpu in GPU_TYPES +) + +_desired_test_bases = get_desired_device_type_test_bases(allow_xpu=True) +RUN_GPU = HAS_GPU and any( + is_gpu(getattr(x, "device_type", "")) for x in _desired_test_bases +) + +RUN_CPU = HAS_CPU and any( + getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases +) + +HAS_TPU = has_tpu_pallas() +RUN_TPU = HAS_TPU + + +def _check_has_dynamic_shape( + self: TestCase, + code, +): + for_loop_found = False + has_dynamic = False + lines = code.split("\n") + for line in lines: + if "for(" in line: + for_loop_found = True + if re.search(r";.*ks.*;", line) is not None: + has_dynamic = True + break + self.assertTrue( + has_dynamic, msg=f"Failed to find dynamic for loop variable\n{code}" + ) + self.assertTrue(for_loop_found, f"Failed to find for loop\n{code}") + + +def skipDeviceIf(cond, msg, *, device): + if cond: + + def decorate_fn(fn): + @functools.wraps(fn) + def inner(self, *args, **kwargs): + if not hasattr(self, "device"): + warn_msg = ( + "Expect the test class to have attribute device but not found. " + ) + if hasattr(self, "device_type"): + warn_msg += "Consider using the skip device decorators in common_device_type.py" + log.warning(warn_msg) + if self.device == device: + raise unittest.SkipTest(msg) + return fn(self, *args, **kwargs) + + return inner + + else: + + def decorate_fn(fn): + return fn + + return decorate_fn + + +def skip_windows_ci(name: str, file: str) -> None: + if IS_WINDOWS and IS_CI: + module = os.path.basename(file).strip(".py") + sys.stderr.write( + f"Windows CI does not have necessary dependencies for {module} tests yet\n" + ) + if name == "__main__": + sys.exit(0) + raise unittest.SkipTest("requires sympy/functorch/filelock") + + +# TODO: Remove HAS_MPS condition when `HAS_GPU` includes HAS_MPS +requires_gpu = functools.partial( + unittest.skipIf, not (HAS_GPU or HAS_MPS), "requires gpu" +) +requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton") +requires_helion = functools.partial(unittest.skipIf, not HAS_HELION, "requires helion") + + +def requires_cuda_with_enough_memory(min_mem_required): + def inner(fn): + if ( + not torch.cuda.is_available() + or torch.cuda.get_device_properties().total_memory < min_mem_required + ): + return unittest.skip( + f"Only if the CUDA device has at least {min_mem_required / 1e9:.3f}GB memory to be safe" + )(fn) + else: + return fn + + return inner + + +skipCUDAIf = functools.partial(skipDeviceIf, device="cuda") +skipXPUIf = functools.partial(skipDeviceIf, device="xpu") +skipCPUIf = functools.partial(skipDeviceIf, device="cpu") + +IS_A100 = LazyVal(lambda: HAS_CUDA_AND_TRITON and get_gpu_shared_memory() == 166912) + +IS_H100 = LazyVal(lambda: HAS_CUDA_AND_TRITON and get_gpu_shared_memory() == 232448) + +IS_BIG_GPU = LazyVal(lambda: HAS_GPU_AND_TRITON and is_big_gpu()) + + +def dummy_graph() -> GraphLowering: + """ + Create a graph. This is useful for unit testing code which accesses + V.graph.sizevars. + """ + example_inputs = [torch.randn(10) for _ in range(2)] + gm = make_fx(torch.add, tracing_mode="fake")(*example_inputs) + shape_env = shape_env_from_inputs(example_inputs) + graph = GraphLowering( + gm, + shape_env=shape_env, + ) + + return graph + + +def maybe_skip_size_asserts(op): + """ + For certain ops, there meta and eager implementation returns different + strides. This cause size/strides assert fail. Skip adding those + asserts for now. + """ + if ( + op.aten_name + in ( + "fft_hfftn", + "fft_hfft", + "fft_hfft2", + "fft_ihfftn", + "fft_fft", + "fft_fft2", + "fft_fftn", + "fft_ifft", + "fft_ifft2", + "fft_ifftn", + "fft_irfft", + "fft_irfft2", + "fft_irfftn", + "fft_ihfft", + "fft_ihfft2", + "fft_rfft", + "fft_rfft2", + "fft_rfftn", + "linalg_eig", + "linalg_eigvals", + ) + and "TORCHINDUCTOR_SIZE_ASSERTS" not in os.environ + ): + return torch._inductor.config.patch(size_asserts=False) + else: + return contextlib.nullcontext() + + +def get_func_call() -> str: + return ( + "void inductor_entry_impl(" + if torch._inductor.config.cpp_wrapper + else "def call(" + ) + + +def get_kernel_launch() -> str: + return "call_triton_" if torch._inductor.config.cpp_wrapper else ".run(" + + +def clone_preserve_strides_offset(x, device=None): + if not isinstance(x, torch.Tensor): + return x + buffer = torch.as_strided( + x, (x.untyped_storage().size() // x.element_size(),), (1,), 0 + ) + if not device: + buffer = buffer.clone() + else: + buffer = buffer.to(device, copy=True) + out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset()) + return out + + +# define the e4m3/e5m2 constants +E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max +E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max +E4M3FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max +E5M2FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max + +FP16_MAX_POS: float = torch.finfo(torch.float16).max +EPS: float = 1e-12 + +Tensor = torch.Tensor + + +def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor: + # The default behavior in PyTorch for casting to `float8_e4m3fn` + # and `e5m2` is to not saturate. In this context, we should saturate. + # A common case where we want to saturate is when the history of a + # tensor has a maximum value of `amax1`, and the current amax value + # is `amax2`, where `amax1 < amax2`. This is common when using delayed + # scaling. + if float8_dtype == torch.float8_e4m3fn: + x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) + elif float8_dtype == torch.float8_e5m2: + x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) + elif float8_dtype == torch.float8_e4m3fnuz: + x = x.clamp(min=-1 * E4M3FNUZ_MAX_POS, max=E4M3FNUZ_MAX_POS) + elif float8_dtype == torch.float8_e5m2fnuz: + x = x.clamp(min=-1 * E5M2FNUZ_MAX_POS, max=E5M2FNUZ_MAX_POS) + else: + raise TypeError(f"Unsupported float8_dtype: {float8_dtype}") + return x.to(float8_dtype) + + +@torch.no_grad() +def _amax_to_scale( + amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype +) -> torch.Tensor: + # To make scale dtype to be fp32 for accuracy + amax = amax.float() + if float8_dtype == torch.float8_e4m3fn: + res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) + else: # e5m2 + res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) + + # Ensure that the scale is representable in float16, + # this helps when amax is small. We are assuming that we don't need + # to care about this for float32/bfloat16. + if orig_dtype is torch.float16: + res = torch.clamp(res, max=FP16_MAX_POS) + return res + + +def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype): + amax = torch.max(torch.abs(x)) + scale = _amax_to_scale(amax, float8_dtype, x.dtype) + x_fp8 = _to_fp8_saturated(x * scale, float8_dtype) + inverse_scale = scale.reciprocal() + return x_fp8, inverse_scale + + +def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype): + amax = torch.max(torch.abs(x), dim=1, keepdim=True).values + scale = _amax_to_scale(amax, float8_dtype, x.dtype) + x_fp8 = _to_fp8_saturated(x * scale, float8_dtype) + inverse_scale = scale.reciprocal() + return x_fp8, inverse_scale + + +def _quantize_blockwise( + x: Tensor, float8_dtype: torch.dtype, block_outer: int, block_inner: int +): + min_outer = min(block_outer, x.shape[0]) + min_inner = min(block_inner, x.shape[1]) + x = x.unflatten(1, (-1, min_inner)).unflatten(0, (-1, min_outer)) + amax = x.abs().amax(dim=[1, 3], keepdim=True).float() + scale = _amax_to_scale(amax, float8_dtype, x.dtype) + x = x.flatten(2, 3).flatten(0, 1) + scale = scale.flatten(2, 3).flatten(0, 1) + scale_expanded = scale.repeat_interleave(min_outer, dim=0).repeat_interleave( + min_inner, dim=1 + ) + x_fp8 = _to_fp8_saturated( + x / scale_expanded, # Ensures that scaling doesn't cause inf/nan values + float8_dtype, + ) + inverse_scale = scale.reciprocal() + return x_fp8, inverse_scale + + +class MockGraphHandler(GraphLowering): + """Minimal mock graph handler for testing virtualized context.""" + + def __init__(self, name_to_buffer=None): + import torch._inductor.sizevars + + self.sizevars = torch._inductor.sizevars.SizeVarAllocator() + self.name_to_buffer = name_to_buffer or {} + self.graph_inputs = {} + self.mutated_buffers = OrderedSet() + self.removed_buffers = OrderedSet() + self.constants = {} + self.scheduler = None + + def get_dtype(self, buffer_name: str) -> torch.dtype: # noqa: ARG002 + """Return default dtype for any buffer (for testing).""" + return torch.float32 + + +@contextlib.contextmanager +def patch_inductor_backend( + device: str, + python_wrapper_codegen: PythonWrapperCodegen = None, + custom_pass: CustomGraphModulePass = None, + custom_backend_config: ConfigModule = None, +): + """ + Patch the inductor backend for a specific device. + """ + # Make sure the backend is already registered + init_backend_registration() + + # Get the original registration parameters + original_scheduling = get_scheduling_for_device(device) + original_python_wrapper = get_wrapper_codegen_for_device(device, False) + original_cpp_wrapper = get_wrapper_codegen_for_device(device, True) + original_fx_wrapper = get_wrapper_codegen_for_device(device, fx_wrapper=True) + original_custom_pass = get_custom_backend_pass_for_device(device) + original_custom_backend_config = get_custom_backend_config_for_device(device) + + try: + # Register modified backend for the device + register_backend_for_device( + device, + original_scheduling, + ( + python_wrapper_codegen + if python_wrapper_codegen is not None + else original_python_wrapper + ), + original_cpp_wrapper, + original_fx_wrapper, + custom_pass if custom_pass is not None else original_custom_pass, + ( + custom_backend_config + if custom_backend_config is not None + else original_custom_backend_config + ), + ) + yield + finally: + # Restore the original backend + register_backend_for_device( + device, + original_scheduling, + original_python_wrapper, + original_cpp_wrapper, + original_fx_wrapper, + original_custom_pass, + original_custom_backend_config, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/logging_tensor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/logging_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..e71f0f46854756a4b4251df6a53a03a288183172 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/logging_tensor.py @@ -0,0 +1,168 @@ +# mypy: ignore-errors + +import torch +from torch.utils._pytree import tree_map +from typing import Optional +from collections.abc import Iterator +import logging +import contextlib +import itertools +from torch.utils._dtype_abbrs import dtype_abbrs as _dtype_abbrs +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils.weak import WeakTensorKeyDictionary +import functools +from torch._C._profiler import gather_traceback, symbolize_tracebacks + +logger = logging.getLogger("LoggingTensor") + +# How the chain of calls works for LoggingTensor: +# 1. Call torch.sin +# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely +# 3. Enter dispatcher, wind your way through Autograd +# 4. Hit Python dispatch key, call __torch_dispatch__ + +# This Tensor can work with autograd in two ways: +# - The wrapped Tensor does not require gradients. In that case, the LoggingTensor +# can require gradients if the user asks for it as a constructor kwarg. +# - The wrapped Tensor can require gradients. In that case autograd will be tracked +# for the wrapped Tensor and the LoggingTensor itself cannot require gradients. +# WARNING: We allow these two possibilities for testing purposes. You should NEVER use both in a single +# test or you might get surprising behavior. + +# TODO: TensorBase should work +class LoggingTensor(torch.Tensor): + elem: torch.Tensor + + __slots__ = ['elem'] + + context = contextlib.nullcontext + + @staticmethod + def __new__(cls, elem, *args, **kwargs): + # The wrapping tensor (LoggingTensor) shouldn't hold any + # memory for the class in question, but it should still + # advertise the same device as before + r = torch.Tensor._make_wrapper_subclass( + cls, elem.size(), + strides=elem.stride(), storage_offset=elem.storage_offset(), + # TODO: clone storage aliasing + dtype=elem.dtype, layout=elem.layout, + device=elem.device, requires_grad=kwargs.get("requires_grad", False) + ) + # ...the real tensor is held as an element on the tensor. + r.elem = elem.detach() if r.requires_grad else elem + return r + + def __repr__(self): + return super().__repr__(tensor_contents=f"{self.elem}") + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def unwrap(e): + return e.elem if isinstance(e, cls) else e + + def wrap(e): + return cls(e) if isinstance(e, torch.Tensor) else e + + with cls.context(): + rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) + logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) # noqa: G004 + return rs + +class LoggingTensorMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + rs = func(*args, **kwargs) + logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) # noqa: G004 + return rs + +class LoggingTensorReentrant(LoggingTensor): + context = torch.overrides.enable_reentrant_dispatch + +# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list +class LoggingTensorHandler(logging.Handler): + def __init__( + self, log_list: list[str], use_shortid_for_all_tensors: bool, + with_type: bool, tracebacks_list: Optional[list]) -> None: + logging.Handler.__init__(self) + self.log_list = log_list + self.use_shortid_for_all_tensors = use_shortid_for_all_tensors + self.tracebacks_list = tracebacks_list + self.memo = WeakTensorKeyDictionary() + self.next_id = 0 + self.with_type = with_type + + def _shortid(self, t: torch.Tensor) -> int: + if t not in self.memo: + self.memo[t] = self.next_id + self.next_id += 1 + return self.memo[t] + + def _fmt(self, a: object, with_type: bool = False) -> str: + cond_cls = torch.Tensor if self.use_shortid_for_all_tensors else LoggingTensor + if isinstance(a, cond_cls): + maybe_type = "" + if with_type and self.with_type: + maybe_type = f": {_dtype_abbrs[a.dtype]}[{', '.join(map(str, a.shape))}]" + x = f"${self._shortid(a)}{maybe_type}" + return x + else: + return repr(a) + + def emit(self, record): + fmt_args = ", ".join( + itertools.chain( + (str(tree_map(self._fmt, a)) for a in record.args[0]), + (f"{k}={str(tree_map(self._fmt, v))}" for k, v in record.args[1].items()), + ) + ) + fmt_rets = tree_map(functools.partial(self._fmt, with_type=True), record.args[2]) + self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})') + if self.tracebacks_list is not None: + self.tracebacks_list.append(record.traceback) + +def log_input(name: str, var: object) -> None: + logger.info("input", (name,), {}, var) # noqa: PLE1205 + +class GatherTraceback(logging.Filter): + def __init__(self, python=True, script=True, cpp=False): + self.python = python + self.script = script + self.cpp = cpp + + def filter(self, record): + record.traceback = gather_traceback(python=self.python, script=self.script, cpp=self.cpp) + return True + +@contextlib.contextmanager +def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[list[str]]: + collect_traceback = python_tb or script_tb or cpp_tb + log_list: list[str] = [] + tracebacks_list: list[str] = [] + handler = LoggingTensorHandler( + log_list, + with_type=True, + use_shortid_for_all_tensors=is_mode, + tracebacks_list=tracebacks_list if collect_traceback else None + ) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + logger.propagate = False + if collect_traceback: + logger.addFilter(GatherTraceback(python=python_tb, script=script_tb, cpp=cpp_tb)) + try: + if collect_traceback: + yield log_list, tracebacks_list + else: + yield log_list + finally: + symbolized_tracebacks = symbolize_tracebacks(tracebacks_list) + tracebacks_list.clear() + tracebacks_list.extend(symbolized_tracebacks) + logger.removeHandler(handler) + +@contextlib.contextmanager +def capture_logs_with_logging_tensor_mode(python_tb=False, script_tb=False, cpp_tb=False): + with LoggingTensorMode(), capture_logs(True, python_tb, script_tb, cpp_tb) as logs: + yield logs diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/logging_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e1ecf8f4f707c9b3712a6fb738fc9ce1467b835 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/logging_utils.py @@ -0,0 +1,243 @@ +# mypy: ignore-errors + +import torch._dynamo.test_case +import unittest.mock +import os +import contextlib +import torch._logging +import torch._logging._internal +from contextlib import AbstractContextManager +from collections.abc import Callable +from torch._dynamo.utils import LazyString +from torch._inductor import config as inductor_config +import logging +import io + +@contextlib.contextmanager +def preserve_log_state(): + prev_state = torch._logging._internal._get_log_state() + torch._logging._internal._set_log_state(torch._logging._internal.LogState()) + try: + yield + finally: + torch._logging._internal._set_log_state(prev_state) + torch._logging._internal._init_logs() + +def log_settings(settings): + exit_stack = contextlib.ExitStack() + settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings}) + exit_stack.enter_context(preserve_log_state()) + exit_stack.enter_context(settings_patch) + torch._logging._internal._init_logs() + return exit_stack + +def log_api(**kwargs): + exit_stack = contextlib.ExitStack() + exit_stack.enter_context(preserve_log_state()) + torch._logging.set_logs(**kwargs) + return exit_stack + + +def kwargs_to_settings(**kwargs): + INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"} + + settings = [] + + def append_setting(name, level): + if isinstance(name, str) and isinstance(level, int) and level in INT_TO_VERBOSITY: + settings.append(INT_TO_VERBOSITY[level] + name) + return + else: + raise ValueError("Invalid value for setting") + + for name, val in kwargs.items(): + if isinstance(val, bool): + settings.append(name) + elif isinstance(val, int): + append_setting(name, val) + elif isinstance(val, dict) and name == "modules": + for module_qname, level in val.items(): + append_setting(module_qname, level) + else: + raise ValueError("Invalid value for setting") + + return ",".join(settings) + + +# Note on testing strategy: +# This class does two things: +# 1. Runs two versions of a test: +# 1a. patches the env var log settings to some specific value +# 1b. calls torch._logging.set_logs(..) +# 2. patches the emit method of each setup handler to gather records +# that are emitted to each console stream +# 3. passes a ref to the gathered records to each test case for checking +# +# The goal of this testing in general is to ensure that given some settings env var +# that the logs are setup correctly and capturing the correct records. +def make_logging_test(**kwargs): + def wrapper(fn): + @inductor_config.patch({"fx_graph_cache": False}) + def test_fn(self): + + torch._dynamo.reset() + records = [] + # run with env var + if len(kwargs) == 0: + with self._handler_watcher(records): + fn(self, records) + else: + with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records): + fn(self, records) + + # run with API + torch._dynamo.reset() + records.clear() + with log_api(**kwargs), self._handler_watcher(records): + fn(self, records) + + + return test_fn + + return wrapper + +def make_settings_test(settings): + def wrapper(fn): + def test_fn(self): + torch._dynamo.reset() + records = [] + # run with env var + with log_settings(settings), self._handler_watcher(records): + fn(self, records) + + return test_fn + + return wrapper + +class LoggingTestCase(torch._dynamo.test_case.TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._exit_stack.enter_context( + unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""}) + ) + cls._exit_stack.enter_context( + unittest.mock.patch("torch._dynamo.config.suppress_errors", True) + ) + cls._exit_stack.enter_context( + unittest.mock.patch("torch._dynamo.config.verbose", False) + ) + + @classmethod + def tearDownClass(cls): + cls._exit_stack.close() + torch._logging._internal.log_state.clear() + torch._logging._init_logs() + + def hasRecord(self, records, m): + return any(m in r.getMessage() for r in records) + + def getRecord(self, records, m): + record = None + for r in records: + # NB: not r.msg because it looks like 3.11 changed how they + # structure log records + if m in r.getMessage(): + self.assertIsNone( + record, + msg=LazyString( + lambda: f"multiple matching records: {record} and {r} among {records}" + ), + ) + record = r + if record is None: + self.fail(f"did not find record with {m} among {records}") + return record + + # This patches the emit method of each handler to gather records + # as they are emitted + def _handler_watcher(self, record_list): + exit_stack = contextlib.ExitStack() + + def emit_post_hook(record): + nonlocal record_list + record_list.append(record) + + # registered logs are the only ones with handlers, so patch those + for log_qname in torch._logging._internal.log_registry.get_log_qnames(): + logger = logging.getLogger(log_qname) + num_handlers = len(logger.handlers) + self.assertLessEqual( + num_handlers, + 2, + "All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).", + ) + + self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers") + + for handler in logger.handlers: + old_emit = handler.emit + + def new_emit(record): + old_emit(record) + emit_post_hook(record) + + exit_stack.enter_context( + unittest.mock.patch.object(handler, "emit", new_emit) + ) + + return exit_stack + + +def logs_to_string(module, log_option): + """Example: + logs_to_string("torch._inductor.compile_fx", "post_grad_graphs") + returns the output of TORCH_LOGS="post_grad_graphs" from the + torch._inductor.compile_fx module. + """ + log_stream = io.StringIO() + handler = logging.StreamHandler(stream=log_stream) + + @contextlib.contextmanager + def tmp_redirect_logs(): + try: + logger = torch._logging.getArtifactLogger(module, log_option) + logger.addHandler(handler) + yield + finally: + logger.removeHandler(handler) + + def ctx_manager(): + exit_stack = log_settings(log_option) + exit_stack.enter_context(tmp_redirect_logs()) + return exit_stack + + return log_stream, ctx_manager + + +def multiple_logs_to_string(module: str, *log_options: str) -> tuple[list[io.StringIO], Callable[[], AbstractContextManager[None]]]: + """Example: + multiple_logs_to_string("torch._inductor.compile_fx", "pre_grad_graphs", "post_grad_graphs") + returns the output of TORCH_LOGS="pre_graph_graphs, post_grad_graphs" from the + torch._inductor.compile_fx module. + """ + log_streams = [io.StringIO() for _ in range(len(log_options))] + handlers = [logging.StreamHandler(stream=log_stream) for log_stream in log_streams] + + @contextlib.contextmanager + def tmp_redirect_logs(): + loggers = [torch._logging.getArtifactLogger(module, option) for option in log_options] + try: + for logger, handler in zip(loggers, handlers, strict=True): + logger.addHandler(handler) + yield + finally: + for logger, handler in zip(loggers, handlers, strict=True): + logger.removeHandler(handler) + + def ctx_manager() -> AbstractContextManager[None]: + exit_stack = log_settings(", ".join(log_options)) + exit_stack.enter_context(tmp_redirect_logs()) + return exit_stack # type: ignore[return-value] + + return log_streams, ctx_manager diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97c38f3560625213fbd59d09a9cfd22bad26ba04 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__init__.py @@ -0,0 +1,4 @@ +# mypy: ignore-errors + +import torch.testing._internal.opinfo.core +import torch.testing._internal.opinfo.definitions diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e88e239e7b6ce0567c09e8640c23a9547dd67e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py @@ -0,0 +1,3221 @@ +# mypy: ignore-errors + +import collections +import collections.abc +import contextlib +import logging +import math +import operator +import unittest +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable +from dataclasses import asdict, dataclass, field +from enum import Enum +from functools import partial +from itertools import product +from typing import Any, Optional, TypeVar, Union + +import torch +from torch.testing import make_tensor +from torch.testing._internal.common_device_type import ( + skipCPUIfNoFFT, + tol, + toleranceOverride, +) +from torch.testing._internal.common_dtype import ( + _dispatch_dtypes, + floating_and_complex_types, + floating_and_complex_types_and, + floating_types, + get_all_dtypes, +) +from torch.testing._internal.common_utils import ( + extract_test_fn, + IS_FBCODE, + is_iterable_of_tensors, + noncontiguous_like, + OPINFO_SAMPLE_INPUT_INDEX, + TEST_WITH_ROCM, + torch_to_numpy_dtype_dict, + TrackedInputIter, + USE_PYTEST, +) +from torch.testing._internal.opinfo import utils +from torchgen.utils import dataclass_repr + + +# setup logging +log = logging.getLogger(__name__) + +# Reasonable testing sizes for dimensions +L = 20 +M = 10 +S = 5 +XS = 3 + +# Unique value to distinguish default from anything else +_NOTHING = object() + + +# Extension of getattr to support qualified names +# e.g. _getattr_qual(torch, 'linalg.norm') -> torch.linalg.norm +def _getattr_qual(obj, name, default=_NOTHING): + try: + for path in name.split("."): + obj = getattr(obj, path) + return obj + except AttributeError: + if default is not _NOTHING: + return default + else: + raise + + +class DecorateInfo: + """Describes which test, or type of tests, should be wrapped in the given + decorators when testing an operator. Any test that matches all provided + arguments will be decorated. The decorators will only be applied if the + active_if argument is True.""" + + __slots__ = [ + "decorators", + "cls_name", + "test_name", + "device_type", + "dtypes", + "active_if", + ] + + def __init__( + self, + decorators, + cls_name=None, + test_name=None, + *, + device_type=None, + dtypes=None, + active_if=True, + ): + self.decorators = ( + list(decorators) + if isinstance(decorators, collections.abc.Sequence) + else [decorators] + ) + self.cls_name = cls_name + self.test_name = test_name + self.device_type = device_type + self.dtypes = dtypes + self.active_if = active_if + + # Validate dtypes + if self.dtypes is not None: + for dtype in self.dtypes: + assert isinstance(dtype, torch.dtype) + + def is_active(self, cls_name, test_name, device_type, dtype, param_kwargs): + return ( + self.active_if + and (self.cls_name is None or self.cls_name == cls_name) + and (self.test_name is None or self.test_name == test_name) + and (self.device_type is None or self.device_type == device_type) + and (self.dtypes is None or dtype in self.dtypes) + # Support callables over kwargs to determine if the decorator is active. + and ( + self.active_if(param_kwargs) + if isinstance(self.active_if, Callable) + else self.active_if + ) + ) + + +# FIXME +# Note: historically the 'input' kwarg had to be a Tensor or TensorList, but we are trying +# to support scalar inputs, too. Some tests still depend on 'input' being a Tensor +# or TensorList, however. +class SampleInput: + """Represents sample inputs to a function.""" + + __slots__ = [ + "input", + "args", + "kwargs", + "output_process_fn_grad", + "broadcasts_input", + "name", + ] + + def __init__( + self, + input, + *var_args, + args=None, + kwargs=None, + output_process_fn_grad=None, + broadcasts_input=None, + name=None, + **var_kwargs, + ): + # input is the first input to the op and is typically either a Tensor or TensorList (Sequence[Tensor]). + # This follows the typical pattern where for Tensor inputs op(t, ...) = t.op(...). + self.input = input + + # Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as + # SampleInput(input, *args, **kwargs) but not to mix the two forms + if args is not None or kwargs is not None: + assert not var_args and not var_kwargs, """ +A SampleInput can be constructed "naturally" with *args and **kwargs or by +explicitly setting the "args" and "kwargs" parameters, but the two +methods of construction cannot be mixed!""" + elif var_args or var_kwargs: + assert ( + output_process_fn_grad is None + and broadcasts_input is None + and name is None + ), """ +A SampleInput constructed "naturally" with *args and **kwargs +cannot specify additional metadata in keyword arguments""" + + self.args = args if args is not None else var_args + assert isinstance(self.args, tuple) + self.kwargs = kwargs if kwargs is not None else var_kwargs + assert isinstance(self.kwargs, dict) + + self.output_process_fn_grad = ( + output_process_fn_grad + if output_process_fn_grad is not None + else lambda x: x + ) + self.name = name if name is not None else "" + + # Specifies if `self.input` is broadcasted or not, + # given that the operator supports broadcasting. + # This field is used to verify the behavior for inplace variant. + # + # If a SampleInput is marked with `broadcasts_input=True`, + # it is verified that we get a `RuntimeError` with this sample, + # and inplace variant. Also inplace grad{grad} tests are skipped, + # for such inputs (as they will error out otherwise). + self.broadcasts_input = ( + broadcasts_input if broadcasts_input is not None else False + ) + + def with_metadata( + self, *, output_process_fn_grad=None, broadcasts_input=None, name=None + ): + if output_process_fn_grad is not None: + self.output_process_fn_grad = output_process_fn_grad + if broadcasts_input is not None: + self.broadcasts_input = broadcasts_input + if name is not None: + self.name = name + return self + + def _repr_helper(self, formatter): + # Helper function to return the details of the SampleInput as `str` + # It consolidates all the fields of SampleInput and allows, + # formatting the fields like `input`, `args`, etc with `formatter` + # callable to customize the representation. + # Look at `summary` method for example. + arguments = [ + f"input={formatter(self.input)}", + f"args={formatter(self.args)}", + f"kwargs={formatter(self.kwargs)}", + f"broadcasts_input={self.broadcasts_input}", + f"name={repr(self.name)}", + ] + + return f"SampleInput({', '.join(a for a in arguments if a is not None)})" + + def __repr__(self): + return self._repr_helper(lambda x: x) + + def summary(self): + # Returns the SampleInput details in a more + # friendly format. + # It formats `Tensor` and `TensorList` + # in a more condensed representation. + def formatter(arg): + # Format any instance of `Tensor` (standalone, in list, or in dict) + # by Tensor[TensorShape] + # Eg. Tensor with shape (3, 4) is formatted as Tensor[3, 4] + if isinstance(arg, torch.Tensor): + shape = str(tuple(arg.shape)) + dtype = str(arg.dtype) + device = str(arg.device) + contiguity_suffix = "" + # NB: sparse CSR tensors annoyingly return is_sparse=False + is_sparse = arg.is_sparse or arg.layout == torch.sparse_csr + if not is_sparse and not arg.is_contiguous(): + contiguity_suffix = ", contiguous=False" + return f'Tensor[size={shape}, device="{device}", dtype={dtype}{contiguity_suffix}]' + elif isinstance(arg, dict): + return {k: formatter(v) for k, v in arg.items()} + elif is_iterable_of_tensors(arg): + return "TensorList[" + ", ".join(map(formatter, arg)) + "]" + elif isinstance(arg, (list, tuple)): # Handle list, tuple + return "(" + ",".join(map(formatter, arg)) + ")" + + return repr(arg) + + return self._repr_helper(formatter) + + # Applies the transform f(t) -> t to each tensor and dtype in the SampleInput + def transform(self, f): + def tt(t): + def _tt(t): + with torch.no_grad(): + return f(t) + + if isinstance(t, torch.Tensor): + return _tt(t) + elif isinstance(t, torch.dtype): + return _tt(t) + elif isinstance(t, list): + return list(map(tt, t)) + elif isinstance(t, tuple): + return tuple(map(tt, t)) + elif isinstance(t, dict): + return {k: tt(v) for k, v in t.items()} + else: + return t + + sample_tt_input, tt_args, tt_kwargs = ( + tt(self.input), + tt(self.args), + tt(self.kwargs), + ) + + # Note the transformed SampleInput assumes metadata like output_process_fn_grad is still valid! + return SampleInput( + sample_tt_input, + args=tt_args, + kwargs=tt_kwargs, + output_process_fn_grad=self.output_process_fn_grad, + broadcasts_input=self.broadcasts_input, + name=self.name + "_transformed", + ) + + # Returns the NumPy version of the sample input object in the form of a tuple: (input, args, kwargs) + # Converts tensors to ndarrays by calling .detach().cpu().numpy() on them + # Converts dtypes by remapping them using torch_to_numpy_dtype_dict + def numpy(self): + def to_numpy(t): + if isinstance(t, torch.Tensor): + if t.dtype is torch.bfloat16: + return t.detach().cpu().to(torch.float32).numpy() + if t.dtype is torch.chalf: + return t.detach().cpu().to(torch.cfloat).numpy() + return t.detach().cpu().numpy() + elif isinstance(t, torch.dtype): + return torch_to_numpy_dtype_dict[t] + + return t + + return self.transform(to_numpy) + + def noncontiguous(self): + def to_noncontiguous(t): + if isinstance(t, torch.Tensor): + return noncontiguous_like(t) + elif isinstance(t, torch.dtype): + return t + + return t + + return self.transform(to_noncontiguous) + + +NumericsFilter = collections.namedtuple("NumericsFilter", ["condition", "safe_val"]) + + +class ErrorInput: + """ + A SampleInput that will cause the operation to throw an error plus information + about the resulting error. + """ + + __slots__ = ["sample_input", "error_type", "error_regex"] + + def __init__(self, sample_input, *, error_type=RuntimeError, error_regex): + self.sample_input = sample_input + self.error_type = error_type + self.error_regex = error_regex + + +class AliasInfo: + """Class holds alias information. For example, torch.abs -> + torch.absolute, torch.Tensor.absolute, torch.Tensor.absolute_ + """ + + def __init__(self, alias_name): + self.name = alias_name + self.op = _getattr_qual(torch, alias_name) + self.method_variant = getattr(torch.Tensor, alias_name, None) + self.inplace_variant = getattr(torch.Tensor, alias_name + "_", None) + + def __call__(self, *args, **kwargs): + return self.op(*args, **kwargs) + + +# Note [OpInfos] +# ~~~~~~~~~~~~~~ +# +# The majority of this note was written shortly after the PyTorch 1.9 release. +# If you notice it's out-of-date or think it could be improved then please +# file an issue. +# +# See also: the OpInfo tracker (https://github.com/pytorch/pytorch/issues/54261) +# See also: "Writing Test Templates" in common_device_type.py to learn how to +# parametrize a test template using OpInfos. +# See also: PyTorch's GitHub wiki on running and writing tests +# https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests +# See also: ModuleInfos, OpInfo's sister class, defined in common_modules.py +# +# An OpInfo is a collection of metadata related to a PyTorch operator. This +# metadata is used to generate tests that validate properties of the operator, +# like if it implements the correct gradient formula. +# +# WHY OPINFOS? +# ~~~~~~~~~~~~ +# +# OpInfos are principally intended to do three things: +# +# 1) to allow systematic testing over all PyTorch's operators +# 2) to simplify operating testing by autogenerating many tests +# 3) to allow systems (like autograd, torchscript, fx, nnc...) to test +# against every PyTorch operator +# +# All these goals are still a work in progress. Not every operator has an +# OpInfo, and some operator tests that could be automatically generated +# still have to be written manually. +# +# It's helpful to understand that OpInfos are both about test simplification and +# modularity. PyTorch is a complicated framework with many interrelated systems, +# too many for any one person to keep track of. An OpInfo can be thought of as the +# interface between an operator implementer and those other systems. Instead of +# requiring the implementer of torch.foo understand how to test its forward +# mode AD or NNC support that's typically handled automatically just by +# defining an OpInfo. +# +# It's often surprising to OpInfo writers that just implementing an OpInfo +# typically can't verify an operator is actually implemented correctly: +# +# "If an OpInfo doesn't validate my op works as expected, what's the point +# of it?" +# +# But the point of is the above. OpInfos are intended to let you focus on testing +# the operator logic you're familiar with instead of having to write tests for +# how the operator interacts with each of PyTorch's many systems. +# +# And, OK, it turns out that SOMETIMES just writing an OpInfo DOES +# validate your op works as expected, but that's only in special +# cases. See below for details. +# +# WHAT'S AN OPINFO? +# ~~~~~~~~~~~~~~~~~ +# +# So what is an OpInfo? It's a Python class that describes an operator's properties, +# like which dtypes it supports on the CPU and whether it has any aliases. +# These properties can be divided into three categories: +# +# 1) Metadata describing the operator, like the operator's name and if it +# "supports" the out kwarg. +# 2) Test directives, like "skips" that tell the test suite to skip some +# tests. +# 3) A "sample inputs" function that generates valid inputs for the operator. +# +# OpInfo attributes are described in more detail below. +# +# THE SAMPLE INPUTS FUNCTION +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The "sample inputs" function merits special elaboration. This function is +# crucial to testing with OpInfos. A typical OpInfo test has to treat the operator +# as a black box. There's no structure for the test to understand or exploit. +# Without "sample inputs" it wouldn't even know how to call the OpInfo's +# operator. The sample input function saves the day by providing different +# "SampleInputs" that can be used to call the operator. A sample input +# function should have the following signature: +# +# def sample_inputs_foo(op_info, device, dtype, requires_grad, **kwargs): +# +# And should return an iterable of SampleInputs (see the class description +# above). Each SampleInput defines an "input", "args", "kwargs", an +# "output_process_fn_grad" function, the "broadcasts_input" bool and a +# "name". +# +# All the "sample_inputs" functions are invoked within a `torch.no_grad()` +# environment for efficiency and correctness. As such remember to set the +# "requires_grad" flag on the inputs **after** performing any transformations +# on them. +# +# The "input" is the first argument to the operator, or the tensor that +# the method or inplace variants of the operator should be called on, and +# should be on the requested device, of the requested dtype, and its +# requires_grad attribute should be set to the requires_grad argument. +# +# "args" should contain positional arguments, and "kwargs" keyword arguments. +# +# "output_process_fn_grad" has an interesting name. It's a function that maps +# the operator's output (when given the input, args, and kwargs) to the +# portion of the output to gradcheck. For example, consider an operator +# like torch.linalg.slogdet +# (https://pytorch.org/docs/main/generated/torch.linalg.slogdet.html). +# This operator returns a tuple of two tensors, but the first tensor +# cannot be backwarded through. Its "output_process_fn_grad" filters +# this output tuple to just the second argument, which we can call backward +# on. Functions that produce a single tensor can ignore this argument. +# +# "broadcasts_input" is a bool indicated if the SampleInput causes the operator +# to broadcast the "input" argument. This is important for tests to understand +# because inplace variants of operations throw a runtime error if they +# would broadcast their input arguments, so tests that work with inplace +# variants filter SampleInputs that broadcast their input. +# +# "name" is a string that's just used for debugging. It appears when printing +# the SampleInput. +# +# Sample inputs are designed to be used with many tests, some +# that are very time consuming, so they should be a small +# set with small tensors. An elaborated set of sample inputs +# can be specified using the "reference_inputs_func" attribute. +# The "reference inputs" for an operation are an extended +# set of sample inputs that can more exhaustively test an +# operator. They are used by only a few tests that are careful +# not to take too long to run. Adding reference inputs +# is highly encouraged! +# +# THE (OPTIONAL) ERROR INPUTS FUNCTION +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# OpInfos may optionally specify "error inputs" through an error function. If +# specified test_errors in test_ops.py will call the op with these inputs +# and validate that the desired error is thrown. +# +# Error inputs automate a common testing pattern where multiple inputs are +# passed to an operation and the errors they thrown are reviewed. Tests +# written in this style should be ported to the new OpInfo pattern. +# +# Error inputs are specified using the ErrorInputs class, which contains +# a SampleInput (see above) and data about the expected error. +# +# OPINFO FILE ORGANIZATION +# ~~~~~~~~~~~~~~~~~~~~~~~~ +# +# All OpInfos are currently defined in this file. Most OpInfo tests are defined +# in test_ops.py, but some system-specific tests are defined in those +# systems' test files, and subclass-specific tests are defined in the test +# file that corresponds to that subclass (see the below). +# Expect a reorganization in the future. +# +# WHAT'S TESTED? +# ~~~~~~~~~~~~~~ +# +# Every OpInfo in the op_db sequence has the following properties validated in +# test_ops.py: +# +# - that its supported dtypes are specified correctly +# - that the operation produces the same results when called with noncontiguous inputs +# - that it supports the out= argument properly (if it allows out=), +# see https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch +# - that it works with the conjugate view bit properly +# - that its function, method, and inplace variants perform the same operation +# (that is, that torch.add, torch.Tensor.add, and torch.Tensor.add_ all +# do the same thing). +# - that its inplace variant preserves the input's storage +# - that its gradient formula is implemented correctly, and that it supports +# gradgrad and complex grad and gradgrad and forward mode AD properly for +# the op's function and inplace variants (method variants are skipped +# to reduce test time). +# - that the operation performs the same operation when traced or scripted +# using the jit +# - that the operation is autodifferentiated by the jit as expected +# - that the operator's aliases, if any, perform the same operation and that +# the jit understands the alias +# - that the operator throws the correct errors (if error_inputs is defined) +# - that the operator produces the same results as a NumPy reference (if ref is defined) +# - that the operator produces the same results as a NumPy reference on an extended +# set of "reference inputs" (if both ref and reference_inputs_func are defined) +# (NOTE: elementwise unary and elementwise binary OpInfos do this even if only +# ref is defined, because they effectively autogenerate reference inputs) +# - that the operator works on different CUDA devices +# +# Additional OpInfo tests are in test_jit_fuser_te.py, test_fx_experimental.py, +# and test_fx.py. These tests validate that operators work with NNC and FX +# as expected. +# +# For performance, some of the above tests may only run on the first +# SampleInput returned by an OpInfo's sample input function. +# +# In addition to these tests, some subclasses (discussed in the next section) +# define additional tests. +# +# Critically, as mentioned above, what's not necessarily tested is that the operator +# works as expected. When implementing an OpInfo an engineer must still +# typically write one or more tests validating the operator's behavior. +# The exception to this is if reference testing is sufficient, or if +# the operation belongs to an OpInfo subclass that has more exhaustive +# operator testing. Elementwise unary and elementwise binary operators, +# in particular, usually don't require additional testing beyond +# writing an Opinfo. +# +# +# OPINFO (SUB)CLASSES +# ~~~~~~~~~~~~~~~~~~~ +# +# In addition to the OpInfo base class there are several specialized OpInfo +# subclasses. For example, the UnaryUfuncInfo subclass is used for +# unary elementwise operations. These operations have a common structure +# that test_unary_ufuncs.py exploits with additional automated testing. +# The automated testing in test_unary_ufuncs.py is so thorough, comparing +# the operator to a NumPy reference function on a plethora of values, that +# just implementing an OpInfo for a unary elementwise operation is often +# sufficient testing. +# +# The ForeachFuncInfo is another OpInfo subclass that is hyper-specialized to a +# very unique class of operations. These OpInfos aren't included in the +# op_db sequence and have their own tests. +# +# Other OpInfo subclasses, like SpectralFuncInfo, are just for convenience +# when writing OpInfos. +# +# TESTING A NEW OPERATOR +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# If you're adding a new operator to any of the following namespaces: +# - torch +# - torch.fft +# - torch.linalg, +# - torch.special +# - torch.nn.functional +# then you should typically add an OpInfo for it. +# +# As mentioned a couple times above, implementing an OpInfo is not +# usually sufficient testing (unless the operator is a unary or binary elementwise +# operator). The OpInfo will only test the properties described in the +# "WHAT'S TESTED" section. It DOES NOT necessarily verify that the operator is +# implemented correctly. +# +# TIPS FOR WRITING AN OPINFO AND OPINFO TESTS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Writing an OpInfo can be a little daunting. Since the point of an OpInfo is to +# be consumed by a variety of systems it can be hard to understand how to +# deal with test failures or how to set the OpInfo metadata properly. +# +# Before adding an OpInfo it helps to look at other OpInfos. A sample inputs +# function must be defined, and the operator's dtypes must be specified. +# Once that's done you should run the operator's tests in test_ops.py +# (these can be filtered using the "-k" argument in pytest). Tests that +# fail should provide an error message that describes what to change about +# your OpInfo. You don't need to worry about changing an OpInfo's default +# values unless a test yells at you. +# +# Similarly, if you're writing a test that consumes OpInfos then it's critical +# your test provides a clear error message describing what to do when it +# fails. You should not assume the OpInfo implementer is familiar with your +# system. +# +# If you see a confusing error message while developing an OpInfo then please +# file an issue describing what happened. +# +# This trial-and-error approach to writing an OpInfo can be frustrating, +# but it's probably necessary as long as OpInfos don't require +# learning about all the systems that consume them. One thing that can help +# is the get_supported_dtypes() function defined in utils.py. This +# function can be used to programmatically specify the dtypes an operator +# supports, and is especially useful if writing an OpInfo on a machine +# without a CUDA device. See its documentation for more details. +# +# THE FUTURE OF OPINFOS AND OPINFO TESTING +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In the future we expect OpInfo coverage to improve and cover +# the great majority of PyTorch's (public) operators. +# + + +# Classes and methods for the operator database +@dataclass +class OpInfo: + """Operator information and helper functions for acquiring it.""" + + # the string name of the function + name: str + + # An optional reference function that accepts ndarrays (AKA "NumPy arrays"). + # If given, the op will be compared with its reference on each of its sample inputs. + ref: Optional[Callable] = None + + # the following metadata describes the operator, its variants, and its aliases, if any + + # iterable of aliases, e.g. ("absolute",) for torch.abs + aliases: Iterable = None + + # additional string to include in the test name + # this is useful when an op needs multiple OpInfos, + # like divide does, often because it's really several + # different ops behind the scenes + variant_test_name: str = "" + + # the function variant of the operation, populated as torch. if None + op: Callable = None + + # allows the method variant of this operation to be specified as follows: + # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name + # - if None, then the OpInfo explicitly specifies is has no associated method + # - if a Callable, then that callable should be the method associated with this operation + method_variant: Callable = _NOTHING + + # allows the inplace variant of this operation to be specified as follows: + # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name + # - if None, then the OpInfo explicitly specifies is has no associated inplace variant + # - if a Callable, then that callable should be the inplace variant associated with this operation + inplace_variant: Callable = _NOTHING + + # allows the operator variant of this operation to be specified as follows: + # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name + # - if None, then the OpInfo explicitly specifies is has no associated operator + # - if a Callable, then that callable should be the operator associated with this operation + operator_variant: Callable = _NOTHING + + # allows the inplace operator variant of this operation to be specified as follows: + # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name + # - if None, then the OpInfo explicitly specifies is has no associated inplace operator + # - if a Callable, then that callable should be the inplace operator associated with this operation + inplace_operator_variant: Callable = _NOTHING + + # the following metadata are test directives for skipping or modifying tests + + # information about which tests to skip + skips: tuple = () + + # decorators to apply to generated tests + decorators: tuple = () + + # the following are pointers to functions to generate certain classes of inputs + + # function to generate sample inputs with strided layouts + sample_inputs_func: Callable = None + + # function to generate a more thorough set of samples inputs with strided layouts + reference_inputs_func: Callable = None + + # function to generate inputs that will throw errors + error_inputs_func: Callable = None + + # function to generate sparse (coo, csr, csc, bsr, bsc) inputs that will throw errors + error_inputs_sparse_func: Callable = None + + # function to generate sample inputs with sparse coo layouts + sample_inputs_sparse_coo_func: Callable = None + + # function to generate sample inputs with sparse csr layouts + sample_inputs_sparse_csr_func: Callable = None + + # function to generate sample inputs with sparse csc layouts + sample_inputs_sparse_csc_func: Callable = None + + # function to generate sample inputs with sparse bsr layouts + sample_inputs_sparse_bsr_func: Callable = None + + # function to generate sample inputs with sparse bsc layouts + sample_inputs_sparse_bsc_func: Callable = None + + # the following metadata relates to dtype support and is tested for correctness in test_ops.py + + # dtypes this function works with on the CPU, + # inherited by other device types that don't specify their own dtypes + dtypes: _dispatch_dtypes = None + + # the following dtypesIf... options override the dtypes value on their respective device types + # I.e. instead of writing multiple `dtypesIfCUDA`, `dtypesIfROCM`, etc one can simply define a dict + # dtypesIf = { 'cuda': (torch.float, torch.double), 'rocm': (torch.half, torch.bfloat16) } + dtypesIf: dict[str, _dispatch_dtypes] = field(default_factory=dict) + + def __getattribute__(self, name: str) -> Any: + if name.startswith("dtypesIf") and name != "dtypesIf": + # TODO: Warn if used + dev_name = name.removeprefix("dtypesIf").lower() + return self.dtypesIf.get(dev_name) + return super().__getattribute__(name) + + def __setattr__(self, name: str, value: Any) -> None: + # TODO: After migration, start adding warnings here + if name.startswith("dtypesIf") and name != "dtypesIf": + assert isinstance(value, (_dispatch_dtypes, type(None))) + dev_name = name.removeprefix("dtypesIf").lower() + self.dtypesIf[dev_name] = value + return + super().__setattr__(name, value) + + # dtypes this function is expected to work with on CUDA + dtypesIfCUDA: _dispatch_dtypes = None + + # dtypes this function is expected to work with on ROCM + dtypesIfROCM: _dispatch_dtypes = None + + dtypesIfHpu: _dispatch_dtypes = None + + # dtypes this function is expected to work with on XPU + dtypesIfXPU: _dispatch_dtypes = None + + # backward dtypes this function is expected to work with + backward_dtypes: _dispatch_dtypes = None + + # backward dtypes this function is expected to work with on CUDA + backward_dtypesIfCUDA: _dispatch_dtypes = None + + # backward dtypes this function is expected to work with on ROCM + backward_dtypesIfROCM: _dispatch_dtypes = None + + backward_dtypesIfHpu: _dispatch_dtypes = None + + # the following metadata describes the operators out= support + + # whether the op supports the out kwarg + # defaults to True, if the op does not allow the out kwarg or + # supports it incorrectly then test_out in test_ops.py should fail + supports_out: bool = True + + # the following metadata relates to autograd support + # whether the operation supports backward mode AD + # if true, gradient correctness is tested in test_ops.py + # using the op's sample inputs + supports_autograd: bool = True + + # whether the op supports second order gradients + # if true, gradgrad correctness is tested in test_ops.py + # defaults to support_autograd's value + # TODO: rename this to supports_bwgrad_bwgrad to be consistent with below + supports_gradgrad: bool = None + + # whether the ops supports second order gradients via + # forward-over-reverse. If True, forward-over-reverse gradgrad correctness + # is tested. If False, test that forward grad is not implemented. + # Defaults to False. + supports_fwgrad_bwgrad: bool = False + + # whether the operation supports inplace autograd + # if true, tested in test_ops.py + # defaults to supports_autograd's value + supports_inplace_autograd: bool = None + + # Whether the operation support forward mode AD + # If the value is True, we check that the gradients are correct + # If the value is False, we test that forward grad is not implemented + supports_forward_ad: bool = False + + # Whether the operation has a varargs variant + # (e.g. functions like ones, zeros, methods like view, permute) + supports_varargs: bool = False + + # Whether the forward operation avoids materializing COW tensor inputs + supports_cow_input_no_materialize_forward: bool = True + + # Whether the backward operation avoids materializing COW tensor inputs + supports_cow_input_no_materialize_backward: bool = True + + # Whether to skip the backward part of the COW tensor input test + skip_cow_input_backward: bool = False + + # If `supports_cow_input_no_materialize_forward == True`, this list contains + # the arg indices or kwarg names of inputs that are expected to materialize + allow_cow_input_materialize_forward: list[Union[int, str]] = None + + # If `supports_cow_input_no_materialize_backward == True`, this list contains + # the arg indices or kwarg names of inputs that are expected to materialize + allow_cow_input_materialize_backward: list[Union[int, str]] = None + + # wrapper function for gradcheck + gradcheck_wrapper: Callable = lambda op, *args, **kwargs: op(*args, **kwargs) + + # whether to check batched grad when doing gradcheck + # defaults to support_autograd's value + check_batched_grad: bool = None + + # whether to check batched grad grad when doing gradgradcheck + # default's to support_gradgrad's value + check_batched_gradgrad: bool = None + + # whether to check batched forward grad when doing gradcheck + # defaults to the value of `supports_forward_ad` + check_batched_forward_grad: bool = None + + # whether to check batched forward grad when doing gradcheck + # defaults to the value of `check_batched_forward_grad` + check_inplace_batched_forward_grad: bool = None + + # tolerance for nondeterminism while performing gradcheck + gradcheck_nondet_tol: float = 0.0 + + # Whether to use the fast implementation for gradcheck/gradgradcheck. + # When set to None, defers to the default value provided by the wrapper + # function around gradcheck (testing._internal.common_utils.gradcheck) + gradcheck_fast_mode: bool = None + + # the following metadata relates to JIT support and is tested for correctness in test_ops.py + + # name of the corresponding aten:: operator + aten_name: str = None + + # if this is a composite implicit autograd op, the decomposed op + decomp_aten_name: Optional[str] = None + + # name of the corresponding aten:: operator for backwards + aten_backward_name: Optional[str] = None + + # if a op's aten::node is expected to be symbolically autodiffed + assert_autodiffed: bool = False + + # a list of strings with node names that are expected to be in a + # DifferentiableGraph when autodiffed. Ex: ['aten::add', 'aten::mm'], + # default is populated to be ['aten::(name of Python operator)'] + autodiff_nonfusible_nodes: list[str] = None + + # a list of strings with node names that are expected to be in FusionGroups + # inside of DifferentiableGraphs when this operation is autodiffed. + # Ex: ['aten::add', 'aten::mm'], defaults to an empty list + # Note: currently no ops use fusible nodes + autodiff_fusible_nodes: list[str] = None + + # the following metadata relates to sparse support and is used in test_sparse.py + + # whether the op supports sparse coo inputs, defaults to False + # TODO: rename supports_sparse to supports_sparse_coo + supports_sparse: bool = None + + # only run tracing tests + supports_scripting: bool = True + + # if the operator can be traced + supports_tracing: bool = True + + # the following metadata relates to sparse compressed support and + # is used in test_sparse_csr.py and test_sparse.py + + # whether the op supports sparse csr inputs, defaults to False + supports_sparse_csr: bool = None + # whether the op supports sparse csc inputs, defaults to False + supports_sparse_csc: bool = None + # whether the op supports sparse bsr inputs, defaults to False + supports_sparse_bsr: bool = None + # whether the op supports sparse bsc inputs, defaults to False + supports_sparse_bsc: bool = None + # whether the op supports nested jagged inputs, defaults to False + supports_njt: bool = None + + # whether the op promotes integer inputs to float + promotes_int_to_float: bool = False + + # the following metadata relates to complex support and is checked in test_ops.py + + test_conjugated_samples: bool = True + + test_neg_view: bool = True + + # assert that jit shape analysis fully propagates shape + assert_jit_shape_analysis: bool = False + + # the following metadata relates to ExpandedWeights support and is checked in test_expanded_weights.py + + supports_expanded_weight: bool = False + + is_factory_function: bool = False + + skip_correctness_check_compile_vs_eager: bool = False + + def __post_init__(self): + self._original_opinfo_args = asdict(self).copy() + + assert self.dtypes is not None, f"OpInfo for {self.name} has no dtypes!" + + # Validates the dtypes are generated from the dispatch-related functions + for name, val in self.dtypesIf.items(): + if val is not None: + assert isinstance(val, _dispatch_dtypes) + self.dtypesIf[name] = set(val) + + if self.aten_name is None: + self.aten_name = self.name + + # Attribute to verify dynamic_dtypes are used. + self.dynamic_dtypes = any( + isinstance(dtypes, utils._dynamic_dispatch_dtypes) + for dtypes in self.dtypesIf.values() + ) + + if self.dynamic_dtypes: + # Make sure `dtyesIfCUDA` is dynamic, if dynamic dispatch is used for CPU + # This is because, below we set dtypesIfCUDA to dtypes if they are None. + assert isinstance(self.dtypesIfCUDA, utils._dynamic_dispatch_dtypes), ( + f"To use dynamic dtypes for operator {self.name}, " + "acquire the dtypes dynamically for argument `dtypesIfCUDA`." + "This is to ensure that CUDA dtypes are acquired correctly as they" + "differ from CPU dtypes occasionally" + ) + + self.dtypes = set(self.dtypes) + + # NOTE: backward dtypes must be acquired before forward dtypes + # since they fallback to explicit (not implicit!) specifications of + # forward dtypes + self.backward_dtypesIfROCM = ( + set(self.backward_dtypesIfROCM) + if self.backward_dtypesIfROCM is not None + else ( + self.backward_dtypesIfCUDA + if self.backward_dtypesIfCUDA is not None + else self.backward_dtypes + if self.backward_dtypes is not None + else self.dtypesIfROCM + if self.dtypesIfROCM is not None + else self.dtypesIfCUDA + if self.dtypesIfCUDA is not None + else self.dtypes + ) + ) + self.backward_dtypesIfCUDA = ( + set(self.backward_dtypesIfCUDA) + if self.backward_dtypesIfCUDA is not None + else ( + self.backward_dtypes + if self.backward_dtypes is not None + else self.dtypesIfCUDA + if self.dtypesIfCUDA is not None + else self.dtypes + ) + ) + self.backward_dtypesIfHpu = ( + set(self.backward_dtypesIfHpu) + if self.backward_dtypesIfHpu is not None + else ( + self.backward_dtypes + if self.backward_dtypes is not None + else self.dtypes + ) + ) + + self.backward_dtypes = ( + set(self.backward_dtypes) + if self.backward_dtypes is not None + else self.dtypes + ) + + # Inherit from cpu + for dev_type in ["cuda", "hpu"]: + if self.dtypesIf.get(dev_type) is None: + self.dtypesIf[dev_type] = self.dtypes + + # Inherit from CUDA + for dev_type in ["rocm", "xpu"]: + if self.dtypesIf.get(dev_type) is None: + self.dtypesIf[dev_type] = self.dtypesIf["cuda"] + + # NOTE: if the op is unspecified it is assumed to be under the torch namespace + if not self.op: + self.op = _getattr_qual(torch, self.name) + + if self.method_variant is _NOTHING: + self.method_variant = getattr(torch.Tensor, self.name, None) + + # attributes like real, imag are not callable + if not callable(self.method_variant): + self.method_variant = None + + if self.inplace_variant is _NOTHING: + inplace_name = self.name + "_" + self.inplace_variant = getattr(torch.Tensor, inplace_name, None) + + if self.operator_variant is _NOTHING: + self.operator_variant = getattr(operator, self.name, None) + + if self.inplace_operator_variant is _NOTHING: + # Note: operator.i will use operator. and assign the result to the lhs when no + # __i__ method is found. This results in the appearance of an inplace operator variant which + # does not have the correct inplace behavior. To avoid this, we guard automatic detection of the inplace + # operator with a check that an inplace variant exists. + if self.inplace_variant is not None: + inplace_operator_name = "i" + self.name + self.inplace_operator_variant = getattr( + operator, inplace_operator_name, None + ) + else: + self.inplace_operator_variant = None + + self.decorators = (*self.decorators, *self.skips) + + # Specifying sample inputs function without specifying the + # corresponding layout support implies the layout support: + if self.supports_sparse is None: + self.supports_sparse = self.sample_inputs_sparse_coo_func is not None + if self.sample_inputs_sparse_coo_func is None: + self.sample_inputs_sparse_coo_func = self._sample_inputs_unspecified + + if self.supports_sparse_csr is None: + self.supports_sparse_csr = self.sample_inputs_sparse_csr_func is not None + if self.sample_inputs_sparse_csr_func is None: + self.sample_inputs_sparse_csr_func = self._sample_inputs_unspecified + + if self.supports_sparse_csc is None: + self.supports_sparse_csc = self.sample_inputs_sparse_csc_func is not None + if self.sample_inputs_sparse_csc_func is None: + self.sample_inputs_sparse_csc_func = self._sample_inputs_unspecified + + if self.supports_sparse_bsr is None: + self.supports_sparse_bsr = self.sample_inputs_sparse_bsr_func is not None + if self.sample_inputs_sparse_bsr_func is None: + self.sample_inputs_sparse_bsr_func = self._sample_inputs_unspecified + + if self.supports_sparse_bsc is None: + self.supports_sparse_bsc = self.sample_inputs_sparse_bsc_func is not None + if self.sample_inputs_sparse_bsc_func is None: + self.sample_inputs_sparse_bsc_func = self._sample_inputs_unspecified + + if self.supports_njt is None: + self.supports_njt = False + + # We run the sampling functions without tracking the gradiends of the creation of inputs + self.sample_inputs_func = torch.no_grad()(self.sample_inputs_func) + self.sample_inputs_sparse_coo_func = torch.no_grad()( + self.sample_inputs_sparse_coo_func + ) + self.sample_inputs_sparse_csr_func = torch.no_grad()( + self.sample_inputs_sparse_csr_func + ) + self.sample_inputs_sparse_csc_func = torch.no_grad()( + self.sample_inputs_sparse_csc_func + ) + self.sample_inputs_sparse_bsr_func = torch.no_grad()( + self.sample_inputs_sparse_bsr_func + ) + self.sample_inputs_sparse_bsc_func = torch.no_grad()( + self.sample_inputs_sparse_bsc_func + ) + if self.reference_inputs_func is not None: + self.reference_inputs_func = torch.no_grad()(self.reference_inputs_func) + + if not self.autodiff_fusible_nodes: + self.autodiff_fusible_nodes = [] + + if self.autodiff_nonfusible_nodes is None: + self.autodiff_nonfusible_nodes = ["aten::" + self.name] + + # Autograd support + + # Autograd flags that depend on backward AD only + # - If setting has been explicitly set, raise error if inconsistent + if self.supports_gradgrad is None: + self.supports_gradgrad = self.supports_autograd + else: + assert not (self.supports_gradgrad and not self.supports_autograd), ( + "supports_gradgrad refines the part of autograd is supported, so it should " + "not be set if supports_autograd is False" + ) + if self.check_batched_grad is None: + self.check_batched_grad = self.supports_autograd or self.supports_forward_ad + else: + assert not ( + self.check_batched_grad + and not (self.supports_autograd or self.supports_forward_ad) + ), ( + "check_batched_grad refines the part of autograd that will be checked (by gradcheck), so " + "it should not be set if supports_autograd is False" + ) + if self.check_batched_gradgrad is None: + self.check_batched_gradgrad = self.supports_gradgrad + else: + assert not (self.check_batched_gradgrad and not self.supports_gradgrad), ( + "check_batched_gradgrad refines the part of autograd that will be checked (by " + "gradgradcheck), so it should not be set if either supports_gradgrad or supports_autograd " + "is False." + ) + if self.check_batched_forward_grad is None: + self.check_batched_forward_grad = self.supports_forward_ad + else: + assert not ( + self.check_batched_forward_grad and not self.supports_forward_ad + ), ( + "check_batched_forward_grad should only be used when supports_forward_ad " + "is True. It is used to disable the test in the specific cases " + "where the op supports forward ad but fails to compute " + "batched forward grad." + ) + + if self.check_inplace_batched_forward_grad is None: + self.check_inplace_batched_forward_grad = self.check_batched_forward_grad + else: + assert not ( + self.check_inplace_batched_forward_grad + and not self.check_batched_forward_grad + ), ( + "check_batched_forward_grad should only be used when check_batched_forward_grad " + "is True. It is used to disable the test in the specific cases " + "where the op supports batched forward grad but fails to compute batched forward " + "grad for the inplace variant of the op." + ) + + assert not (self.supports_fwgrad_bwgrad and not self.supports_autograd), ( + "supports_fwgrad_bwgrad enables forward-over-backward gradgrad checks and should only be " + "True if backward ad is also checked, i.e., supports_forward_ad should be True.", + self.name, + ) + + # Autograd flags that depend on both forward AD and backward AD + if self.supports_inplace_autograd is None: + self.supports_inplace_autograd = ( + self.supports_autograd or self.supports_forward_ad + ) + else: + assert not ( + self.supports_inplace_autograd + and not self.supports_autograd + and not self.supports_forward_ad + ), ( + "supports_inplace_autograd refines the part of autograd that is supported, so " + "it should not be set if both supports_autograd and supports_forward_ad are False" + ) + + if self.aliases is not None: + self.aliases = tuple(AliasInfo(a) for a in self.aliases) # type: ignore[assignment] + else: + self.aliases = () + + def __call__(self, *args, **kwargs): + """Calls the function variant of the operator.""" + return self.op(*args, **kwargs) + + def __str__(self): + return dataclass_repr(self) + + def get_op(self): + """Returns the function variant of the operator, torch..""" + return self.op + + def get_method(self): + """Returns the method variant of the operator, torch.Tensor.. + Returns None if the operator has no method variant. + """ + return self.method_variant + + def get_inplace(self): + """Returns the inplace variant of the operator, torch.Tensor._. + Returns None if the operator has no inplace variant. + """ + return self.inplace_variant + + def get_operator(self): + """Returns operator variant of the operator, e.g. operator.neg + Returns None if the operator has no operator variant. + """ + return self.operator_variant + + def get_inplace_operator(self): + """Returns the inplace operator variant of the operator, e.g operator.iadd + Returns None if the operator has no inplace operator variant""" + return self.inplace_operator_variant + + # Returns a tuple of callables: + # (TestCase -> subtest context, TestCase -> skip / xfail context) + # I'd love to combine these into one but I haven't figured out how to do it + # in a way that works like it should, and I tried a LOT of things. + def _maybe_skip_or_xfail(self, rules, device, sample, idx): + def _subtest_fn(test_case, sample=sample.name, idx=idx): + return test_case.subTest(sample=sample, idx=idx) + + if rules is None or len(rules) == 0: + return (_subtest_fn, lambda _: contextlib.nullcontext()) + + # NB: match first rule only (order matters!) + for rule in rules: + if rule.sample_match_fn(device, sample): + log.debug( + "matched %s rule '%s': %s %s %s", + rule.type, + rule.name, + self.full_name, + device, + sample, + ) + + # Provide a context for the test case to run the sample input + # through as a subtest AND handle skip / xfail for it as needed. + return ( + _subtest_fn, + lambda test_case, rule=rule: rule.get_context(test_case), + ) + + log.debug("matched no rules: %s %s %s", self.full_name, device, sample) + return (_subtest_fn, lambda _: contextlib.nullcontext()) + + def _sample_callback_fn(self, use_subtests, device): + # Get sample-specific skips / xfails. + sample_skips_and_xfails = getattr( + extract_test_fn(), "sample_skips_and_xfails", None + ) + + if sample_skips_and_xfails is not None and not use_subtests: + raise RuntimeError( + """Sample-specific skips / xfails require use_subtests=True. +Please pass this to the sample generation function and run the test logic within the +returned contexts (NB: order matters!). For example: + +def test_foo(self, device, dtype, op): + for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(..., use_subtests=True): + # these contexts handle running within subtests and skips / xfails + with subtest_ctx(self), skip_xfail_ctx(self): + # test logic here + ...""" + ) + + if not use_subtests: + # use the default callback that returns the sample without a subtest context + return None + + if USE_PYTEST: + try: + import pytest_subtests # noqa: F401 + except ModuleNotFoundError: + raise RuntimeError( + "Encountered an OpInfo test with use_subtests=True and pytest-subtests is " + "not installed. The feature will not work correctly within pytest without " + "this package; please install it." + ) from None + + def _f( + sample, + idx, + self=self, + device=device, + sample_skips_and_xfails=sample_skips_and_xfails, + use_subtests=use_subtests, + ): + # When subtests are enabled, also return a subtest context. This is required + # for xfails / skips to work properly. + return ( + sample, + *self._maybe_skip_or_xfail( + sample_skips_and_xfails, device, sample, idx + ), + ) + + return _f + + def conjugate_sample_inputs(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs but with the tensor input or first + tensor in a sequence input conjugated. + """ + + set_seed = kwargs.pop("set_seed", True) + use_subtests = kwargs.pop("use_subtests", False) + samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) + conj_samples = list(samples) + + def conjugate(tensor): + _requires_grad = tensor.requires_grad + tensor = tensor.conj() + return tensor.requires_grad_(_requires_grad) + + for i, sample in enumerate(samples): + sample = conj_samples[i] + # Note: it is assumed that the input here is either a tensor or tensorlist + if isinstance(sample.input, torch.Tensor): + sample.input = conjugate(sample.input) + else: + sample.input[0] = conjugate(sample.input[0]) + + return TrackedInputIter( + iter(conj_samples), + "conjugate sample input", + item_callback=self._sample_callback_fn(use_subtests, device), + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + def sample_inputs(self, device, dtype, requires_grad=False, **kwargs): + """ + Returns an iterable of SampleInputs. + + These samples should be sufficient to test the function works correctly + with autograd, TorchScript, etc. + """ + set_seed = kwargs.pop("set_seed", True) + use_subtests = kwargs.pop("use_subtests", False) + samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) + + if kwargs.get("include_conjugated_inputs", False): + conj_samples = self.conjugate_sample_inputs( + device, dtype, requires_grad, **kwargs + ) + samples_list = list(samples) + samples_list.extend(conj_samples) + samples = tuple(samples_list) + + return TrackedInputIter( + iter(samples), + "sample input", + item_callback=self._sample_callback_fn(use_subtests, device), + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + def reference_inputs(self, device, dtype, requires_grad=False, **kwargs): + """ + Returns an iterable of SampleInputs. + + Distinct from sample_inputs() above because this returns an expanded set + of inputs when reference_inputs_func is defined. If undefined this returns + the sample inputs. + """ + set_seed = kwargs.pop("set_seed", True) + use_subtests = kwargs.pop("use_subtests", False) + if self.reference_inputs_func is None: + samples = self.sample_inputs_func( + self, device, dtype, requires_grad, **kwargs + ) + return TrackedInputIter( + iter(samples), + "reference input", + item_callback=self._sample_callback_fn(use_subtests, device), + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + if kwargs.get("include_conjugated_inputs", False): + raise NotImplementedError + + references = self.reference_inputs_func( + self, device, dtype, requires_grad, **kwargs + ) + return TrackedInputIter( + iter(references), + "reference input", + item_callback=self._sample_callback_fn(use_subtests, device), + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + def error_inputs(self, device, **kwargs): + """ + Returns an iterable of ErrorInputs. + """ + set_seed = kwargs.pop("set_seed", True) + use_subtests = kwargs.pop("use_subtests", False) + errs = self.error_inputs_func(self, device, **kwargs) + + def _error_item_callback(e, i, use_subtests=use_subtests, device=device): + cb = self._sample_callback_fn(use_subtests, device) + # no rules to apply; just return the sample + if cb is None: + return e + + # adapt the callback call since ErrorInputs contain SampleInputs + _, subtest_ctx = cb(e.sample_input, i) + return (e, subtest_ctx) + + return TrackedInputIter( + iter(errs), + "error input", + track_callback=lambda e: e.sample_input, + item_callback=_error_item_callback, + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + def error_inputs_sparse(self, device, layout, **kwargs): + """ + Returns an iterable of ErrorInputs that contain sparse sample + inputs with a specified layout. + """ + if not self.supports_sparse_layout(layout): + raise unittest.SkipTest("unsupported sparse layout") + return self.error_inputs_sparse_func(self, device, layout, **kwargs) + + def supports_sparse_layout(self, layout): + """Return True if OpInfo supports the specified sparse layout.""" + layout_name = str(layout).split(".")[-1] + # map torch.sparse_coo to OpInfo.supports_sparse: + layout_name = layout_name.replace("_coo", "") + return getattr(self, f"supports_{layout_name}") + + def sample_inputs_sparse( + self, layout, device, dtype, requires_grad=False, **kwargs + ): + """Returns an iterable of SampleInputs that contain inputs with a + specified sparse layout. + """ + layout_name = str(layout).split(".")[-1] + sample_inputs_mth = getattr(self, "sample_inputs_" + layout_name) + + def non_empty_sampler(op, generator): + found_sample = False + for sample in generator: + found_sample = True + yield sample + if not found_sample: + raise unittest.SkipTest("NO SAMPLES!") + + return non_empty_sampler( + self, + sample_inputs_mth(device, dtype, requires_grad=requires_grad, **kwargs), + ) + + def _sample_inputs_unspecified(self, *args, **kwargs): + """Raises an NotImplemented exception in a OpInfo instance creation + that specifies supports_sparse(|_csr|_csc|_bsr|_bsc)=True + without specifying the corresponding sample function as + sample_inputs_sparse_(coo|csr|csc|bsr|bsc)_func. + + To avoid this, either define the corresponding sample function, + or re-map unsupported samples to error inputs in an appropriate + + opinfo/definitions/sparse.py:_validate_sample_input_sparse_ + + function. + """ + raise NotImplementedError("no sample function specified") + + def sample_inputs_sparse_coo(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + coo layout. + """ + return self.sample_inputs_sparse_coo_func( + self, device, dtype, requires_grad, **kwargs + ) + + def sample_inputs_sparse_csr(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + csr layout. + """ + return self.sample_inputs_sparse_csr_func( + self, device, dtype, requires_grad, **kwargs + ) + + def sample_inputs_sparse_csc(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + csc layout. + """ + return self.sample_inputs_sparse_csc_func( + self, device, dtype, requires_grad, **kwargs + ) + + def sample_inputs_sparse_bsr(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + bsr layout. + """ + return self.sample_inputs_sparse_bsr_func( + self, device, dtype, requires_grad, **kwargs + ) + + def sample_inputs_sparse_bsc(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + bsc layout. + """ + return self.sample_inputs_sparse_bsc_func( + self, device, dtype, requires_grad, **kwargs + ) + + def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): + """Returns the decorators targeting the given test.""" + result = [] + for decorator in self.decorators: + if isinstance(decorator, DecorateInfo): + if decorator.is_active( + test_class, test_name, device, dtype, param_kwargs + ): + result.extend(decorator.decorators) + else: + result.append(decorator) + return result + + def supported_dtypes(self, device_type): + if device_type == "privateuse1": + device_type = torch._C._get_privateuse1_backend_name() + device_type = torch.device(device_type).type + if device_type == "cuda" and TEST_WITH_ROCM: + device_type = "rocm" + result = self.dtypesIf.get(device_type, self.dtypes) + if device_type == "mps": + return result - {torch.float64, torch.cdouble} + return result + + def supported_backward_dtypes(self, device_type): + if not self.supports_autograd: + return set() + + if device_type == "privateuse1": + device_type = torch._C._get_privateuse1_backend_name() + device_type = torch.device(device_type).type + backward_dtypes = None + if device_type == "cuda": + backward_dtypes = ( + self.backward_dtypesIfROCM + if TEST_WITH_ROCM + else self.backward_dtypesIfCUDA + ) + elif device_type == "hpu": + backward_dtypes = self.backward_dtypesIfHpu + elif device_type == "mps": + backward_dtypes = self.backward_dtypes - {torch.double, torch.cdouble} + else: + backward_dtypes = self.backward_dtypes + + allowed_backward_dtypes = floating_and_complex_types_and( + torch.bfloat16, torch.float16, torch.complex32 + ) + return set(allowed_backward_dtypes).intersection(backward_dtypes) + + def supports_dtype(self, dtype, device_type) -> bool: + return dtype in self.supported_dtypes(device_type) + + @property + def full_name(self): + """Returns a full name that helps to uniquely identify this OpInfo.""" + variant = "." + self.variant_test_name if self.variant_test_name else "" + # example: "normal.in_place" where "normal" is the name and "in_place" is the variant + return f"{self.name}{variant}" + + @property + def formatted_name(self): + """Returns a formatted full name for this OpInfo that can be used in test names.""" + return self.full_name.replace(".", "_") + + +# Represents a skip / xfail rule matching a particular set of tests. It allows granularity +# at the device, dtype, op, and individual sample levels. This flexibility allows entire +# bugs to be represented by a single rule, even if this corresponds with multiple conceptual +# test cases across multiple ops. +@dataclass +class SampleRule(ABC): + # function to indicate whether the rule applies to this op; return True if so + # NB: str arg of callable is device_type + op_match_fn: Callable[[str, OpInfo], bool] = None + # function to indicate whether the rule applies to this sample; return True if so + sample_match_fn: Callable[[torch.device, SampleInput], bool] = None + # optional name for identifying the rule + name: str = "" + + def __post_init__(self): + if self.op_match_fn is None: + raise ValueError("must have op_match_fn set to be useful") + if self.sample_match_fn is None: + # by default, match for all samples + self.sample_match_fn = lambda device, sample: True + + # returns a string identifier of the rule type + @abstractmethod + def type(self) -> str: ... + + # returns an appropriate context that handles the xfail, skips, etc. + @abstractmethod + def get_context(self, test_case): ... + + +# useful for specifying xfails +@dataclass +class XFailRule(SampleRule): + # expected error type + error_type: TypeVar = Exception + # expected error message + error_msg: str = ".*" + + @property + def type(self) -> str: + return "xfail" + + def get_context(self, test_case): + return test_case.assertRaisesRegex( + # failing within torch.compile wraps within a BackendCompilerFailed + (self.error_type, torch._dynamo.exc.BackendCompilerFailed), + self.error_msg, + ) + + +# useful for specifying skips +@dataclass +class SkipRule(SampleRule): + @property + def type(self): + return "skip" + + def get_context(self, test_case): + @contextlib.contextmanager + def skipcontext(test_case=test_case): + test_case.skipTest("Skipped!") + yield + + return skipcontext() + + +# Decorator that defines skip / xfail rules for a given test function. If these are +# present, the @ops decorator will apply these for each op and place them onto the +# parametrized test functions for use by e.g. OpInfo.sample_inputs(). +class sample_skips_and_xfails: + def __init__(self, rules): + self.rules = rules + + def __call__(self, fn): + rules = getattr(fn, "sample_skips_and_xfails", None) + if rules is not None: + raise RuntimeError("Multiple sets of sample_skips_and_xfails defined") + + fn.sample_skips_and_xfails = self.rules + return fn + + +def _generate_reduction_inputs(device, dtype, requires_grad, **kwargs): + """Generates input tensors for testing reduction operators""" + yield make_tensor([], dtype=dtype, device=device, requires_grad=requires_grad) + yield make_tensor([2], dtype=dtype, device=device, requires_grad=requires_grad) + yield make_tensor([3, 5], dtype=dtype, device=device, requires_grad=requires_grad) + yield make_tensor( + [3, 2, 1, 2], dtype=dtype, device=device, requires_grad=requires_grad + ) + + +def _generate_reduction_kwargs(ndim, supports_multiple_dims=True): + """Generates a subset of all valid dim and keepdim kwargs given ndim that + is appropriate for testing reduction operators. + """ + + # Test default dim and keepdim + yield {} + + # Test reducing inner and outer most dimensions + yield {"dim": 0, "keepdim": True} + yield {"dim": -1, "keepdim": False} + + # Test reducing middle dimension + if ndim > 2: + yield {"dim": ndim // 2, "keepdim": True} + + if supports_multiple_dims: + # Test reducing all dimensions + yield {"dim": tuple(range(ndim)), "keepdim": False} + + # Test reducing both first and last dimensions + if ndim > 1: + yield {"dim": (0, -1), "keepdim": True} + + # Test reducing every other dimension starting with the second + if ndim > 3: + yield {"dim": tuple(range(1, ndim, 2)), "keepdim": False} + + +def sample_inputs_reduction(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for reduction operators.""" + + # TODO(@heitorschueroff) Once all reduction operators are using + # ReductionOpInfo use op_info.supports_multiple_dims directly. + supports_multiple_dims: bool = kwargs.get("supports_multiple_dims", True) + + # TODO(@heitorschueroff) Once all reduction operators are using ReductionOpInfo + # use op_info.generate_args_kwargs directly. + generate_args_kwargs = kwargs.get( + "generate_args_kwargs", lambda *args, **kwargs: (yield (), {}) + ) + + for t in _generate_reduction_inputs(device, dtype, requires_grad): + for reduction_kwargs in _generate_reduction_kwargs( + t.ndim, supports_multiple_dims + ): + for args, kwargs in generate_args_kwargs(t, **reduction_kwargs): + kwargs.update(reduction_kwargs) + yield SampleInput( + t.detach().requires_grad_(requires_grad), args=args, kwargs=kwargs + ) + + +# NOTE [Reductions]: +# +# For testing purposes, we relax the definition of a reduction operator +# as defined in the docstring below. We do this to capture operators with +# a similar API so they can be tested automatically. However... +# +# Strictly speaking a reduction operator is an operator that can reduce an +# array to a single scalar value and that can be computed from the partial +# result of reducing subarrays. This usually means that the reduction operation +# should be commutative and associative. This definition is important when it +# comes to implementation as it determines how a reduction can be parallelized. +# +# For example, many summary statistics such as median, mode and quantile cannot +# be computed from partial results because these are sorting and counting based +# algorithms that need information that would be lost in the reduced value. +class ReductionOpInfo(OpInfo): + """Reduction operator information. + + An operator is a reduction operator if it reduces one or more dimensions of + the input tensor to a single value. Reduction operators must implement the + following signature: + + - `op(input, *args, *, dim=None, keepdim=False, **kwargs) -> Tensor` + + ReductionOpInfo tests that reduction operators implement a consistent API. + Optional features such as reducing over multiple dimensions are captured in + the optional keyword parameters of the ReductionOpInfo constructor. + + If a reduction operator does not yet implement the full required API of + reduction operators, this should be documented by xfailing the failing + tests rather than adding optional parameters to ReductionOpInfo. + + NOTE + The API for reduction operators has not yet been finalized and some + requirements may change. + + See tests in test/test_reductions.py + """ + + def __init__( + self, + name, + *, + # The identity value for the operator if it has one. + identity: Optional[Any] = None, + # The nan policy for the operator if it implements one. + # - propagate: NaN values are propagated to the output + # - omit: NaN values are discarded during the reduction + nan_policy: Optional[str] = None, + # Whether the operator supports reducing multiple dimensions. + supports_multiple_dims: bool = True, + # Whether the operator promotes integral to floating point dtypes. + promotes_int_to_float: bool = False, + # Whether the operator promotes all integral dtypes to int64. + promotes_int_to_int64: bool = False, + # If a specific dtype is given, then the operator always returns that + # dtype irrespective of the input dtype. If None, the operator returns + # the dtype according to the type promotion rules above. + result_dtype: Optional[torch.dtype] = None, + # Casts complex results to real (e.g. linalg.norm or torch.var) + complex_to_real: bool = False, + # ReductionOpInfo tests generate their own input, dim and keepdim + # arguments and call this function to generate tuples of extra args and + # kwargs to use when calling the op. This is required for operators that + # have other required parameters besides the input tensor. + generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: ( + yield ( + (), + {}, + ) + ), + # Options from the OpInfo base class + **kwargs, + ): + self._original_reduction_args = locals().copy() + assert nan_policy in (None, "propagate", "omit") + + # These are mutually exclusive options + assert not (result_dtype and promotes_int_to_float) + assert not (result_dtype and promotes_int_to_int64) + assert not (result_dtype and complex_to_real) + assert not (promotes_int_to_float and promotes_int_to_int64) + + # Default sample_inputs_func for ReductionOpInfo which augments sample + # inputs from sample_inputs_reduction with the args and kwargs from + # generate_args_kwargs. This is only used if sample_inputs_func is None. + def sample_inputs_func(*args, **kwargs): + kwargs["supports_multiple_dims"] = supports_multiple_dims + kwargs["generate_args_kwargs"] = generate_args_kwargs + yield from sample_inputs_reduction(*args, **kwargs) + + # Override OpInfo defaults and call base class __init__ + kwargs.setdefault("inplace_variant", None) + kwargs.setdefault("sample_inputs_func", sample_inputs_func) + super().__init__(name, promotes_int_to_float=promotes_int_to_float, **kwargs) + + self.identity = identity + self.nan_policy = nan_policy + self.supports_multiple_dims = supports_multiple_dims + self.promotes_int_to_int64 = promotes_int_to_int64 + self.complex_to_real = complex_to_real + self.result_dtype = result_dtype + self.generate_args_kwargs = generate_args_kwargs + + +# The base reference input generation for elementwise binary operations +def _reference_inputs_elementwise_binary( + op, device, dtype, requires_grad, exclude_zero, **kwargs +): + yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs) + yield from generate_elementwise_binary_tensors( + op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + if dtype is not torch.bool: + yield from generate_elementwise_binary_small_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + if dtype not in (torch.bool, torch.uint8, torch.int8): + yield from generate_elementwise_binary_large_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield from generate_elementwise_binary_broadcasting_tensors( + op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + yield from generate_elementwise_binary_with_scalar_samples( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + + yield from generate_elementwise_binary_with_scalar_and_type_promotion_samples( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + + if dtype.is_floating_point or dtype.is_complex: + yield from generate_elementwise_binary_extremal_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + + +# Note that these references inputs use scalars for the SampleInput.input value, +# and many tests require SampleInput.input be a tensor or a list of tensors +def reference_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs): + if hasattr(op, "rhs_make_tensor_kwargs"): + exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) + + gen = partial( + _reference_inputs_elementwise_binary, + op, + device, + dtype, + requires_grad, + exclude_zero, + **kwargs, + ) + + # yields "normal" samples + yield from gen() + + # yields noncontiguous samples + for sample in gen(): + yield sample.noncontiguous() + + yield from generate_elementwise_binary_noncontiguous_tensors( + op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + + yield from generate_elementwise_binary_arbitrarily_strided_tensors( + op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + + +# A functional that extends an elementwise binary operator's bespoke error inputs +# with generic error inputs for the class of elementwise binary operations +def make_error_inputs_elementwise_binary(error_inputs_func): + def error_inputs_func_wrapper(op, device, **kwargs): + if error_inputs_func is not None: + yield from error_inputs_func(op, device, **kwargs) + + if not op.supports_rhs_python_scalar: + si = SampleInput(torch.tensor((1, 2, 3), device=device), args=(2,)) + yield ErrorInput(si, error_type=Exception, error_regex="") + + if not op.supports_one_python_scalar: + si = SampleInput(2, args=(torch.tensor((1, 2, 3), device=device),)) + yield ErrorInput(si, error_type=Exception, error_regex="") + + if ( + not kwargs.get("skip_two_python_scalars", False) + and not op.supports_two_python_scalars + ): + si = SampleInput(2, args=(3,)) + yield ErrorInput(si, error_type=Exception, error_regex="") + + return error_inputs_func_wrapper + + +# The following functions and classes are for testing elementwise binary operators. + + +# Returns a generator of pairs of contiguous tensors on the requested device +# and with the requested dtype. +# +# This function is intended to test the non-vectorized and vectorized code +# paths of elementwise binary functions, as well as their handling of odd tensor +# sizes (like zero-dim tensors and tensors with zero elements). +# +# Each iterable will include an a tensor with no elements, +# zero dim (scalar) tensors, small 1D tensors, a medium 1D tensor, and +# a large 2D tensor. +def generate_elementwise_binary_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=False +): + shapes = ( + # tensors with no elements + (0,), + (1, 0, 3), + # zero dim (scalar) tensor + (), + # small 1D tensor + (20,), + # medium 1D tensor + (812,), + # large 2D tensor + (1029, 917), + ) + + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + for shape in shapes: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + yield SampleInput( + lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + + +def generate_elementwise_binary_arbitrarily_strided_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=False +): + # shape, strides, offset + strided_cases = ( + ((5, 6, 2), (1, 1, 7), 2), + ((5, 5, 4), (1, 1, 7), 2), + ((5, 5, 2), (4, 5, 7), 3), + ((5, 5, 2), (5, 5, 7), 3), + ((5, 5, 2), (5, 5, 5), 3), + ((9, 5, 2), (0, 1, 7), 3), + ) + + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + for shape, strides, offset in strided_cases: + a = make_arg( + 500, + ).as_strided(shape, strides, offset) + b = make_arg(shape) + yield SampleInput(a, args=(b,), kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +# Returns a generator of pairs of contiguous tensors on the requested device and with +# the requested dtype. +# +# Unlike the previous function, the values in these tensors are specified manually. +def generate_elementwise_binary_small_value_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=None +): + if exclude_zero is None: + if hasattr(op, "rhs_make_tensor_kwargs"): + exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) + + # defines interesting values + _unsigned_int_vals = (0, 1, 55, 127, 128, 190, 210, 220, 254) + _int_vals = (0, -1, 1, -55, 55, -127, 127, -128) + _float_vals = ( + 0.0, + -0.0, + -0.001, + 0.001, + -0.25, + 0.25, + -1.0, + 1.0, + -math.pi / 2, + math.pi / 2, + -math.pi + 0.00001, + math.pi - 0.00001, + -math.pi, + math.pi, + -math.pi - 0.00001, + math.pi + 0.00001, + ) + + l_vals = [] + r_vals = [] + + if dtype.is_floating_point: + prod = product(_float_vals, _float_vals) + elif dtype.is_complex: + complex_vals = product(_float_vals, _float_vals) + # Note the use of list is required here or the map generator will be + # emptied by the following product and it won't produce the desired cross-product + complex_vals = [complex(*x) for x in complex_vals] + prod = product(complex_vals, complex_vals) + elif dtype in (torch.int8, torch.int16, torch.int32, torch.int64): + prod = product(_int_vals, _int_vals) + elif dtype is torch.uint8: + prod = product(_unsigned_int_vals, _unsigned_int_vals) + else: + raise ValueError("Unsupported dtype!") + + for l, r in prod: + l_vals.append(l) + if r == 0 and exclude_zero: + r_vals.append(1) + else: + r_vals.append(r) + + lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad) + rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]) + + +def generate_elementwise_binary_large_value_tensors( + op, *, device, dtype, requires_grad=False +): + _large_int_vals = (-1113, 1113, -10701, 10701) + _large_float16_vals = (-501, 501, -1001.2, 1001.2, -13437.7, 13437.7) + _large_float_vals = _large_float16_vals + (-4988429.2, 4988429.2, -1e20, 1e20) + + l_vals = [] + r_vals = [] + + if dtype == torch.float16: + prod = product(_large_float16_vals, _large_float16_vals) + elif dtype.is_floating_point: + prod = product(_large_float_vals, _large_float_vals) + elif dtype.is_complex: + complex_vals = product(_large_float_vals, _large_float_vals) + # Note the use of list is required here or the map generator will be + # emptied by the following product and it won't produce the desired cross-product + complex_vals = [complex(*x) for x in complex_vals] + prod = product(complex_vals, complex_vals) + elif dtype in (torch.int16, torch.int32, torch.int64): + prod = product(_large_int_vals, _large_int_vals) + else: + raise ValueError("Unsupported dtype!") + + for l, r in prod: + l_vals.append(l) + r_vals.append(r) + + lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad) + rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]) + + +def generate_elementwise_binary_extremal_value_tensors( + op, *, device, dtype, requires_grad=False +): + _float_extremals = (float("inf"), float("-inf"), float("nan")) + + l_vals = [] + r_vals = [] + + if dtype.is_floating_point: + prod = product(_float_extremals, _float_extremals) + elif dtype.is_complex: + complex_vals = product(_float_extremals, _float_extremals) + # Note the use of list is required here or the map generator will be + # emptied by the following product and it won't produce the desired cross-product + complex_vals = [complex(*x) for x in complex_vals] + prod = product(complex_vals, complex_vals) + else: + raise ValueError("Unsupported dtype!") + + for l, r in prod: + l_vals.append(l) + r_vals.append(r) + + lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad) + rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]) + + # Test case for NaN propagation + nan = ( + float("nan") if dtype.is_floating_point else complex(float("nan"), float("nan")) + ) + lhs = make_tensor( + (128, 128), device=device, dtype=dtype, requires_grad=requires_grad + ) + lhs.view(-1)[::3] = nan + rhs = make_tensor( + (128, 128), device=device, dtype=dtype, requires_grad=requires_grad + ) + rhs.view(-1)[::3] = nan + + yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]) + + +# Returns a generator of pairs of contiguous and noncontiguous tensors that +# require broadcasting +def generate_elementwise_binary_broadcasting_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=False +): + shapes = ( + ((1,), ()), + ((2,), ()), + ((1,), (2,)), + ((2, 1), (2,)), + ((1, 2), (2,)), + ((3, 2), (2,)), + ((1, 3, 2), (2,)), + ((1, 3, 2), (3, 2)), + ((3, 1, 2), (3, 2)), + ((2, 3, 2), ()), + ((3, 1, 2), (1, 3, 2)), + ) + + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + for shape, noncontiguous in product(shapes, [True, False]): + shape_lhs, shape_rhs = shape + lhs = make_arg( + shape_lhs, noncontiguous=noncontiguous, **op.lhs_make_tensor_kwargs + ) + rhs = make_arg( + shape_rhs, noncontiguous=noncontiguous, **op.rhs_make_tensor_kwargs + ) + + yield SampleInput( + lhs, + args=(rhs,), + broadcasts_input=True, + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + + +# Returns a generator of pairs of contiguous tensors and scalars +def generate_elementwise_binary_with_scalar_samples( + op, *, device, dtype, requires_grad=False +): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + shapes = ((), (3,), (5, 3), (0, 1, 3), (1, 5)) + if op.supports_rhs_python_scalar: + for shape in shapes: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item() + rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item() + + yield SampleInput( + lhs, args=(rhs_scalar,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + + # Extends with scalar lhs + if op.supports_one_python_scalar: + yield SampleInput( + lhs_scalar, + args=(rhs,), + kwargs=op.sample_kwargs(device, dtype, lhs_scalar)[0], + ) + + if op.supports_two_python_scalars: + lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item() + rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item() + + yield SampleInput( + lhs_scalar, + args=(rhs_scalar,), + kwargs=op.sample_kwargs(device, dtype, lhs_scalar)[0], + ) + + +# Returns a generator of pairs of contiguous tensors and 0d tensors and scalars and type promotion +def generate_elementwise_binary_with_scalar_and_type_promotion_samples( + op, *, device, dtype, requires_grad=False +): + # add these samples only for logical and comparison ops, arithmetic ops are not happy about extremal scalars + if op.name in ( + "eq", + "ne", + "gt", + "ge", + "lt", + "le", + "logical_and", + "logical_or", + "logical_xor", + ): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + shape = ( + 23, + ) # this shape is big enough to trigger vectorization, and has non-vectorized tail + values = (float("nan"), float("inf"), -float("inf")) + scalar_tensors = tuple(torch.tensor(val) for val in values) + if op.supports_rhs_python_scalar: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + for scalar in values + scalar_tensors: + yield SampleInput( + lhs, args=(scalar,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + # Extends with scalar lhs + if op.supports_one_python_scalar: + yield SampleInput( + scalar, + args=(rhs,), + kwargs=op.sample_kwargs(device, dtype, scalar)[0], + ) + + +# Returns a generator of pairs of noncontiguous tensors +def generate_elementwise_binary_noncontiguous_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=False +): + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + + # Generic noncontiguity + lhs = make_arg((1026,), noncontiguous=True, **op.lhs_make_tensor_kwargs) + rhs = make_arg((1026,), noncontiguous=True, **op.rhs_make_tensor_kwargs) + + yield SampleInput( + lhs.clone(), args=(rhs.clone(),), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + yield SampleInput( + lhs.contiguous(), args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + + # Transposed + lhs = make_arg((789, 357), **op.lhs_make_tensor_kwargs) + rhs = make_arg((789, 357), **op.rhs_make_tensor_kwargs) + + yield SampleInput( + lhs.T, args=(rhs.T,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + + # More noncontiguity + shapes = ((5, 7), (1024,)) + + for shape in shapes: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + + lhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0] + lhs_non_contig.copy_(lhs) + + rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0] + rhs_non_contig.copy_(rhs) + + yield SampleInput( + lhs_non_contig.clone(), + args=(rhs_non_contig.clone(),), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + yield SampleInput( + lhs_non_contig.contiguous(), + args=(rhs_non_contig,), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + + # Noncontiguous indices + shape = (2, 2, 1, 2) + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + + lhs_non_contig = lhs[:, 1, ...] + rhs_non_contig = rhs[:, 1, ...] + + yield SampleInput( + lhs_non_contig.clone(), + args=(rhs_non_contig.clone(),), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + yield SampleInput( + lhs_non_contig.contiguous(), + args=(rhs_non_contig,), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + + # Expanded tensors + shapes = ((1, 3), (1, 7), (5, 7)) + + for shape in shapes: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + + lhs_non_contig = lhs.expand(3, -1, -1) + rhs_non_contig = rhs.expand(3, -1, -1) + + yield SampleInput( + lhs_non_contig, + args=(rhs_non_contig,), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + + +# Sample inputs for elementwise binary operators, like add +def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs): + _M = S if kwargs.get("small_inputs_only", False) else M + _S = XS if kwargs.get("small_inputs_only", False) else S + + if hasattr(op, "rhs_make_tensor_kwargs"): + exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) + + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + + shapes = ( + ((), ()), + ((_S,), ()), + ((_S, 1), (_S,)), + ((_M, _S), ()), + ((_S, _M, _S), (_M, _S)), + ((_S, _M, _S), (_S, _M, _S)), + ((_M, 1, _S), (_M, _S)), + ((_M, 1, _S), (1, _M, _S)), + ((0, 1, XS), (0, _M, XS)), + ) + + for shape_lhs, shape_rhs in shapes: + lhs = make_arg(shape_lhs, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape_rhs, **op.rhs_make_tensor_kwargs) + broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs) + + yield SampleInput( + lhs, + args=(rhs,), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + broadcasts_input=broadcasts_input, + ) + + +# Metadata class for binary "universal functions (ufuncs)" that accept two +# tensor and have common properties +class BinaryUfuncInfo(OpInfo): + """Operator information for 'universal binary functions (binary ufuncs).' + These are functions of two tensors with common properties like: + - they are elementwise functions + - the output shape is determined by the input shape + - they typically have method and inplace variants + - they typically support the out kwarg + - they typically have NumPy or SciPy references + See NumPy's universal function documentation + (https://numpy.org/doc/stable/reference/ufuncs.html) for more details + about the concept of ufuncs. + """ + + def __init__( + self, + name, + *, + sample_inputs_func=sample_inputs_elementwise_binary, + reference_inputs_func=reference_inputs_elementwise_binary, + sample_kwargs=lambda device, dtype, input: ({}, {}), + error_inputs_func=None, + lhs_make_tensor_kwargs=None, + rhs_make_tensor_kwargs=None, + always_returns_bool=False, # Set to true if the op always returns bool tensors + supports_rhs_python_scalar=True, # Whether the operator allows Tensor x scalar inputs + supports_one_python_scalar=False, # Whether the operator allows scalar x tensor and tensor x scalar inputs + supports_two_python_scalars=False, # Whether the operator allows scalar x scalar inputs + **kwargs, + ): + self._original_binary_ufunc_args = locals().copy() + + # Elementwise binary operations perform the equivalent of test_numpy_refs + # in test_binary_ufuncs, but with additional test granularity. So the + # generic test_ops.py test is skipped because it's redundant. + common_skips = ( + DecorateInfo( + unittest.skip("Skipping redundant test."), + "TestCommon", + "test_numpy_refs", + ), + ) + kwargs["skips"] = kwargs.get("skips", ()) + common_skips + super().__init__( + name, + sample_inputs_func=sample_inputs_func, + reference_inputs_func=reference_inputs_func, + error_inputs_func=make_error_inputs_elementwise_binary(error_inputs_func), + **kwargs, + ) + + self.sample_kwargs = sample_kwargs + + # [lr]hs_make_tensor_kwargs are part of the OpInfo to be able to dynamically generate valid samples later on. + if lhs_make_tensor_kwargs is None: + lhs_make_tensor_kwargs = {} + self.lhs_make_tensor_kwargs = lhs_make_tensor_kwargs + + if rhs_make_tensor_kwargs is None: + rhs_make_tensor_kwargs = {} + self.rhs_make_tensor_kwargs = rhs_make_tensor_kwargs + + self.always_returns_bool = always_returns_bool + self.supports_rhs_python_scalar = supports_rhs_python_scalar + self.supports_one_python_scalar = supports_one_python_scalar + self.supports_two_python_scalars = supports_two_python_scalars + + if self.supports_two_python_scalars: + self.supports_one_python_scalar = True + + if self.supports_one_python_scalar: + assert supports_rhs_python_scalar, ( + "Can't support lhs and rhs Python scalars but not rhs scalars!" + ) + + +# The following functions and classes are for testing elementwise unary operators. +def sample_inputs_elementwise_unary( + op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs +): + if not op_kwargs: + op_kwargs = {} + + _L = S if kwargs.get("small_inputs_only", False) else L + + low, high = op_info.domain + is_floating = dtype.is_floating_point or dtype.is_complex + low = low if low is None or not is_floating else low + op_info._domain_eps + high = high if high is None or not is_floating else high - op_info._domain_eps + if ( + op_info.supports_sparse_csr + or op_info.supports_sparse_csc + or op_info.supports_sparse_bsr + or op_info.supports_sparse_bsc + ): + # Tensors with dim=2 for sparse compressed testing + yield SampleInput( + make_tensor( + (_L, _L), + device=device, + dtype=dtype, + low=low, + high=high, + requires_grad=requires_grad, + ), + kwargs=op_kwargs, + ) + else: + # Creates a 1D, empty, and scalar tensor + for shape in ((_L,), (1, 0, 3), ()): + yield SampleInput( + make_tensor( + shape, + device=device, + dtype=dtype, + low=low, + high=high, + requires_grad=requires_grad, + ), + kwargs=op_kwargs, + ) + + +# Replace values satisfying condition with a safe value. This is used to block +# out values the could cause singularity like tan(pi/2) +def _replace_values_in_tensor(tensor, condition, safe_value): + mask = condition(tensor) + tensor.masked_fill_(mask, safe_value) + + +# Helper to create a unary elementwise tensor with valid inputs +def _make_unary_elementwise_tensor(shape, *, op, dtype, **kwargs): + low, high = op.domain + is_floating = dtype.is_floating_point or dtype.is_complex + low = low if low is None or not is_floating else low + op._domain_eps + high = high if high is None or not is_floating else high - op._domain_eps + + a = make_tensor(shape, low=low, high=high, dtype=dtype, **kwargs) + + if op.reference_numerics_filter is not None and dtype is not torch.bool: + condition, safe_value = op.reference_numerics_filter + _replace_values_in_tensor(a, condition, safe_value) + + return a + + +# Restricts the values in the tensor to the domain of the +# given elementwise unary operator +def _filter_unary_elementwise_tensor(a, *, op): + # short-circuits for boolean tensors + if a.dtype is torch.bool: + return a + + low, high = op.domain + is_floating = a.dtype.is_floating_point or a.dtype.is_complex + low = low if low is None or not is_floating else low + op._domain_eps + high = high if high is None or not is_floating else high - op._domain_eps + + if a.dtype is torch.uint8 and low is not None: + low = max(low, 0) + + if not a.dtype.is_floating_point and not a.dtype.is_complex: + low = math.ceil(low) if low is not None else None + high = math.floor(high) if high is not None else None + + if op.reference_numerics_filter is not None: + condition, safe_value = op.reference_numerics_filter + _replace_values_in_tensor(a, condition, safe_value) + + if low is not None or high is not None: + if a.dtype.is_complex: + a.real.clamp_(low, high) + a.imag.clamp_(low, high) + else: + a.clamp_(min=low, max=high) + + return a + + +def generate_elementwise_unary_tensors(op, *, device, dtype, requires_grad, **kwargs): + # Special-cases bool + if dtype is torch.bool: + tensors = ( + torch.empty(0, device=device, dtype=torch.bool), + torch.tensor(True, device=device), + torch.tensor(False, device=device), + torch.tensor((True, False), device=device), + make_tensor((812,), device=device, dtype=dtype), + make_tensor((1029, 917), device=device, dtype=dtype), + ) + for a in tensors: + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + shapes = ( + (1029, 917), + (812,), + # Empty sizes + (0,), + (0, 3, 3), + (1, 0, 5), + (6, 0, 0, 0), + (3, 0, 1, 0), + ) + + make_arg = partial( + _make_unary_elementwise_tensor, + op=op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + for shape in shapes: + a = make_arg(shape) + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +def generate_elementwise_unary_small_value_tensors( + op, *, device, dtype, requires_grad=False +): + for sample in generate_elementwise_binary_small_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ): + a = _filter_unary_elementwise_tensor(sample.input, op=op) + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +def generate_elementwise_unary_large_value_tensors( + op, *, device, dtype, requires_grad=False +): + for sample in generate_elementwise_binary_large_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ): + a = _filter_unary_elementwise_tensor(sample.input, op=op) + yield SampleInput(sample.input, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +def generate_elementwise_unary_extremal_value_tensors( + op, *, device, dtype, requires_grad=False +): + for sample in generate_elementwise_binary_extremal_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ): + yield SampleInput( + sample.input, kwargs=op.sample_kwargs(device, dtype, sample.input)[0] + ) + + +def generate_elementwise_unary_noncontiguous_tensors( + op, *, device, dtype, requires_grad=False +): + make_arg = partial( + _make_unary_elementwise_tensor, + op=op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + # Generic noncontiguity + t = make_arg((1026,), noncontiguous=True) + yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0]) + + # Transposed + t = make_arg((1024, 1024)).T + yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0]) + + # Expanded tensors + shapes = ((1, 3), (1, 7), (5, 7)) + + for shape in shapes: + t = make_arg(shape) + t_non_contig = t.expand(3, -1, -1) + yield SampleInput( + t_non_contig, kwargs=op.sample_kwargs(device, dtype, t_non_contig)[0] + ) + + +def generate_elementwise_unary_arbitrarily_strided_tensors( + op, *, device, dtype, requires_grad=False +): + # shape, strides, offset + strided_cases = ( + ((5, 6, 2), (1, 1, 7), 2), + ((5, 5, 4), (1, 1, 7), 2), + ((5, 5, 2), (4, 5, 7), 3), + ((5, 5, 2), (5, 5, 7), 3), + ((5, 5, 2), (5, 5, 5), 3), + ((9, 5, 2), (0, 1, 7), 3), + ) + + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + for shape, strides, offset in strided_cases: + a = make_arg( + 500, + ).as_strided(shape, strides, offset) + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +# Reuses the elementwise binary generators for consistency +# TODO: in the future generalize the reference generators to handle n-ary elementwise operations +def _reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs): + yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs) + + yield from generate_elementwise_unary_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + if dtype is not torch.bool: + yield from generate_elementwise_unary_small_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + if dtype not in (torch.bool, torch.uint8, torch.int8) and ( + op.handles_large_floats + or (not dtype.is_floating_point and not dtype.is_complex) + ): + yield from generate_elementwise_unary_large_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + if dtype.is_floating_point or ( + op.handles_complex_extremal_values and dtype.is_complex + ): + yield from generate_elementwise_unary_extremal_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + +def reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs): + gen = partial( + _reference_inputs_elementwise_unary, op, device, dtype, requires_grad, **kwargs + ) + + # yields "normal" samples + yield from gen() + + # yields noncontiguous samples + for sample in gen(): + yield sample.noncontiguous() + + yield from generate_elementwise_unary_noncontiguous_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + yield from generate_elementwise_unary_arbitrarily_strided_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + +# Metadata class for unary "universal functions (ufuncs)" that accept a single +# tensor and have common properties like: +class UnaryUfuncInfo(OpInfo): + """Operator information for 'universal unary functions (unary ufuncs).' + These are functions of a single tensor with common properties like: + - they are elementwise functions + - the input shape is the output shape + - they typically have method and inplace variants + - they typically support the out kwarg + - they typically have NumPy or SciPy references + See NumPy's universal function documentation + (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details + about the concept of ufuncs. + """ + + def __init__( + self, + name, # the string name of the function + *, + dtypes=floating_types(), + domain=(None, None), # the [low, high) domain of the function + handles_complex_extremal_values=True, # whether the op correctly handles extremal values (like nan/inf) + handles_large_floats=True, # whether the op correctly handles large float values (like 1e20) + supports_complex_to_float=False, # op supports casting from complex input to real output safely eg. angle + sample_inputs_func=sample_inputs_elementwise_unary, + reference_inputs_func=reference_inputs_elementwise_unary, + sample_kwargs=lambda device, dtype, input: ({}, {}), + reference_numerics_filter=None, # Filters values in the range of the domain specified above but that should not be tested + **kwargs, + ): + self._original_unary_ufunc_args = locals().copy() + + super().__init__( + name, + dtypes=dtypes, + sample_inputs_func=sample_inputs_func, + reference_inputs_func=reference_inputs_func, + **kwargs, + ) + self.domain = domain + self.handles_complex_extremal_values = handles_complex_extremal_values + self.handles_large_floats = handles_large_floats + self.supports_complex_to_float = supports_complex_to_float + self.reference_numerics_filter = reference_numerics_filter + + # test_unary_ufuncs.py generates its own inputs to test the consistency + # of the operator on sliced tensors, non-contig tensors, etc. + # `sample_kwargs` is a utility function to provide kwargs + # along with those inputs if required (eg. clamp). + # It should return two dictionaries, first holding kwarg for + # torch operator and second one for reference NumPy operator. + self.sample_kwargs = sample_kwargs + + # Epsilon to ensure grad and gradgrad checks don't test values + # outside a function's domain. + self._domain_eps = 1e-5 + + +def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs): + is_fp16_or_chalf = dtype == torch.complex32 or dtype == torch.half + if not is_fp16_or_chalf: + nd_tensor = partial( + make_tensor, + (S, S + 1, S + 2), + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + oned_tensor = partial( + make_tensor, (31,), device=device, dtype=dtype, requires_grad=requires_grad + ) + else: + # cuFFT supports powers of 2 for half and complex half precision + # NOTE: For hfft, hfft2, hfftn, irfft, irfft2, irfftn with default args + # where output_size n=2*(input_size - 1), we make sure that logical fft size is a power of two + low = None + high = None + if self.name in ["fft.hfft", "fft.irfft", "_refs.fft.hfft", "_refs.fft.irfft"]: + shapes = ((2, 9, 9), (33,)) + elif self.name in [ + "fft.hfft2", + "fft.irfft2", + "_refs.fft.hfft2", + "_refs.fft.irfft2", + ]: + shapes = ((2, 8, 9), (33,)) + elif self.name in [ + "fft.hfftn", + "fft.irfftn", + "_refs.fft.hfftn", + "_refs.fft.irfftn", + ]: + shapes = ((2, 2, 33), (33,)) + # Adjusting the limits because the test would be flaky due to over-saturation of float16 + # See: https://github.com/pytorch/pytorch/pull/81416 + low = -1.0 + high = 1.0 + else: + shapes = ((2, 8, 16), (32,)) + nd_tensor = partial( + make_tensor, + shapes[0], + device=device, + low=low, + high=high, + dtype=dtype, + requires_grad=requires_grad, + ) + oned_tensor = partial( + make_tensor, + shapes[1], + device=device, + low=low, + high=high, + dtype=dtype, + requires_grad=requires_grad, + ) + + if self.ndimensional == SpectralFuncType.ND: + yield SampleInput( + nd_tensor(), + s=(3, 10) if not is_fp16_or_chalf else (4, 8), + dim=(1, 2), + norm="ortho", + ) + yield SampleInput(nd_tensor(), norm="ortho") + yield SampleInput(nd_tensor(), s=(8,)) + yield SampleInput(oned_tensor()) + yield from (SampleInput(nd_tensor(), dim=dim) for dim in [-1, -2, -3, (0, -1)]) + elif self.ndimensional == SpectralFuncType.TwoD: + yield SampleInput( + nd_tensor(), + s=(3, 10) if not is_fp16_or_chalf else (4, 8), + dim=(1, 2), + norm="ortho", + ) + yield SampleInput(nd_tensor(), norm="ortho") + yield SampleInput(nd_tensor(), s=(6, 8) if not is_fp16_or_chalf else (4, 8)) + yield SampleInput(nd_tensor(), dim=0) + yield SampleInput(nd_tensor(), dim=(0, -1)) + yield SampleInput(nd_tensor(), dim=(-3, -2, -1)) + else: + yield SampleInput( + nd_tensor(), + n=10 if not is_fp16_or_chalf else 8, + dim=1, + norm="ortho", + ) + yield SampleInput(nd_tensor(), norm="ortho") + yield SampleInput(nd_tensor(), n=7 if not is_fp16_or_chalf else 8) + yield SampleInput(oned_tensor()) + yield from (SampleInput(nd_tensor(), dim=dim) for dim in [-1, -2, -3]) + + +SpectralFuncType = Enum("SpectralFuncType", ("OneD", "TwoD", "ND")) + + +# Metadata class for Fast Fourier Transforms in torch.fft. +class SpectralFuncInfo(OpInfo): + """Operator information for torch.fft transforms.""" + + def __init__( + self, + name, # the string name of the function + *, + ref=None, # Reference implementation (probably in np.fft namespace) + dtypes=floating_and_complex_types(), + ndimensional: SpectralFuncType, + sample_inputs_func=sample_inputs_spectral_ops, + decorators=None, + **kwargs, + ): + self._original_spectral_func_args = dict(locals()).copy() + self._original_spectral_func_args.update(kwargs) + + decorators = list(decorators) if decorators is not None else [] + decorators += [ + skipCPUIfNoFFT, + DecorateInfo( + toleranceOverride({torch.chalf: tol(4e-2, 4e-2)}), + "TestCommon", + "test_complex_half_reference_testing", + ), + ] + + super().__init__( + name=name, + dtypes=dtypes, + decorators=decorators, + sample_inputs_func=sample_inputs_func, + **kwargs, + ) + self.ref = ref + self.ndimensional = ndimensional + + +class ShapeFuncInfo(OpInfo): + """Early version of a specialized OpInfo for Shape manipulating operations like tile and roll""" + + def __init__( + self, + name, # the string name of the function + *, + ref, # a reference function + dtypes=floating_types(), + dtypesIfCUDA=None, + dtypesIfROCM=None, + dtypesIfXPU=None, + sample_inputs_func=None, + **kwargs, + ): + super().__init__( + name, + dtypes=dtypes, + dtypesIfCUDA=dtypesIfCUDA, + dtypesIfROCM=dtypesIfROCM, + dtypesIfXPU=dtypesIfXPU, + sample_inputs_func=sample_inputs_func, + **kwargs, + ) + self.ref = ref + + +def sample_inputs_foreach( + self, + device, + dtype, + N, + *, + noncontiguous=False, + same_size=False, + low=None, + high=None, + # zero_size means EVERY input is empty + zero_size: bool, + requires_grad: bool, + # mutually exclusive from same_size and zero_size, which are all or nothing + intersperse_empty_tensors: bool = False, +): + if zero_size: + return [torch.empty(0, dtype=dtype, device=device) for _ in range(N)] + if same_size: + return [ + make_tensor( + (N, N), + dtype=dtype, + device=device, + noncontiguous=noncontiguous, + low=low, + high=high, + requires_grad=requires_grad, + ) + for _ in range(N) + ] + else: + # interweave some empty tensors + have the last 2 tensors be empty (see #100701) + return [ + torch.empty(0, dtype=dtype, device=device, requires_grad=requires_grad) + if (i % 3 == 0 or i >= N - 2) and intersperse_empty_tensors + else make_tensor( + (N - i, N - i), + dtype=dtype, + device=device, + noncontiguous=noncontiguous, + low=low, + high=high, + requires_grad=requires_grad, + ) + for i in range(N) + ] + + +def get_foreach_method_names(name): + # get torch inplace reference function + op_name = "_foreach_" + name + inplace_op_name = op_name + "_" + + op = getattr(torch, op_name, None) + inplace_op = getattr(torch, inplace_op_name, None) + + ref = getattr(torch, name, None) + ref_inplace = getattr(torch.Tensor, name + "_", None) + return op, inplace_op, ref, ref_inplace + + +@dataclass +class ForeachFuncInfo(OpInfo): + """Early version of a specialized OpInfo for foreach functions + + The main differences from the parent class are (a) `dtypes`, `dtypesIfCUDA`, and `dtypesIfROCM` + are set to `get_all_dtypes(include_qint=False)`, and (b) the following arguments. + + ``supports_alpha_param=True`` means that the function supports a python scalar (``numbers.Number``) + as the last keyword argument such as `_foreach_add`. + ``supports_scalar_self_arg=True`` means that the function can take a python scalar as its first argument. + Currently only `_foreach_pow` supports this. + ``backward_requires_result=True``, which could sound self-explanatory, means that the function uses + the forward result for its backward computation. + """ + + supports_alpha_param: bool = False + supports_scalar_self_arg: bool = False + backward_requires_result: bool = False + + def __post_init__(self): + ( + foreach_method, + foreach_method_inplace, + torch_ref_method, + torch_ref_inplace, + ) = get_foreach_method_names(self.name) + if not self.supports_out: + # note(crcrpar): `foreach_method` for `"zero"` is `None` but `None` would call + # `_getattr_qual` in `OpInfo.__post_init__` which should fail since `_foreach_zero` + # is not defined at the moment. Thus to skip the qualification, set a similar torch + # function. + assert foreach_method is None + assert torch_ref_method is None + foreach_method = foreach_method_inplace + torch_ref_method = torch_ref_inplace + + # We disable all complex128 tests internally for foreach due to reported flakiness + # tracked in #139648 + supported_dtypes = get_all_dtypes(include_qint=False) + if IS_FBCODE: + supported_dtypes = [ + x for x in supported_dtypes if x is not torch.complex128 + ] + self.dtypes = _dispatch_dtypes(supported_dtypes) + + self.op = foreach_method + self.method_variant = foreach_method + self.ref = torch_ref_method + self.inplace_variant = foreach_method_inplace + self.ref_inplace = torch_ref_inplace + self.has_no_in_place = self.inplace_variant is None + + name = self.name + self.name = f"_foreach_{name}" + if name == "norm": + self.ref = torch.linalg.vector_norm + elif name == "minimum": + # because minimum ref does not support inplace or scalar + self.ref = torch.clamp_max + self.ref_inplace = torch.Tensor.clamp_max_ + elif name == "maximum": + # because maximum ref does not support inplace or scalar + self.ref = torch.clamp_min + self.ref_inplace = torch.Tensor.clamp_min_ + + # The following sets `dtypesIfCUDA` and `dtypesIfROCM` accordingly. + super().__post_init__() + + def sample_zero_size_inputs(self, device, dtype, requires_grad=False, **kwargs): + if not hasattr(self.sample_inputs_func, "sample_zero_size_tensor_inputs"): + return [] + return self.sample_inputs_func.sample_zero_size_tensor_inputs( + self, device, dtype, requires_grad, **kwargs + ) + + +def gradcheck_wrapper_hermitian_input(op, input, *args, **kwargs): + """Gradcheck wrapper for functions that take Hermitian matrices as input. + + They require a modified function because the finite-difference algorithm + for calculating derivatives does not preserve the Hermitian property of the input. + """ + return op(input + input.mH, *args, **kwargs) + + +def gradcheck_wrapper_ctc_loss(op, input, *args, **kwargs): + """Gradcheck wrapper for ctc loss to project onto log-simplex space.""" + # See https://github.com/pytorch/pytorch/issues/52241 + return op(input.log_softmax(dim=2), *args, **kwargs) + + +def gradcheck_wrapper_triangular_input(op, *args, upper=False, idx=0, **kwargs): + """Gradcheck wrapper for functions that take lower or upper triangular matrices as input. + + They require a modified function because the finite-difference algorithm + for calculating derivatives does not preserve the triangular property of the input. + `idx` is used to specific which `args[idx]` is to be triangularized. + """ + triangular_arg = args[idx].triu() if upper else args[idx].tril() + return op(*args[:idx], triangular_arg, *args[idx + 1 :], upper, **kwargs) + + +def gradcheck_wrapper_triangular_input_real_positive_diagonal( + op, *args, upper=False, idx=0, **kwargs +): + """Gradcheck wrapper for functions that take lower/upper triangular matrices + with real and positive diagonals, for example, cholesky-like operations. + """ + arg = args[idx] + arg_diag = arg.diagonal(0, -2, -1) + arg_diag_embed = torch.diag_embed(arg_diag) + id_diag_tensor = torch.ones_like(arg_diag) + id_tensor = torch.diag_embed(id_diag_tensor) + # new_arg = arg - diag(arg) + I + new_arg = arg - arg_diag_embed + id_tensor + return gradcheck_wrapper_triangular_input( + op, *args[:idx], new_arg, *args[idx + 1 :], upper=upper, idx=idx, **kwargs + ) + + +def gradcheck_wrapper_masked_operation(op, input, *args, **kwargs): + """Gradcheck wrapper for masked operations. + + When mask is specified, replaces masked-out elements with zeros. + + Use for operations that produce non-finite masked-out elements, + for instance, for minimum and maximum reductions. + """ + output = op(input, *args, **kwargs) + mask = kwargs.get("mask") + if mask is not None: + output_mask = torch.masked._output_mask(op, input, *args, **kwargs) + output = torch.where(output_mask, output, output.new_zeros([])) + return output + + +def gradcheck_wrapper_masked_pointwise_operation(op, input, *args, **kwargs): + """Gradcheck wrapper for masked pointwise operations. Assumes that the result + will be masked iff both tensors are masked at a specific index + + When mask is specified, replaces masked-out elements with zeros. + + Use for operations that produce non-finite masked-out elements, + for instance, for minimum and maximum reductions. + """ + output = op(input, *args, **kwargs) + input_mask = kwargs.get("input_mask") + other_mask = kwargs.get("other_mask") + if input_mask is not None and other_mask is not None: + combined_mask = torch.logical_and(input_mask, other_mask) + new_kwargs = dict(mask=combined_mask, **kwargs) + output_mask = torch.masked._input_mask(input, *args, **new_kwargs) + output = torch.where(output_mask, output, output.new_zeros([])) + return output + + +def clone_sample(sample, **kwargs): + """ + Given a SampleInput, this function analyzes its input, args and kwargs, + and produces a copy with each non-Tensor entry being copied by reference, + and with each Tensor entry cloned with `t.clone().requires_grad_(t.requires_grad)` + """ + + def clone_tensor(t): + if isinstance(t, torch.Tensor): + return t.detach().clone().requires_grad_(t.requires_grad) + else: + return t + + sample_kwargs = kwargs if kwargs else sample.kwargs + + return SampleInput( + clone_tensor(sample.input), + args=tuple(map(clone_tensor, sample.args)), + kwargs={k: clone_tensor(v) for k, v in sample_kwargs.items()}, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/refs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/refs.py new file mode 100644 index 0000000000000000000000000000000000000000..435a9d113164b3652af4d246655f579d1b72d4dc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/refs.py @@ -0,0 +1,207 @@ +# mypy: ignore-errors + +from torch.testing._internal.opinfo.core import ( + BinaryUfuncInfo, + OpInfo, + ReductionOpInfo, + UnaryUfuncInfo, +) + + +# NOTE [Python References] +# Python References emulate existing PyTorch operations, but can ultimately +# be expressed in terms of "primitive" operations from torch._prims. +# +# These references are experimental. +# See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577 +# for additional context. +# +# Python Reference OpInfos should be added to the python_ref_db list below. +# Tests can opt-into running on these references by including +# that list in the Sequence they pass to the @ops decorator. +# +# When a Python Reference OpInfo is constructed a pointer to an +# existing OpInfo must be provided using the torch_opinfo_name kwarg. +# The existing OpInfo with that name and no variant will be found +# to inherit from. +# +# Instead of just inheriting the existing OpInfo's metadata, the +# Python Reference OpInfos inherit the existing OpInfo's +# construction arguments. These arguments can be overridden +# by adding kwargs to the constructor. + + +def _find_referenced_opinfo(referenced_name, variant_name, *, op_db=None): + """ + Finds the OpInfo with the given name that has no variant name. + """ + # NOTE: searching the global op_db doesn't work when OpInfos are split into + # different modules, as otherwise the op_db will not be fully constructed + # yet. So, instead the local op_db must be passed in explicitly. + if op_db is None: + from torch.testing._internal.common_methods_invocations import op_db + + for opinfo in op_db: + if opinfo.name == referenced_name and opinfo.variant_test_name == variant_name: + return opinfo + + +def _inherit_constructor_args(name, op, inherited, overrides): + # inherits metadata + common_kwargs = { + "name": name, + "op": op, + "aliases": None, # TODO add a check for alias coverage + "method_variant": None, + "inplace_variant": None, # TODO: add a check for inplace coverage + "supports_scripting": False, + } + + # Acquires inherited kwargs + kwargs = inherited.copy() + + # Fixes metadata + if "kwargs" in kwargs: + kwargs.update(kwargs["kwargs"]) + del kwargs["kwargs"] + if "self" in kwargs: + del kwargs["self"] + if "__class__" in kwargs: + del kwargs["__class__"] + if "skips" in kwargs: + del kwargs["skips"] + if "decorators" in kwargs: + del kwargs["decorators"] + + # Overrides metadata + kwargs.update(common_kwargs) + kwargs.update(overrides) + + # At the moment no prims support autograd, so we must not run autograd + # tests e.g. when testing dtype support. Once we start writing autograd + # formulas for prims this can be removed. + kwargs["supports_autograd"] = False + kwargs["supports_gradgrad"] = False + kwargs["supports_fwgrad_bwgrad"] = False + kwargs["supports_inplace_autograd"] = False + kwargs["supports_forward_ad"] = False + + return kwargs + + +class PythonRefInfo(OpInfo): + """ + An OpInfo for a Python reference of an OpInfo base class operation. + """ + + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + op_db=None, # The database of opinfos to search for the parent opinfo + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo + validate_view_consistency=True, + **kwargs, + ): # additional kwargs override kwargs inherited from the torch opinfo + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo( + torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db + ) + self.validate_view_consistency = validate_view_consistency + assert isinstance(self.torch_opinfo, OpInfo) + + inherited = self.torch_opinfo._original_opinfo_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + super().__init__(**ukwargs) + + +class ReductionPythonRefInfo(ReductionOpInfo): + """ + An OpInfo for a Python reference of an elementwise unary operation. + """ + + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + op_db=None, # The database of opinfos to search for the parent opinfo + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo + **kwargs, + ): # additional kwargs override kwargs inherited from the torch opinfo + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo( + torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db + ) + assert isinstance(self.torch_opinfo, ReductionOpInfo) + + inherited = self.torch_opinfo._original_reduction_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + + # See https://github.com/pytorch/pytorch/issues/77216 + self.validate_view_consistency = False + + super().__init__(**ukwargs) + + +class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo): + """ + An OpInfo for a Python reference of an elementwise unary operation. + """ + + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + op_db=None, # The database of opinfos to search for the parent opinfo + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo + validate_view_consistency=True, + **kwargs, + ): # additional kwargs override kwargs inherited from the torch opinfo + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo( + torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db + ) + self.validate_view_consistency = validate_view_consistency + assert isinstance(self.torch_opinfo, UnaryUfuncInfo) + + inherited = self.torch_opinfo._original_unary_ufunc_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + + super().__init__(**ukwargs) + + +class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo): + """ + An OpInfo for a Python reference of an elementwise binary operation. + """ + + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + op_db=None, # The database of opinfos to search for the parent opinfo + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo + **kwargs, + ): # additional kwargs override kwargs inherited from the torch opinfo + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo( + torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db + ) + assert isinstance(self.torch_opinfo, BinaryUfuncInfo) + + inherited = self.torch_opinfo._original_binary_ufunc_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + + super().__init__(**ukwargs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e2127e956b46c711961bf90d822a461b99aedd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/utils.py @@ -0,0 +1,276 @@ +# mypy: ignore-errors + +import collections +import warnings +from collections.abc import Sequence +from functools import partial, wraps + +import numpy as np +import numpy.typing as npt + +import torch +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_dtype import ( + _dispatch_dtypes, + all_types, + all_types_and, + all_types_and_complex, + all_types_and_complex_and, + all_types_and_half, + complex_types, + floating_and_complex_types, + floating_and_complex_types_and, + floating_types, + floating_types_and, + floating_types_and_half, + integral_types, + integral_types_and, +) +from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict + + +COMPLETE_DTYPES_DISPATCH = ( + all_types, + all_types_and_complex, + all_types_and_half, + floating_types, + floating_and_complex_types, + floating_types_and_half, + integral_types, + complex_types, +) + +EXTENSIBLE_DTYPE_DISPATCH = ( + all_types_and_complex_and, + floating_types_and, + floating_and_complex_types_and, + integral_types_and, + all_types_and, +) + +# Better way to acquire devices? +DEVICES = ["cpu"] + (["cuda"] if TEST_CUDA else []) + + +class _dynamic_dispatch_dtypes(_dispatch_dtypes): + # Class to tag the dynamically generated types. + pass + + +def get_supported_dtypes(op, sample_inputs_fn, device_type): + # Returns the supported dtypes for the given operator and device_type pair. + assert device_type in ["cpu", "cuda"] + if not TEST_CUDA and device_type == "cuda": + warnings.warn( + "WARNING: CUDA is not available, empty_dtypes dispatch will be returned!", + stacklevel=2, + ) + return _dynamic_dispatch_dtypes(()) + + supported_dtypes = set() + for dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half): + try: + samples = sample_inputs_fn(op, device_type, dtype, False) + except RuntimeError: + # If `sample_inputs_fn` doesn't support sampling for a given + # `dtype`, we assume that the `dtype` is not supported. + # We raise a warning, so that user knows that this was the case + # and can investigate if there was an issue with the `sample_inputs_fn`. + warnings.warn( + f"WARNING: Unable to generate sample for device:{device_type} and dtype:{dtype}", + stacklevel=2, + ) + continue + + # We assume the dtype is supported + # only if all samples pass for the given dtype. + supported = True + for sample in samples: + try: + op(sample.input, *sample.args, **sample.kwargs) + except RuntimeError: + # dtype is not supported + supported = False + break + + if supported: + supported_dtypes.add(dtype) + + return _dynamic_dispatch_dtypes(supported_dtypes) + + +def dtypes_dispatch_hint(dtypes): + # Function returns the appropriate dispatch function (from COMPLETE_DTYPES_DISPATCH and EXTENSIBLE_DTYPE_DISPATCH) + # and its string representation for the passed `dtypes`. + return_type = collections.namedtuple("return_type", "dispatch_fn dispatch_fn_str") + + # CUDA is not available, dtypes will be empty. + if len(dtypes) == 0: + return return_type((), "()") + + set_dtypes = set(dtypes) + for dispatch in COMPLETE_DTYPES_DISPATCH: + # Short circuit if we get an exact match. + if set(dispatch()) == set_dtypes: + return return_type(dispatch, dispatch.__name__ + "()") + + chosen_dispatch = None + chosen_dispatch_score = 0.0 + for dispatch in EXTENSIBLE_DTYPE_DISPATCH: + dispatch_dtypes = set(dispatch()) + if not dispatch_dtypes.issubset(set_dtypes): + continue + + score = len(dispatch_dtypes) + if score > chosen_dispatch_score: + chosen_dispatch_score = score + chosen_dispatch = dispatch + + # If user passed dtypes which are lower than the lowest + # dispatch type available (not likely but possible in code path). + if chosen_dispatch is None: + return return_type((), str(dtypes)) + + return return_type( + partial(dispatch, *tuple(set(dtypes) - set(dispatch()))), + dispatch.__name__ + str(tuple(set(dtypes) - set(dispatch()))), + ) + + +def is_dynamic_dtype_set(op): + # Detect if the OpInfo entry acquired dtypes dynamically + # using `get_supported_dtypes`. + return op.dynamic_dtypes + + +def str_format_dynamic_dtype(op): + fmt_str = f""" + OpInfo({op.name}, + dtypes={dtypes_dispatch_hint(op.dtypes).dispatch_fn_str}, + dtypesIfCUDA={dtypes_dispatch_hint(op.dtypesIfCUDA).dispatch_fn_str}, + ) + """ + + return fmt_str + + +def np_unary_ufunc_integer_promotion_wrapper(fn): + # Wrapper that passes PyTorch's default scalar + # type as an argument to the wrapped NumPy + # unary ufunc when given an integer input. + # This mimics PyTorch's integer->floating point + # type promotion. + # + # This is necessary when NumPy promotes + # integer types to double, since PyTorch promotes + # integer types to the default scalar type. + + # Helper to determine if promotion is needed + def is_integral(dtype): + return dtype in [ + np.bool_, + bool, + np.uint8, + np.int8, + np.int16, + np.int32, + np.int64, + ] + + @wraps(fn) + def wrapped_fn(x): + # As the default dtype can change, acquire it when function is called. + # NOTE: Promotion in PyTorch is from integer types to the default dtype + np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()] + + if is_integral(x.dtype): + return fn(x.astype(np_dtype)) + return fn(x) + + return wrapped_fn + + +def reference_reduction_numpy(f, supports_keepdims=True): + """Wraps a NumPy reduction operator. + + The wrapper function will forward dim, keepdim, mask, and identity + kwargs to the wrapped function as the NumPy equivalent axis, + keepdims, where, and initiak kwargs, respectively. + + Args: + f: NumPy reduction operator to wrap + supports_keepdims (bool, optional): Whether the NumPy operator accepts + keepdims parameter. If it does not, the wrapper will manually unsqueeze + the reduced dimensions if it was called with keepdim=True. Defaults to True. + + Returns: + Wrapped function + + """ + + @wraps(f) + def wrapper(x: npt.NDArray, *args, **kwargs): + # Copy keys into a set + keys = set(kwargs.keys()) + + dim = kwargs.pop("dim", None) + keepdim = kwargs.pop("keepdim", False) + + if "dim" in keys: + dim = tuple(dim) if isinstance(dim, Sequence) else dim + + # NumPy reductions don't accept dim=0 for scalar inputs + # so we convert it to None if and only if dim is equivalent + if x.ndim == 0 and dim in {0, -1, (0,), (-1,)}: + kwargs["axis"] = None + else: + kwargs["axis"] = dim + + if "keepdim" in keys and supports_keepdims: + kwargs["keepdims"] = keepdim + + if "mask" in keys: + mask = kwargs.pop("mask") + if mask is not None: + assert mask.layout == torch.strided + kwargs["where"] = mask.cpu().numpy() + + if "identity" in keys: + identity = kwargs.pop("identity") + if identity is not None: + if identity.dtype is torch.bfloat16: + identity = identity.cpu().to(torch.float32) + else: + identity = identity.cpu() + kwargs["initial"] = identity.numpy() + + result = f(x, *args, **kwargs) + + # Unsqueeze reduced dimensions if NumPy does not support keepdims + if keepdim and not supports_keepdims and x.ndim > 0: + dim = list(range(x.ndim)) if dim is None else dim + result = np.expand_dims(result, dim) + + return result + + return wrapper + + +def prod_numpy(a, *args, **kwargs): + """ + The function will call np.prod with type as np.int64 if the input type + is int or uint64 if is uint. This is necessary because windows np.prod uses by default + int32 while on linux it uses int64. + This is for fixing integer overflow https://github.com/pytorch/pytorch/issues/77320 + + Returns: + np.prod of input + """ + if "dtype" not in kwargs: + if np.issubdtype(a.dtype, np.signedinteger): + a = a.astype(np.int64) + elif np.issubdtype(a.dtype, np.unsignedinteger): + a = a.astype(np.uint64) + + fn = reference_reduction_numpy(np.prod) + return fn(a, *args, **kwargs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9125ba0ebe7e0623a12ad1a1cd7eeb7d2749a3a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__init__.py @@ -0,0 +1,7 @@ +# mypy: ignore-errors + +from .make_fx import make_fx_check +from .aot_autograd import aot_autograd_check, _test_aot_autograd_forwards_backwards_helper +from .fake_tensor import fake_check +from .autograd_registration import autograd_registration_check +from .generate_tests import generate_opcheck_tests, opcheck, OpCheckError, dontGenerateOpCheckTests, is_inside_opcheck_mode diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/aot_autograd.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/aot_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4d05a95a33e262e19efbb4cbb0d3a01d3dbf3b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/aot_autograd.py @@ -0,0 +1,175 @@ +# mypy: ignore-errors + +import torch +import torch.utils._pytree as pytree +from torch.testing._utils import wrapper_set_seed +from functorch.compile import compiled_function, min_cut_rematerialization_partition, default_partition, nop +from .make_fx import randomize +import re + + +class assert_raises_regex: + def __init__(self, exception_cls, regex): + self.exception_cls = exception_cls + self.regex = regex + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, traceback): + if exc_type == self.exception_cls: + msg = str(exc_val) + if not re.search(self.regex, msg): + raise AssertionError( + f"Expected exception to match regex. regex: {self.regex}, exception: {msg}") + return True # Squashes the exception + if exc_type is not None: + raise AssertionError( + f"Expected {self.exception_cls} to be raised, instead got exception {exc_type}") + raise AssertionError("Expected exception to be raised but none was") + + +def aot_autograd_check( + func, + args, + kwargs, + dynamic, + assert_raises_regex_fn=assert_raises_regex, + assert_equals_fn=torch.testing.assert_close, + check_gradients=True, + try_check_data_specialization=False, + skip_correctness_check=False, + disable_functionalization=False): + """Compares func(*args, **kwargs) in eager-mode to under AOTAutograd. + + Compares outputs and (if check_gradients=True) gradients produced by + AOTAutograd against eager-mode PyTorch. + + We assume that func(*args, **kwargs) succeeds in eager-mode PyTorch. + + """ + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + args = [arg for arg in flat_args if isinstance(arg, torch.Tensor)] + + # We construct a new function that only accepts Tensors as inputs + def func_no_tensors(args): + reconstructed_flat_args = [] + args = iter(args) + for v in flat_args: + if isinstance(v, torch.Tensor): + reconstructed_flat_args.append(next(args)) + else: + reconstructed_flat_args.append(v) + + c_args, c_kwargs = pytree.tree_unflatten(reconstructed_flat_args, args_spec) + return func(*c_args, **c_kwargs) + + # cannot use the min cut partitioner without functionalization + if disable_functionalization: + compiled_f = compiled_function( + func_no_tensors, + nop, + nop, + dynamic=dynamic, + partition_fn=default_partition, + keep_inference_input_mutations=True, + disable_functionalization=True + ) + else: + compiled_f = compiled_function( + func_no_tensors, + nop, + nop, + dynamic=dynamic, + partition_fn=min_cut_rematerialization_partition, + keep_inference_input_mutations=True, + disable_functionalization=False + ) + + out = wrapper_set_seed(func_no_tensors, args) + if check_gradients == "auto": + any_tensor_requires_grad = pytree.tree_any_only(torch.Tensor, lambda x: x.requires_grad, args) + any_output_requires_grad = pytree.tree_any_only(torch.Tensor, lambda x: x.requires_grad, out) + check_gradients = any_tensor_requires_grad and any_output_requires_grad + if not check_gradients: + compiled_out = wrapper_set_seed(compiled_f, args) + if not skip_correctness_check: + assert_equals_fn(compiled_out, out, msg=outputs_msg) + return + _test_aot_autograd_forwards_backwards_helper( + func_no_tensors, compiled_f, args, assert_raises_regex_fn, assert_equals_fn, + try_check_data_specialization, skip_correctness_check) + +outputs_msg = ( + "Outputs of the operator are different in eager-mode PyTorch vs " + "AOTDispatcher tracing. This means the operator will have incorrect output " + "underneath torch.compile. This could be because the operator's " + "implementation not traceable." +) + + +def _test_aot_autograd_forwards_backwards_helper( + f, compiled_f, args, assert_raises_regex_fn, assert_equals_fn, + try_check_data_specialization, skip_correctness_check=False): + # Verify grads are equal between compiled and non-compiled versions of f. + + def call_forwards_backwards(f, args): + flat_args = pytree.arg_tree_leaves(*args) + diff_args = [arg for arg in flat_args if isinstance(arg, torch.Tensor) and + arg.requires_grad] + out = wrapper_set_seed(f, args) + flat_out = pytree.tree_leaves(out) + + sm = 0 + for i in flat_out: + if isinstance(i, torch.Tensor): + # We need to call .abs() because it is possible that the output of the + # operator is a complex Tensor and autograd will yell at autograd.grad + # on a complex Tensor unless we manually provide the grad_output flag. + sm += i.sum().abs() + assert isinstance(sm, torch.Tensor) + return out, torch.autograd.grad(sm, diff_args, allow_unused=True) + + def check(args, ignore_failure=False): + try: + orig_out, orig_grad = call_forwards_backwards(f, args) + except Exception: + if ignore_failure: + return + raise + + # See https://github.com/pytorch/pytorch/pull/98960#issuecomment-1505962215 + tensor_args = [x for x in pytree.tree_flatten(args)[0] if isinstance(x, torch.Tensor)] + any_non_leaves = any(x.grad_fn is not None for x in tensor_args) + if all(x is None for x in orig_grad) and any_non_leaves: + with assert_raises_regex_fn(RuntimeError, 'does not require grad and does not have a grad_fn'): + call_forwards_backwards(compiled_f, args) + return + + msg = ( + "Gradients of the operator are different in eager-mode PyTorch vs " + "AOTDispatcher. This means the operator will have incorrect gradients " + "underneath torch.compile. This could be because the operator's " + "backward is incorrectly registered or not traceable." + ) + + compiled_out, compiled_grad = call_forwards_backwards(compiled_f, args) + if not skip_correctness_check: + try: + assert_equals_fn(compiled_out, orig_out) + except Exception as e: + raise type(e)(outputs_msg) from e + try: + assert_equals_fn(compiled_grad, orig_grad) + except Exception as e: + raise type(e)(msg) from e + + check(args, ignore_failure=False) + + # Randomize the data and run the traced graph with it, to catch bugs + # where we may have baked in Tensor data into the trace. + # This is not guaranteed to succeed, because `f` might have preconditions + # on the values of the inputs, so we just ignore if this test fails. + if try_check_data_specialization: + args = randomize(args) + check(args, ignore_failure=True) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/autograd_registration.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/autograd_registration.py new file mode 100644 index 0000000000000000000000000000000000000000..ae5ae34059eaa3d7ae1197699638f52f86538b02 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/autograd_registration.py @@ -0,0 +1,134 @@ +# mypy: ignore-errors + +import contextlib + +import torch +import torch.utils._pytree as pytree + + +@contextlib.contextmanager +def set_autograd_fallback_mode(mode): + prev = torch._C._get_autograd_fallback_mode() + try: + torch._C._set_autograd_fallback_mode(mode) + yield + finally: + torch._C._set_autograd_fallback_mode(prev) + + +def autograd_registration_check(op, args, kwargs): + """Check if autograd was registered correctly (for the operator). + + Operators should have "autograd support" registered directly to an + autograd dispatch key. + An incorrect registration may lead to unexpected silent incorrectness. + Note that this check won't catch all problems but will catch + the most common ones. + + Example usage: + >>> x = torch.randn(3, requires_grad=True) + >>> autograd_registration_check(torch.ops.aten.sin.default, (x,), {}) + + Here are some best practices if you do find your autograd is + registered incorrectly: + - If the operator is composite (i.e. consists of other PyTorch ops) + and you wish the operator to decompose and get autograd support + that way, then please register the implementation to + DispatchKey::CompositeImplicitAutograd + - If you're adding an autograd formula for the operator, the correct + thing to do is to register an autograd.Function to + DispatchKey::Autograd (preferred) or one of the + DispatchKey::Autograd keys. It is NOT OK to register + an autograd.Function to a backend (e.g. CPU/CUDA) key. + - If your operator is non-differentiable, then you should register + an implementation to the Autograd key that uses + AutoDispatchBelowAutograd and re-invokes the operator. + + """ + assert isinstance(op, torch._ops.OpOverload) + # Implementation details + # ----------------------------------------------- + # If an operator doesn't have an autograd kernel at an autograd key, + # and the operator does not return inputs as-is, then all of + # the outputs should have requires_grad=False before we apply + # special behaviors of our default autograd fallback. + # (The default autograd fallback may set requires_grad=True on output + # tensors in certain modes so that when they are backpropped through, + # they raise an error). + # + # Our strategy for detecting if an operator doesn't have an autograd + # kernel at the autograd key is: + # - set the autograd fallback mode to "nothing" (so it does not change + # the required-gradness of outputs) + # - run the operator + # - Check if any outputs of the operator (that are not inputs) require + # grad. This would only happen if the user calls regular PyTorch + # operations in their backend key (this op should instead be + # CompositeImplicitAutograd or not an op) or if the user invokes + # an autograd.Function in the backend key. + # + # Note that it's already likely a bug if the operator directly returns + # an input as output (because custom ops don't have a good way of + # constructing true in-place or out variants), but we defer that + # responsibility to a different test (schema_check). + + flat_args = pytree.arg_tree_leaves(*args, **kwargs) + all_tensors = [arg for arg in flat_args if isinstance(arg, torch.Tensor)] + if not any(t.requires_grad for t in all_tensors): + raise RuntimeError( + "autograd_registration_check: no inputs have requires_grad=True so " + "we are unable to actually perform this test. Please pass inputs " + "that do require grad." + ) + + # Determine which AutogradBACKEND key to check + all_device_types = {arg.device.type for arg in all_tensors} + if not all_device_types.issubset(["cpu", "cuda", "xpu"]): + # Don't want to support other keys yet + raise NotImplementedError( + f"autograd_registration_check: NYI devices other than CPU/CUDA/XPU, got {all_device_types}" + ) + if "cuda" in all_device_types: + key = "AutogradCUDA" + elif "cpu" in all_device_types: + key = "AutogradCPU" + elif "xpu" in all_device_types: + key = "AutogradXPU" + + if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), key): + return + if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), "Autograd"): + return + if torch._C._dispatch_has_kernel_for_dispatch_key( + op.name(), "CompositeImplicitAutograd" + ): + return + + # At this point, we know the operator doesn't have a kernel registered to an + # autograd key. Let's proceed with our test. + with set_autograd_fallback_mode("nothing"): + all_outs = op(*args, **kwargs) + + inp_ids = {id(arg) for arg in flat_args} + + def not_an_input_and_requires_grad(tensor): + if not tensor.requires_grad: + return False + if id(tensor) in inp_ids: + return False + return True + + if not pytree.tree_any_only(torch.Tensor, not_an_input_and_requires_grad, all_outs): + return + + raise AssertionError( + f"{op.name()}: at least one output of this operator has requires_grad=True " + f"but the operator does not have an autograd kernel defined at an autograd " + f"key (e.g. DispatchKey::Autograd). This could mean that you have " + f"incorrectly registered an autograd kernel to a non-Autograd DispatchKey, " + f"which may lead to silently incorrect results. If your operator consists " + f"of regular PyTorch operations, consider not using an operator at all " + f"or registering your operator as CompositeImplicitAutograd. If you have " + f"an autograd.Function registered to a backend (CPU/CUDA/XPU) key, the correct " + f"location for it is the Autograd key." + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/fake_tensor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/fake_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..5e60f50189b5dc3ab43fdd97120d5fa23559a84e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/fake_tensor.py @@ -0,0 +1,12 @@ +# mypy: ignore-errors + +import torch._subclasses + + +def is_builtin(op): + return op.namespace in ('aten', 'prims', 'prim') + + +def fake_check(op, args, kwargs): + with torch._subclasses.CrossRefFakeMode(ignore_op_fn=is_builtin): + op(*args, **kwargs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/generate_tests.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/generate_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..398425853f09adccce056b4115042f9379a1a9b3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/generate_tests.py @@ -0,0 +1,852 @@ +# mypy: ignore-errors + +import datetime +import difflib +import functools +import inspect +import json +import os +import re +import tempfile +import threading +import unittest +from collections.abc import Callable, Sequence +from typing import Any, Optional, Union + +import torch +import torch._dynamo +import torch.utils._pytree as pytree +from torch._dynamo.utils import clone_input +from torch._library.custom_ops import CustomOpDef +from torch._subclasses.schema_check_mode import SchemaCheckMode +from torch._utils_internal import get_file_path_2 +from torch.overrides import TorchFunctionMode +from torch.testing._internal.optests import ( + aot_autograd_check, + autograd_registration_check, + fake_check, +) + + +def dontGenerateOpCheckTests(reason: str): + def inner(fun): + fun._torch_dont_generate_opcheck_tests = True + return fun + + return inner + + +def is_abstract(tensor: torch.Tensor) -> bool: + if tensor.is_meta: + return True + if torch._subclasses.fake_tensor.is_fake(tensor): + return True + return False + + +def safe_schema_check( + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + copy_inputs: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> Any: + if copy_inputs: + args, kwargs = deepcopy_tensors((args, kwargs)) + if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): + return None + with SchemaCheckMode(): + result = op(*args, **kwargs) + return result + + +def safe_autograd_registration_check( + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + copy_inputs: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> None: + if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): + return + if copy_inputs: + args, kwargs = deepcopy_tensors((args, kwargs)) + # Don't perform autograd_registration_check if none of the inputs require grad. + if not pytree.tree_any_only( + torch.Tensor, lambda x: x.requires_grad, (args, kwargs) + ): + return + return autograd_registration_check(op, args, kwargs) + + +def safe_fake_check( + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + copy_inputs: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> None: + if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): + return None + if copy_inputs: + args, kwargs = deepcopy_tensors((args, kwargs)) + return fake_check(op, args, kwargs) + + +def safe_aot_autograd_check( + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + dynamic: bool, + *, + copy_inputs: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> Any: + # NB: copy_inputs does nothing for aot_autograd_check: it always needs to copy + # inputs. + if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): + return None + + def func(*args, **kwargs): + args, kwargs = pytree.tree_map_only(torch.Tensor, torch.clone, (args, kwargs)) + return op(*args, **kwargs) + + # aot_autograd_check runs func(*args, **kwargs) multiple times + # and assumes `func` does not modify its inputs. + if rtol and atol: + assert_equals_fn = functools.partial( + torch.testing.assert_close, rtol=rtol, atol=atol + ) + else: + assert_equals_fn = torch.testing.assert_close + return aot_autograd_check( + func, + args, + kwargs, + dynamic, + check_gradients="auto", + assert_equals_fn=assert_equals_fn, + ) + + +def deepcopy_tensors(inputs: Any) -> Any: + return pytree.tree_map_only(torch.Tensor, clone_input, inputs) + + +# Test util requirements +# - The test util must have signature (op: OpOverload, args, kwargs) +# - The test util must NOT mutate args, kwargs. +# - The test utils in this list must not be prefixes of each other. For example, +# having both "test_schema" and "test_schema_is_functional" is NOT OK. +# - The order of items in this dict matters (for opcheck), we'll run them +# in order. +ALL_TEST_UTILS = { + "test_schema": safe_schema_check, + "test_autograd_registration": safe_autograd_registration_check, + "test_faketensor": safe_fake_check, + "test_aot_dispatch_static": functools.partial( + safe_aot_autograd_check, + dynamic=False, + ), + "test_aot_dispatch_dynamic": functools.partial( + safe_aot_autograd_check, + dynamic=True, + ), +} + +GDOC = "https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit" + +DEFAULT_TEST_UTILS = [ + "test_schema", + "test_autograd_registration", + "test_faketensor", + "test_aot_dispatch_dynamic", +] + +DEPRECATED_DEFAULT_TEST_UTILS = DEFAULT_TEST_UTILS + [ + "test_aot_dispatch_static", +] + + +def generate_opcheck_tests( + testcase: Any, + namespaces: list[str], + failures_dict_path: Optional[str] = None, + additional_decorators: Optional[dict[str, Callable]] = None, + test_utils: list[str] = DEFAULT_TEST_UTILS, +) -> None: + """Given an existing TestCase, use the existing tests to generate + additional validation tests for custom operators. + + For {all existing tests in the TestCase} x {all test utils}, + we will generate one new test. The new test runs a TorchFunctionMode + that intercepts ``op(*args, **kwargs)`` calls and invokes + ``test_util(op, *args, **kwargs)``, where ``op`` is an operator. + + The test_util that we support are in ALL_TEST_UTILS. They are: + - test_schema: This runs SchemaCheckMode. + - test_autograd_registration: This runs autograd_registration_check. + - test_faketensor: This runs CrossRefFakeMode. + - test_aot_dispatch_static: This runs aot_autograd_check, which: + checks that the outputs (and gradients, if they are computable) + are the same under eager-mode PyTorch and using AOTAutograd. + - test_aot_dispatch_dynamic: Same as aot_dispatch_static, but + runs AOTAutograd using dynamic shapes instead of static shapes. + + The generated test will have name ``{test_util}__{original_name}``. + For example, if there is a method named ``test_cumsum``, then + we will generate a ``test_schema__test_cumsum``, + ``test_faketensor__test_cumsum``, etc. + + For more details, see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit + + Args: + testcase: The testcase we will modify and generate additional tests for. + namespaces: We will only intercept calls to custom operators with these + namespaces. + failures_dict_path: See ``validate_failures_dict_structure`` for more details + test_utils: a list of test_utils to generate. Example: ["test_schema", "test_faketensor"] + """ + if additional_decorators is None: + additional_decorators = {} + test_methods = [ + m + for m in dir(testcase) + if m.startswith("test_") and callable(getattr(testcase, m)) + ] + if failures_dict_path is None: + # The default failures_dict_path is failures_dict.json in + # the same directory as the test file. + prev_frame = inspect.currentframe().f_back + filename = inspect.getframeinfo(prev_frame)[0] + failures_dict_path = get_file_path_2( + os.path.dirname(filename), "failures_dict.json" + ) + failures_dict = FailuresDict.load( + failures_dict_path, create_file=should_update_failures_dict() + ) + validate_failures_dict_structure(failures_dict, test_utils, testcase) + validate_failures_dict_formatting(failures_dict_path) + + def construct_method(attr, prefix, tester): + method = getattr(testcase, attr) + if getattr(method, "_torch_dont_generate_opcheck_tests", False): + return + new_method_name = prefix + "__" + attr + + @functools.wraps(method) + def new_method(*args, **kwargs): + with OpCheckMode( + namespaces, + prefix, + tester, + failures_dict, + f"{testcase.__name__}.{new_method_name}", + failures_dict_path, + ): + result = method(*args, **kwargs) + return result + + if pytestmark := new_method.__dict__.get("pytestmark"): + import pytest + + # check if we need to simplify the parametrize marks + # NB: you need to add this mark to your pytest.ini + opcheck_only_one = False + for mark in pytestmark: + if isinstance(mark, pytest.Mark) and mark.name == "opcheck_only_one": + opcheck_only_one = True + + if opcheck_only_one: + new_pytestmark = [] + for mark in pytestmark: + if isinstance(mark, pytest.Mark) and mark.name == "parametrize": + argnames, argvalues = mark.args + assert not mark.kwargs, "NYI" + # Special case for device, we want to run on all + # devices + if argnames != "device": + new_pytestmark.append( + pytest.mark.parametrize( + argnames, (next(iter(argvalues)),) + ) + ) + continue + new_pytestmark.append(mark) + new_method.__dict__["pytestmark"] = new_pytestmark + + if new_method_name in additional_decorators: + for dec in additional_decorators[new_method_name]: + new_method = dec(new_method) + + if hasattr(testcase, new_method_name): + raise RuntimeError( + f"Tried to autogenerate {new_method_name} but {testcase} already " + f"has method named {new_method_name}. Please rename the original " + f"method on the TestCase." + ) + setattr(testcase, new_method_name, new_method) + + test_utils = {name: ALL_TEST_UTILS[name] for name in test_utils} + for attr in test_methods: + for prefix, tester in test_utils.items(): + construct_method(attr, prefix, tester) + + generate_tag_tests(testcase, failures_dict, additional_decorators) + + +def generate_tag_tests(testcase, failures_dict, additional_decorators): + def generate_test(qualname, definitely_not_pt2_compliant, xfailed_tests): + def inner(self): + try: + op = torch._library.utils.lookup_op(qualname) + except AttributeError as e: + # Operator not importable in this test file + raise unittest.SkipTest(f"Can't import operator {qualname}") from e + op_marked_as_compliant = torch.Tag.pt2_compliant_tag in op.tags + if not op_marked_as_compliant: + return + if not definitely_not_pt2_compliant: + return + raise AssertionError( + f"op '{qualname}' was tagged with torch.Tag.pt2_compliant_tag " + f"but it failed some of the generated opcheck tests " + f"({xfailed_tests}). This may lead to silent correctness issues, " + f"please fix this." + ) + + return inner + + for qualname, test_dict in failures_dict.data.items(): + xfailed_tests = [ + test + for test, status_dict in test_dict.items() + # We're about to delete the following test after Ed's PR + # to specialize on C++ .size() calls + if "test_aot_dispatch_static" not in test + and status_dict["status"] == "xfail" + ] + definitely_not_pt2_compliant = len(xfailed_tests) > 0 + generated = generate_test(qualname, definitely_not_pt2_compliant, xfailed_tests) + + # Could result in collisions, but unlikely. We'll raise if we see one below. + mangled_qualname = qualname.replace("::", "_").replace(".", "_") + test_name = "test_pt2_compliant_tag_" + mangled_qualname + + # You can skip this test via the additional_decorators argument + # in generate_opcheck_tests + if test_name in additional_decorators: + for decorator in additional_decorators[test_name]: + generated = decorator(generated) + + if hasattr(testcase, test_name): + raise RuntimeError( + f"Tried to generate a test named {test_name}, but it exists " + f"already. This could be because of a name collision (where " + f"we generated two tests with the same name), or where we " + f"generated a test with the same name as an existing test." + ) + setattr(testcase, test_name, generated) + + +TEST_OPTIONS = ("xfail", "skip", "xsuccess") + + +def validate_failures_dict_formatting(failures_dict_path: str) -> None: + with open(failures_dict_path) as fp: + actual = fp.read() + failures_dict = FailuresDict.load(failures_dict_path) + expected = failures_dict._save(to_str=True) + if actual == expected: + return + if should_update_failures_dict(): + failures_dict = FailuresDict.load(failures_dict_path) + failures_dict.save() + return + expected = expected.splitlines(1) + actual = actual.splitlines(1) + diff = difflib.unified_diff(actual, expected) + diff = "".join(diff) + raise RuntimeError( + f"\n{diff}\n\nExpected the failures dict to be formatted " + f"a certain way. Please see the above diff; you can correct " + f"this either manually or by re-running the test with " + f"PYTORCH_OPCHECK_ACCEPT=1" + ) + + +def validate_failures_dict_structure( + failure_dict: "FailuresDict", test_utils: list[str], testcase: Any +) -> None: + """Validates the failures dict. + + The failure dict looks something like the following. + It maps operator name (qualname) to a list of autogenerated tests. + Each autogenerated test may have a check for the operator (if the operator is + called by the test); the dictionary specifies if we should skip the check, + or if we expect some check to fail. + + { + "fbgemm::split_lengths": { + "test_schema__test_split_lengths": { + "comment": "you can put whatever you want into the comment section", + "status": "xfail", + } + "test_schema__test_split_lengths_empty": { + "comment": "", + "status": "skip", + }, + }, + "fbgemm::gather_lengths": { + "test_schema__test_gather_lengths": { + "comment": "", + "status": "skip", + }, + }, + } + + """ + failure_dict = failure_dict.data + for test_to_option in failure_dict.values(): + for test_name, test_dict in test_to_option.items(): + if set(test_dict.keys()) != set({"comment", "status"}): + raise RuntimeError( + "in failures_dict, expected sub-dict to have keys 'comment' and 'status'" + ) + test_option = test_dict["status"] + if test_option not in TEST_OPTIONS: + raise RuntimeError( + f"In failures_dict, got status={test_option} but it needs to be in {TEST_OPTIONS}" + ) + test_class, actual_test_name = test_name.split(".") + if not any(actual_test_name.startswith(test) for test in test_utils): + raise RuntimeError( + f"In failures_dict, test name '{test_name}' should begin with one of {test_utils}" + ) + for test in test_utils: + if not actual_test_name.startswith(test): + continue + base_test_name = actual_test_name[len(test) + 2 :] + # remove potential pytest parametrization suffix + base_test_name = re.sub(r"\[.*\]", "", base_test_name) + if testcase.__name__ != test_class: + continue + if hasattr(testcase, base_test_name): + continue + raise RuntimeError( + f"In failures dict, got test name '{test_name}'. We parsed this as " + f"running test '{test}' on '{base_test_name}', but " + f"{base_test_name} does not exist on the TestCase '{testcase.__name__}]. " + f"Maybe you need to change the test name?" + ) + + +def should_update_failures_dict() -> bool: + key = "PYTORCH_OPCHECK_ACCEPT" + return key in os.environ and os.environ[key] == "1" + + +_is_inside_opcheck_mode = threading.local() +_is_inside_opcheck_mode.value = False + + +def is_inside_opcheck_mode(): + return _is_inside_opcheck_mode.value + + +class OpCheckMode(TorchFunctionMode): + """ + For a given test, OpCheckMode intercepts calls to operators and runs + test_util(op, args, kwargs) for each intercepted (op, args, kwargs). + """ + + def __init__( + self, + namespaces: list[str], + test_util_name: str, + test_util: Callable, + failures_dict: "FailuresDict", + test_name: str, + failures_dict_path: str, + ): + # We will intercept calls to ops with these namespaces + self.namespaces = namespaces + # The test utility function. Its signature should be (op, args, kwargs) -> None. + # Examples of test utilities are: schema_check, make_fx_check + self.test_util = test_util + self.test_util_name = test_util_name + # The name of the test that is running this OpCheckMode. + self.test_name = test_name + # Maps qualname -> test_name -> skip/xfail + # Tells us if we should skip a test or assert that there is a failure. + self.failures_dict = failures_dict + # Location of the failures dict. Makes it so that the error message is better. + self.failures_dict_path = failures_dict_path + + # OpCheckMode suppresses errors, collects them here, and then raises them on exit. + # Maps qualname -> List[(Exception, func, maybe args, maybe kwargs)] + self.seen_ops_to_errors = {} + + def maybe_raise_errors_on_exit(self) -> None: + # Check expected failures first + for qualname in self.seen_ops_to_errors: + option = self.failures_dict.get_status(qualname, self.test_name) + if len(self.seen_ops_to_errors[qualname]) == 0: + if should_update_failures_dict(): + self.failures_dict.set_status( + qualname, self.test_name, "xsuccess", comment="" + ) + else: + if option == "xfail": + raise OpCheckError( + f"generate_opcheck_tests: Unexpected success for operator " + f"{qualname} on test {self.test_name}. This may mean that " + f"you have fixed this test failure. Please rerun the test with " + f"PYTORCH_OPCHECK_ACCEPT=1 to automatically update the test runner " + f"or manually remove the " + f"expected failure in the failure dict at " + f"{self.failures_dict_path}" + f"For more details, see " + f"{GDOC}" + ) + continue + failed_ops = [] + for qualname in self.seen_ops_to_errors: + option = self.failures_dict.get_status(qualname, self.test_name) + if option != "xsuccess": + continue + if len(self.seen_ops_to_errors[qualname]) == 0: + continue + failed_ops.append(qualname) + if not failed_ops: + return + + if should_update_failures_dict(): + for op in failed_ops: + self.failures_dict.set_status(op, self.test_name, "xfail") + return + + # Raise from the first error but also report about all of them to make + # recording xfails easier. + ex, op, args, kwargs = self.seen_ops_to_errors[failed_ops[0]][0] + repro_command = generate_repro( + self.test_util_name, op, args, kwargs, save_data=should_print_better_repro() + ) + raise OpCheckError( + f"Test generated by `generate_opcheck_tests`, {self.test_name}, " + f"failed on operators {failed_ops}. This usually means that the " + f"operators are not implemented correctly and may lead to silently " + f"incorrect behavior. Set PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1 for a standalone repro, " + f"or please see " + f"{GDOC} " + f"for more recommendations. " + f"To reproduce this problem locally, try to run the following:\n{repro_command}" + ) from ex + + def __enter__(self, *args, **kwargs): + self.prev_is_opcheck_mode = _is_inside_opcheck_mode.value + self.prev_dynamo_disable = os.environ.get("TORCHDYNAMO_DISABLE", "") + _is_inside_opcheck_mode.value = True + os.environ["TORCHDYNAMO_DISABLE"] = "1" + return super().__enter__(*args, **kwargs) + + def __exit__(self, *args, **kwargs): + _is_inside_opcheck_mode.value = self.prev_is_opcheck_mode + os.environ["TORCHDYNAMO_DISABLE"] = self.prev_dynamo_disable + try: + self.maybe_raise_errors_on_exit() + if should_update_failures_dict(): + self.failures_dict.save() + finally: + result = super().__exit__(*args, **kwargs) + return result + + def run_test_util(self, op, args, kwargs): + try: + self.test_util(op, args, kwargs, copy_inputs=False) + except torch._subclasses.fake_tensor.UnsupportedFakeTensorException: + # We might get here if the input is already a FakeTensor + # or if we're in a torch.compile block. Just ignore these + # since we can't handle them and reporting them as failures + # is too noisy. + pass + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + + # Only intercept calls to operators + if not isinstance(func, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)): + return func(*args, **kwargs) + if ( + torch.jit.is_tracing() + or torch.jit.is_scripting() + or torch._dynamo.is_compiling() + ): + return func(*args, **kwargs) + # Pre-existing code may not use the .default overload. If we see an + # OpOverloadPacket and we cannot resolve the overload, then we just throw + # and ask the user to clarify. Otherwise, we attempt to resolve the overload. + if isinstance(func, torch._ops.OpOverloadPacket): + func = resolve_unique_overload_or_throw(func) + qualname = func.name() + ns = qualname.split("::")[0] + if ns not in self.namespaces: + return func(*args, **kwargs) + + args_c, kwargs_c = deepcopy_tensors((args, kwargs)) + result = func(*args, **kwargs) + + option = self.failures_dict.get_status(qualname, self.test_name) + if option == "xsuccess" or option == "xfail": + # Suppress all errors during execution. Raise them during __exit__. + try: + if qualname not in self.seen_ops_to_errors: + self.seen_ops_to_errors[qualname] = [] + self.run_test_util(func, args_c, kwargs_c) + except Exception as ex: + if should_print_better_repro(): + self.seen_ops_to_errors[qualname].append((ex, func, args, kwargs)) + else: + self.seen_ops_to_errors[qualname].append((ex, func, None, None)) + elif option == "skip": + pass + return result + + +def should_print_better_repro() -> None: + """If set, the tests generated by `generate_opcheck_tests` will print a + repro command on failure. + + In order to print the repro command, we need to save some tensors to disk. + These will be saved under the following directory: + {tempfile.gettempdir()}/pytorch_opcheck_safe_to_delete/. + + Although this is a temp folder, it will usually not automatically get cleaned + up, so you'll need to manually delete it. + """ + key = "PYTORCH_OPCHECK_PRINT_BETTER_REPRO" + if key not in os.environ: + return False + value = os.environ[key] + return value == "1" or value == 1 + + +def opcheck( + op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef], + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + test_utils: Union[str, Sequence[str]] = DEFAULT_TEST_UTILS, + raise_exception: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> dict[str, str]: + """See torch.library.opcheck for docstring""" + + if (rtol is None) ^ (atol is None): + raise ValueError( + "opcheck(op, ...): if you specify one of rtol/atol, you must specify both" + ) + + if kwargs is None: + kwargs = {} + if isinstance(op, CustomOpDef): + op = op._opoverload + if isinstance(op, torch._ops.OpOverloadPacket): + op = resolve_unique_overload_or_throw(op) + if not isinstance(op, torch._ops.OpOverload): + raise ValueError( + f"opcheck(op, ...): op must be instance of torch._ops.OpOverload, " + f"e.g. torch.ops.aten.sin.default, got {type(op)}" + ) + if test_utils == "ALL": + test_utils = tuple(ALL_TEST_UTILS.keys()) + if isinstance(test_utils, str): + test_utils = (test_utils,) + if not isinstance(test_utils, (tuple, list)) or not set(test_utils).issubset( + ALL_TEST_UTILS.keys() + ): + raise ValueError( + f"opcheck(op, ..., test_utils={test_utils}), expected test_utils " + f"to be subset of {tuple(ALL_TEST_UTILS.keys())} but it was not" + ) + + results_dict = {} + for test_util in test_utils: + tester = ALL_TEST_UTILS[test_util] + try: + tester(op, args, kwargs, rtol=rtol, atol=atol) + results_dict[test_util] = "SUCCESS" + except Exception as ex: + if raise_exception: + raise OpCheckError( + f"opcheck(op, ...): {test_util} failed with {ex} " + f"(scroll up for stack trace)" + ) from ex + results_dict[test_util] = ex + return results_dict + + +class OpCheckError(Exception): + pass + + +def generate_repro( + test: str, + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + save_data: bool, + dry_run: bool = False, +) -> str: + if save_data: + now = datetime.datetime.now() + path = os.path.join(tempfile.gettempdir(), "pytorch_opcheck_safe_to_delete") + unix_timestamp = datetime.datetime.timestamp(now) * 100000 + filepath = os.path.join(path, f"repro_{unix_timestamp}.pt") + if not dry_run: + os.makedirs(path, exist_ok=True) + torch.save((args, kwargs), filepath) + args_kwargs = f'args, kwargs = torch.load("{filepath}")' + else: + args_kwargs = ( + "# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1\n" + "# we will fill them in same (args, kwargs) as in your test\n" + "args = () # args to the operator\n" + "kwargs = {} # kwargs to the operator" + ) + + ns, name = op._schema.name.split("::") + overload = op._overloadname + + repro_command = ( + f"# =========================================================\n" + f"# BEGIN REPRO SCRIPT\n" + f"# =========================================================\n" + f"import torch\n" + f"from torch.testing._internal.optests import opcheck\n" + f"\n" + f"# Make sure you have loaded the library that contains the op\n" + f"# via an import or torch.ops.load_library(...)\n" + f"op = torch.ops.{ns}.{name}.{overload}\n" + f"\n" + f"{args_kwargs}\n" + f'opcheck(op, args, kwargs, test_utils="{test}")\n' + f"# =========================================================\n" + f"# END REPRO SCRIPT\n" + f"# =========================================================\n" + ) + return repro_command + + +def resolve_unique_overload_or_throw( + op: torch._ops.OpOverloadPacket, +) -> torch._ops.OpOverload: + all_schemas = torch._C._jit_get_schemas_for_operator(op._qualified_op_name) + if len(all_schemas) != 1: + raise RuntimeError( + f"opcheck can only test operators without overloads. " + f"Got the following overloads for {op._qualified_op_name}: " + f"{[schema.overload_name for schema in all_schemas]}" + ) + + overload_name = all_schemas[0].overload_name + if overload_name == "": + return op.default + return getattr(op, overload_name) + + +DUMP_OPTIONS = {"indent": 2, "sort_keys": True} + + +FailuresDictData = dict[str, dict[str, dict[str, str]]] + + +VERSION = 1 +DESCRIPTION = ( + f"This is a dict containing failures for tests autogenerated by " + f"generate_opcheck_tests. " + f"For more details, please see {GDOC}" +) + + +class FailuresDict: + def __init__(self, path: str, data: FailuresDictData): + self.path = path + self.data = data + + @staticmethod + def load(path, *, create_file=False) -> "FailuresDict": + if create_file and not os.path.exists(path): + result = FailuresDict(path, {}) + FailuresDict.save() + return result + with open(path) as fp: + contents = fp.read() + if contents.strip() == "": + dct = { + "_description": DESCRIPTION, + "data": {}, + "_version": VERSION, + } + else: + dct = json.loads(contents) + assert "data" in dct + assert "_version" in dct and dct["_version"] == VERSION + return FailuresDict(path, dct["data"]) + + def _save(self, to_str=False) -> Optional[str]: + to_dump = { + "_description": DESCRIPTION, + "data": self.data, + "_version": VERSION, + } + # json.dumps doesn't end with a newline. Let's add one because files + # should end in newlines. + serialized = json.dumps(to_dump, **DUMP_OPTIONS) + "\n" + if to_str: + return serialized + with open(self.path, "w") as fp: + fp.write(serialized) + return None + + def save(self) -> None: + return self._save() + + def get_status(self, qualname: str, test_name: str) -> str: + if qualname not in self.data: + return "xsuccess" + dct = self.data[qualname] + if test_name not in dct: + return "xsuccess" + return dct[test_name]["status"] + + def set_status( + self, + qualname: str, + test_name: str, + status: str, + *, + comment: Optional[str] = None, + ): + if qualname not in self.data: + self.data[qualname] = {} + dct = self.data[qualname] + if test_name not in dct: + dct[test_name] = {"status": None, "comment": ""} + + if status == "xsuccess": + # The default status is "xsuccess". + del dct[test_name] + else: + dct[test_name]["status"] = status + if comment is not None: + dct[test_name]["comment"] = comment diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/make_fx.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/make_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..970a0be1b36956d3693a5a93d07dbf32027c9773 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/make_fx.py @@ -0,0 +1,89 @@ +# mypy: ignore-errors + +import torch +from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing._utils import wrapper_set_seed +import torch.utils._pytree as pytree + + +def make_fx_check( + func, + args, + kwargs, + tracing_mode, + assert_close=torch.testing.assert_close, + randomize_data=False, +): + f, *new_args = handle_sizes_for_dynamic_shapes(func, args, kwargs) + + def run(f, *args, **kwargs): + return wrapper_set_seed(f, *args, **kwargs) + + traced_f = make_fx(f, tracing_mode=tracing_mode)(*new_args) + + msg = ( + "op(*args, **kwargs) and make_fx(op)(*args, **kwargs) produced different " + "values. This could mean that your abstract impls (meta/FakeTensor impls) " + "are incorrect, that your operator is not completely traceable (e.g., " + "it relies on some global state), or that there is a bug in make_fx. " + "Note that if you passed a python function (and not an operator) to " + "make_fx_check, it is still possible that the python function will still " + "work with torch.compile because it handles capturing pieces of " + "your python code to compile." + ) + + # Randomize the data and run the traced graph with it, to catch bugs + # where we may have baked in Tensor data into the trace. + # This is not guaranteed to succeed, because `f` might have preconditions + # on the values of the inputs, so we just ignore if we used + # random data and it fails. + if randomize_data: + new_args = randomize(new_args) + try: + expected = run(f, *new_args) + except Exception: + if randomize_data: + return + raise + result = run(traced_f, *new_args) + assert_close(result, expected, msg=msg) + + +# Arguably we should make make_fx promote torch.Size() objects to symbolic shapes. +# Absent that, here is our strategy: +# +# If any argument is a torch.Size(), maybe get dynamic shapes for it by: +# - Create a temporary Tensor whose size is the torch.Size() we want. Note that +# we use an expanded Tensor as we cannot pass "meta" Tensors to make_fx. +# - Pass it to make_fx such that it is converted to a proxy Tensor +# - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in +# symbolic mode, a no-op otherwise) +def handle_sizes_for_dynamic_shapes(func, args, kwargs): + def f(args, kwargs, extra_args, extra_kwargs): + if extra_args: + for i, t in extra_args: + args[i] = t.size() + if extra_kwargs: + for k, t in extra_kwargs.items(): + kwargs[k] = t.size() + + return func(*args, **kwargs) + + extra_args = [] + extra_kwargs = {} + for i, arg in enumerate(args): + if isinstance(arg, torch.Size): + extra_args.append((i, torch.empty(arg, device="cpu"))) + for key, value in kwargs.items(): + if isinstance(value, torch.Size): + extra_kwargs[key] = torch.empty(value, device="cpu") + + return f, args, kwargs, extra_args, extra_kwargs + + +def randomize(args): + def transform(x): + if not x.dtype.is_floating_point: + return x + return x.detach().clone().uniform_(0, 1).requires_grad_(x.requires_grad) + return pytree.tree_map_only(torch.Tensor, transform, args) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/quantization_torch_package_models.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/quantization_torch_package_models.py new file mode 100644 index 0000000000000000000000000000000000000000..abc4ab6f7e4734361ec7ecea3d4755910f9cf2ab --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/quantization_torch_package_models.py @@ -0,0 +1,33 @@ +# mypy: ignore-errors + +import math + +import torch +import torch.nn as nn + + +class LinearReluFunctionalChild(nn.Module): + def __init__(self, N): + super().__init__() + self.w1 = nn.Parameter(torch.empty(N, N)) + self.b1 = nn.Parameter(torch.zeros(N)) + torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) + + def forward(self, x): + x = torch.nn.functional.linear(x, self.w1, self.b1) + x = torch.nn.functional.relu(x) + return x + +class LinearReluFunctional(nn.Module): + def __init__(self, N): + super().__init__() + self.child = LinearReluFunctionalChild(N) + self.w1 = nn.Parameter(torch.empty(N, N)) + self.b1 = nn.Parameter(torch.zeros(N)) + torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) + + def forward(self, x): + x = self.child(x) + x = torch.nn.functional.linear(x, self.w1, self.b1) + x = torch.nn.functional.relu(x) + return x diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/future_div.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/future_div.py new file mode 100644 index 0000000000000000000000000000000000000000..0a3494f945fad36d84cb8056dcf722d6911f0af2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/future_div.py @@ -0,0 +1,10 @@ +# mypy: ignore-errors + + + +def div_int_future(): + return 1 / 2 + + +def div_float_future(): + return 3.14 / 0.125 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/no_future_div.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/no_future_div.py new file mode 100644 index 0000000000000000000000000000000000000000..164e6d168414a11039f3b63885760ad08b81ae99 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/no_future_div.py @@ -0,0 +1,11 @@ +# mypy: ignore-errors + +import torch # noqa: F401 + + +def div_int_nofuture(): + return 1 / 2 + + +def div_float_nofuture(): + return 3.14 / 0.125 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/torchbind_impls.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/torchbind_impls.py new file mode 100644 index 0000000000000000000000000000000000000000..e5162ba0d6cb6729534ab28f8a84a906f8c99f87 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/torchbind_impls.py @@ -0,0 +1,194 @@ +# mypy: allow-untyped-defs +import contextlib +from pathlib import Path +from typing import Optional + +import torch + + +_TORCHBIND_IMPLS_INITIALIZED = False + +_TENSOR_QUEUE_GLOBAL_TEST: Optional[torch.ScriptObject] = None + + +def init_torchbind_implementations(): + global _TORCHBIND_IMPLS_INITIALIZED + global _TENSOR_QUEUE_GLOBAL_TEST + if _TORCHBIND_IMPLS_INITIALIZED: + return + + load_torchbind_test_lib() + register_fake_operators() + register_fake_classes() + _TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue() + _TORCHBIND_IMPLS_INITIALIZED = True + + +def _empty_tensor_queue() -> torch.ScriptObject: + return torch.classes._TorchScriptTesting._TensorQueue( + torch.empty( + 0, + ).fill_(-1) + ) + + +# put these under a function because the corresponding library might not be loaded yet. +def register_fake_operators(): + @torch.library.register_fake("_TorchScriptTesting::takes_foo_python_meta") + def fake_takes_foo(foo, z): + return foo.add_tensor(z) + + @torch.library.register_fake("_TorchScriptTesting::queue_pop") + def fake_queue_pop(tq): + return tq.pop() + + @torch.library.register_fake("_TorchScriptTesting::queue_push") + def fake_queue_push(tq, x): + return tq.push(x) + + torch.library.register_autocast( + "_TorchScriptTesting::queue_push", "cpu", torch.float32 + ) + torch.library.register_autocast( + "_TorchScriptTesting::queue_push", "cuda", torch.float32 + ) + + torch.library.register_autocast( + "_TorchScriptTesting::queue_pop", "cpu", torch.float32 + ) + torch.library.register_autocast( + "_TorchScriptTesting::queue_pop", "cuda", torch.float32 + ) + + @torch.library.register_fake("_TorchScriptTesting::queue_size") + def fake_queue_size(tq): + return tq.size() + + def meta_takes_foo_list_return(foo, x): + a = foo.add_tensor(x) + b = foo.add_tensor(a) + c = foo.add_tensor(b) + return [a, b, c] + + def meta_takes_foo_tuple_return(foo, x): + a = foo.add_tensor(x) + b = foo.add_tensor(a) + return (a, b) + + @torch.library.register_fake("_TorchScriptTesting::takes_foo_tensor_return") + def meta_takes_foo_tensor_return(foo, x): + # This implementation deliberately creates unbacked symint for testing + ctx = torch.library.get_ctx() + fake_shape = [ctx.new_dynamic_size() for _ in range(2)] + return torch.empty(fake_shape, dtype=torch.int, device="cpu") + + torch.ops._TorchScriptTesting.takes_foo_list_return.default.py_impl( + torch._C.DispatchKey.Meta + )(meta_takes_foo_list_return) + + torch.ops._TorchScriptTesting.takes_foo_tuple_return.default.py_impl( + torch._C.DispatchKey.Meta + )(meta_takes_foo_tuple_return) + + torch.ops._TorchScriptTesting.takes_foo.default.py_impl(torch._C.DispatchKey.Meta)( + # make signature match original cpp implementation to support kwargs + lambda foo, x: foo.add_tensor(x) + ) + + +def register_fake_classes(): + # noqa: F841 + @torch._library.register_fake_class("_TorchScriptTesting::_Foo") + class FakeFoo: + def __init__(self, x: int, y: int): + self.x = x + self.y = y + + @classmethod + def __obj_unflatten__(cls, flattend_foo): + return cls(**dict(flattend_foo)) + + def add_tensor(self, z): + return (self.x + self.y) * z + + @torch._library.register_fake_class("_TorchScriptTesting::_ContainsTensor") + class FakeContainsTensor: + def __init__(self, t: torch.Tensor): + self.t = t + + @classmethod + def __obj_unflatten__(cls, flattend_foo): + return cls(**dict(flattend_foo)) + + def get(self): + return self.t + + @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") + class FakeTensorQueue: + def __init__(self, queue): + self.queue = queue + + @classmethod + def __obj_unflatten__(cls, flattened_ctx): + return cls(**dict(flattened_ctx)) + + def push(self, x): + self.queue.append(x) + + def pop(self): + if self.is_empty(): + return torch.empty([]) + return self.queue.pop(0) + + def size(self): + return len(self.queue) + + def is_empty(self): + return len(self.queue) == 0 + + def float_size(self): + return float(len(self.queue)) + + @torch._library.register_fake_class("_TorchScriptTesting::_FlattenWithTensorOp") + class FakeFlatten: + def __init__(self, t): + self.t = t + + def get(self): + return self.t + + @classmethod + def __obj_unflatten__(cls, flattened_ctx): + return cls(**dict(flattened_ctx)) + + +def load_torchbind_test_lib(): + import unittest + + from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] + find_library_location, + IS_FBCODE, + IS_MACOS, + IS_SANDCASTLE, + IS_WINDOWS, + ) + + if IS_MACOS: + raise unittest.SkipTest("non-portable load_library call used in test") + elif IS_SANDCASTLE or IS_FBCODE: + lib_file_path = Path("//caffe2/test/cpp/jit:test_custom_class_registrations") + elif IS_WINDOWS: + lib_file_path = find_library_location("torchbind_test.dll") + else: + lib_file_path = find_library_location("libtorchbind_test.so") + torch.ops.load_library(str(lib_file_path)) + + +@contextlib.contextmanager +def _register_py_impl_temporarily(op_overload, key, fn): + try: + op_overload.py_impl(key)(fn) + yield + finally: + del op_overload.py_kernels[key] + op_overload._dispatch_cache.clear() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ccbf9ed0bfc9286b083d65faeaf5386e10d2655 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_appending_byte_serializer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_appending_byte_serializer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c17938ca82510cc0184ed27eee1e4ff1ae27c01 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_appending_byte_serializer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_config_module.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_config_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d06b68283f6218d95396636a76fb156aaeca3da Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_config_module.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_content_store.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_content_store.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..540cd6a5f5f98c5f8b62ce15d873867a03e87a6b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_content_store.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_contextlib.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_contextlib.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6543a53d034046f109d7cd3bcfcc70c2b4bebc1b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_contextlib.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_cpp_embed_headers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_cpp_embed_headers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56a440af81d989505b42866e5c323eb651619283 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_cpp_embed_headers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_cpp_extension_versioner.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_cpp_extension_versioner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2978ce02498f64d851a8e91476f28a491ddeb1e9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_cpp_extension_versioner.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_cxx_pytree.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_cxx_pytree.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..625f1b1f4523dd14f0b4b9f8c499eb4180758faf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_cxx_pytree.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_debug_mode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_debug_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71ae5d1392743040c0238045065305849184c03c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_debug_mode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_device.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_device.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b32fa5bf5e2b7a61c353396b79c8e846bb7da6b5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_device.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_dtype_abbrs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_dtype_abbrs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4f175945dabb97b386ee763a9d796bb6ab9bfe6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_dtype_abbrs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_exposed_in.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_exposed_in.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f52d1e69faada0b0c4a51d63ed44757f29e52258 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_exposed_in.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_filelock.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_filelock.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77d00c9979bc4dc6c5660fb836bc239ecc1a727c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_filelock.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_foreach_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_foreach_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce5a0f1c7bcf7d56ccd08f03382cf898e297d35b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_foreach_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_functools.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_functools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82a6d06aabea4f602c048e6ac921bf871a044db0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_functools.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_get_clean_triton.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_get_clean_triton.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cff688c1b53a1d9c269c083b76bb8ebf47b349e0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_get_clean_triton.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_helion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_helion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3b80cf675dcd197a82a831e2ef1095c353843b9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_helion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_import_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_import_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14be8a1b8e39fcc6e7d6bfd292f3026ccdebf70b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_import_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_mode_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_mode_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e236489429f1cf889b5c6238308b4e763bbcb77 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_mode_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_ordered_set.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_ordered_set.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..240b2b90cefb1fdcdf1133465987e46faa00c477 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_ordered_set.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_pallas.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_pallas.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b2423273edd764cc5b7187413187b5b65b1a07b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_pallas.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_python_dispatch.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_python_dispatch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39dfff8856e64598f13c7980f932304ee6839ea8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_python_dispatch.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_pytree.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_pytree.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8f457f6ab46157fd393739b578cbb44db9535bf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_pytree.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_runtime_estimation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_runtime_estimation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da0b7588fddaafe750b318eb3e0ad10b17754bdc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_runtime_estimation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_stats.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_stats.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4020cdaf6598dc4ed58c509681bd411d7020cabe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_stats.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_thunk.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_thunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70c2ccee17975b3eafffc55b0ed2aeadde9f1974 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_thunk.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_traceback.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_traceback.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdb958a2640348fccc56f0548eb9261fcb22ca62 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_traceback.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_triton.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_triton.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff8b25a16fdf0c340b2f9453f2d21dba25d9cbdd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_triton.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_typing_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_typing_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e902e76617eb0e3592281dda8c9c258a6a50f1d1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_typing_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_zip.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_zip.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..216c5ca42afd5f6fc8ec3bb1108eb20973374db1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/_zip.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/backend_registration.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/backend_registration.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4896b1242c46385ae91c0df0bf7ed56df1448ee3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/backend_registration.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/bundled_inputs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/bundled_inputs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96ebd1faaaafc9c668ea993ce7512150503ae9e1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/bundled_inputs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/checkpoint.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/checkpoint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bea12c7002d0b2d5048e8e640b749099dbba44b3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/checkpoint.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/collect_env.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/collect_env.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..411171c29792d6362a2936a71c47b10084ae000c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/collect_env.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/cpp_backtrace.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/cpp_backtrace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43d147ef45de8e6fb1ca6f1cb3698e3b64dab6e0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/cpp_backtrace.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/deterministic.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/deterministic.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4a721f4427a475a1492c5f090f7fd21965d61c0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/deterministic.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/dlpack.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/dlpack.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9715e26cd9d651e677020503f319de12b8f6a04 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/dlpack.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/file_baton.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/file_baton.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e541f71bb4f51bb781c61f56fb969f4f83122546 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/file_baton.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/flop_counter.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/flop_counter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b85819256960db87ac1dfb942d68234db65cfe56 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/flop_counter.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/hooks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/hooks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..718469c7be8ad1cdad70908422e8556226b5f61c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/hooks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/mkldnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/mkldnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95f77e5eb19340ed54a0d24b7a968d71ebd8ed02 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/mkldnn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/mobile_optimizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/mobile_optimizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fce4706907e6c0c9730f51666ec3e7327670cb0b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/mobile_optimizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/model_zoo.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/model_zoo.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19348ff9cb03a419c2bea425cb8833570e4e4061 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/model_zoo.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/module_tracker.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/module_tracker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..373e9fc4a879f7f90da091941d635c2c3429a6b1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/module_tracker.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/show_pickle.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/show_pickle.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f75af73084669dbe7b58a19bdb4225c24500788a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/show_pickle.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/throughput_benchmark.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/throughput_benchmark.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6a3245db6979edf07a753ba03b3e6736cbe8032 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/throughput_benchmark.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/weak.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/weak.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7961a80919b7a994383884203ecea8904348f4f7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/__pycache__/weak.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_strobelight/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_strobelight/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_strobelight/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_strobelight/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd4f6248d00667054d23fe26978d0105056ced31 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_strobelight/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e57a7885ecb2db48ce7326e3a9966a87fce36a28 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_strobelight/cli_function_profiler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_strobelight/cli_function_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..d2e1595bf2a1477b33ed00446d86e6cdea267a8f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_strobelight/cli_function_profiler.py @@ -0,0 +1,313 @@ +# mypy: disallow-untyped-defs + +import functools +import logging +import os +import re +import subprocess +import time +from collections.abc import Callable, Sequence +from threading import Lock +from typing import Any, TypeVar +from typing_extensions import ParamSpec + + +logger = logging.getLogger("strobelight_function_profiler") + +console_handler = logging.StreamHandler() +formatter = logging.Formatter( + "%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s" +) +console_handler.setFormatter(formatter) + +logger.addHandler(console_handler) +logger.setLevel(logging.INFO) +logger.propagate = False + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +class StrobelightCLIProfilerError(Exception): + """ + Raised when an error happens during strobelight profiling + """ + + +def _pid_namespace_link(pid: int | None = None) -> str: + """Returns the link to the process's namespace, example: pid:[4026531836]""" + PID_NAMESPACE_PATH = "/proc/{}/ns/pid" + pid = pid or os.getpid() + return os.readlink(PID_NAMESPACE_PATH.format(pid)) + + +def _pid_namespace(pid: int | None = None) -> int: + """Returns the process's namespace id""" + pid = pid or os.getpid() + link = _pid_namespace_link(pid) + return int(link[link.find("[") + 1 : -1]) + + +def _command_to_string(command: Sequence[str]) -> str: + return " ".join(command) + + +class StrobelightCLIFunctionProfiler: + """ + Note: this is a meta only tool. + + StrobelightCLIFunctionProfiler can be used to profile a python function and + generate a strobelight link with the results. It works on meta servers but + does not requires an fbcode target. + When stop_at_error is false(default), error during profiling does not prevent + the work function from running. + + Check function_profiler_example.py for an example. + """ + + # This lock is used to make sure only one thread is running the profiler at any point. + _lock = Lock() + + def __init__( + self, + *, + stop_at_error: bool = False, + max_profile_duration_sec: int = 60 * 10, + sample_each: float = 1e7, # sample each sample_each cycles. + run_user_name: str = "pytorch-strobelight-ondemand", + timeout_wait_for_running_sec: int = 60, + timeout_wait_for_finished_sec: int = 60, + recorded_env_variables: list[str] | None = None, + sample_tags: list[str] | None = None, + stack_max_len: int = 127, + async_stack_max_len: int = 127, + ) -> None: + self.stop_at_error = stop_at_error + self.max_profile_duration_sec = max_profile_duration_sec + self.sample_each = sample_each + self.run_user_name = run_user_name + self.timeout_wait_for_running_sec = timeout_wait_for_running_sec + self.timeout_wait_for_finished_sec = timeout_wait_for_finished_sec + # Results of the most recent run. + # Tracks the strobelight run id of the most recent run + self.current_run_id: int | None = None + self.sample_tags = sample_tags + + def _run_async(self) -> None: + processId = os.getpid() + namespace = _pid_namespace(processId) + command = [ + "strobeclient", + "run", + "--profiler", + "pyperf", + "--event", + "cycles", + "--async", + "--sample-interval", + f"{int(self.sample_each)}", + "--duration-ms", + f"{int(self.max_profile_duration_sec * 1000)}", + "--pid", + f"{namespace}:{processId}", + ] + + if self.sample_tags: + command.append("--sample-tags") + command.append(",".join(self.sample_tags)) + + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, error in run_async:{output}" + ) + + if match := re.search(r"INFO Run Id: (-?\d+)", output): + self.current_run_id = int(match.group(1)) + return + + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, unexpected result {output}" + ) + + def _wait_for_running(self, counter: int = 0) -> None: + if counter > 20: + raise StrobelightCLIProfilerError( + "wait_for_running called more than 20 times" + ) + + command = ["strobeclient", "getRunStatus", "--run-id", f"{self.current_run_id}"] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, error in wait_for_running:{output}" + ) + + if match := re.search("Profile run status: (.*)", output): + current_status = match.group(1) + if current_status == "RUNNING": + return + elif current_status == "PREPARING": + time.sleep(10) + self._wait_for_running(counter + 1) + return + else: + raise StrobelightCLIProfilerError(f"unexpected {current_status} phase") + + raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ") + + def _stop_run(self) -> None: + command = ["strobeclient", "stopRun", "--run-id", str(self.current_run_id)] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to stop strobelight profiling, return code is not 0 :{output}" + ) + + if match := re.search("INFO ::1:(.*)", output): + current_status = match.group(1) + if current_status.__contains__("Success!"): + return + else: + raise StrobelightCLIProfilerError( + f"failed to stop strobelight profiling, got {current_status} result" + ) + + raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ") + + def _get_results(self) -> None: + command = ["strobeclient", "getRunStatus", "--run-id", str(self.current_run_id)] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to extract profiling results, return code is not 0 : {output}" + ) + + if match := re.search("INFO ::1:(.*)", output): + current_status = match.group(1) + if current_status.__contains__("Profile run status: PROCESSING"): + time.sleep(10) + self._get_results() + return + elif not current_status.__contains__("Profile run finished with SUCCESS"): + raise StrobelightCLIProfilerError( + f"failed to extract profiling results, unexpected response {output}" + ) + + for item in re.findall( + r"(Total samples(.*)|GraphProfiler(.*)|Icicle view \(python stack\)(.*))", + output, + ): + logger.info(item[0]) + + def _stop_strobelight_no_throw( + self, + collect_results: bool, + ) -> None: + try: + # call stop run + self._stop_run() + logger.info("strobelight profiling stopped") + + logger.debug("collection stopped") + + if not collect_results: + return + + self._get_results() + except Exception: + logger.warning("error during stop_strobelight", exc_info=True) + + # Return true if strobelight started and is running. Never throw. + def _start_strobelight(self) -> bool: + strobelight_started = False + try: + self._run_async() + strobelight_started = True + logger.info("strobelight run id is: %s", self.current_run_id) + self._wait_for_running() + logger.info("strobelight profiling running") + return True + + except Exception: + logger.warning("error during start_strobelight:", exc_info=True) + if strobelight_started: + self._stop_strobelight_no_throw(collect_results=False) + return False + + def profile( + self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs + ) -> _R | None: + self.current_run_id = None + + if locked := StrobelightCLIFunctionProfiler._lock.acquire(False): + if not locked: + if self.stop_at_error: + raise StrobelightCLIProfilerError("concurrent runs not supported") + + logger.warning("concurrent runs not supported") + return work_function(*args, **kwargs) + + started = self._start_strobelight() + if not started: + if self.stop_at_error: + StrobelightCLIFunctionProfiler._lock.release() + raise StrobelightCLIProfilerError( + "failed to start strobelight profiling" + ) + result = work_function(*args, **kwargs) + StrobelightCLIFunctionProfiler._lock.release() + return result + + try: + logger.debug("collection started") + result = work_function(*args, **kwargs) + self._stop_strobelight_no_throw(collect_results=True) + StrobelightCLIFunctionProfiler._lock.release() + return result + except Exception as error: + logger.warning("work function throw exception", exc_info=True) + self._stop_strobelight_no_throw(collect_results=False) + StrobelightCLIFunctionProfiler._lock.release() + raise error + return None + + +# A function decorator that wraps profile, if no profiler is provided one with +# default args is created. A function can be annotated as: +# @strobelight() +# @strobelight(profiler = StrobelightFunctionProfiler(stop_at_error=True,..)) +# @strobelight(stop_at_error=True,...) +def strobelight( + profiler: StrobelightCLIFunctionProfiler | None = None, **kwargs: Any +) -> Callable[[Callable[_P, _R]], Callable[_P, _R | None]]: + if not profiler: + profiler = StrobelightCLIFunctionProfiler(**kwargs) + + def strobelight_inner( + work_function: Callable[_P, _R], + ) -> Callable[_P, _R | None]: + @functools.wraps(work_function) + def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _R | None: + # pyrefly: ignore [bad-argument-type] + return profiler.profile(work_function, *args, **kwargs) + + return wrapper_function + + return strobelight_inner diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b10f04ed9ff633081daac8e5c051aef36064a49 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e8164e40e80caa1f3a6682452f3348c1ea6d124 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/interp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/interp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8c9148e8de08d8576c14838b72fb5b3e7c72ea5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/interp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/numbers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/numbers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f24d0c4e0af18e0377327808cc10475d5e85de20 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/numbers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/printers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/printers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48470c5e994a075a4214045a1efcc77423bf2db0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/printers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/reference.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/reference.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed9f4b499c89498af5c1a3efb11bcdbb388149af Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/reference.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/singleton_int.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/singleton_int.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..555619b0aeddf3f4e2e70a3ab6894207b86ba3a8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/singleton_int.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/solve.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/solve.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2455e2c0bc7910280bdba009b0b50ad7bf95630e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/solve.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/symbol.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/symbol.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2e39ac32fd4972c27cb23a0c82f7510ece894d6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/symbol.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/value_ranges.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/value_ranges.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86aa9d4659b8e24027cb64363598db41743469ca Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/value_ranges.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/functions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..0816a2c23d6484b8b4e7bca0a9225554ae7770b2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/functions.py @@ -0,0 +1,1463 @@ +# mypy: allow-untyped-defs +import functools +import math +import operator +import sys +from collections.abc import Callable +from typing import SupportsFloat, TYPE_CHECKING, TypeVar +from typing_extensions import TypeVarTuple, Unpack + +import sympy +from sympy import S +from sympy.core import sympify +from sympy.core.expr import Expr +from sympy.core.function import Application +from sympy.core.logic import _torf, fuzzy_and, fuzzy_or +from sympy.core.numbers import equal_valued +from sympy.core.operations import LatticeOp, ShortCircuit +from sympy.core.sorting import ordered +from sympy.core.traversal import walk +from sympy.printing.precedence import PRECEDENCE +from sympy.utilities.iterables import sift + +from torch.torch_version import TorchVersion + +from .numbers import int_oo + + +if TYPE_CHECKING: + from collections.abc import Iterable + + +_T = TypeVar("_T", bound=SupportsFloat) +_Ts = TypeVarTuple("_Ts") + +# Portions of this file are adapted from the Sympy codebase, which was +# licensed as follows: +# +# Copyright (c) 2006-2023 SymPy Development Team +# +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# a. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# b. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# c. Neither the name of SymPy nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +# DAMAGE. + +__all__ = [ + "FloorDiv", + "ModularIndexing", + "Where", + "PythonMod", + "Mod", + "CleanDiv", + "CeilToInt", + "FloorToInt", + "CeilDiv", + "IntTrueDiv", + "FloatTrueDiv", + "LShift", + "RShift", + "IsNonOverlappingAndDenseIndicator", + "TruncToFloat", + "TruncToInt", + "RoundToInt", + "RoundDecimal", + "ToFloat", + "FloatPow", + "PowByNatural", + "Identity", +] + + +def _is_symbols_binary_summation(expr: sympy.Expr) -> bool: + # No need to check that two args are not the same, since expr is pr-optimized but we do it anyway. + return ( + expr.is_Add + and len(expr._args) == 2 + and expr._args[0].is_symbol + and expr._args[1].is_symbol + and expr._args[0] is not expr._args[1] + ) + + +def _keep_float( + f: Callable[[Unpack[_Ts]], _T], +) -> Callable[[Unpack[_Ts]], _T | sympy.Float]: + @functools.wraps(f) + def inner(*args: Unpack[_Ts]) -> _T | sympy.Float: + # pyrefly: ignore [bad-argument-type] + r: _T | sympy.Float = f(*args) + if any(isinstance(a, sympy.Float) for a in args) and not isinstance( + r, sympy.Float + ): + r = sympy.Float(float(r)) + return r + + # pyrefly: ignore [bad-return] + return inner + + +def fuzzy_eq(x: bool | None, y: bool | None) -> bool | None: + if None in (x, y): + return None + return x == y + + +def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic: + """ + Fast path for sympy.gcd, using a simple factoring strategy. + + We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0, + where n is the greatest common integer factor and e is the largest + syntactic common factor (i.e., common sub-expression) in p and q. + Then the gcd returned is n*e, cancelling which we would be left with + p1 + p2 and q0. + + Note that further factoring of p1 + p2 and q0 might be possible with + sympy.factor (which uses domain-specific theories). E.g., we are unable + to find that x*y + x + y + 1 is divisible by x + 1. More generally, + when q is of the form q1 + q2 (instead of being already factored) it + might be necessary to fall back on sympy.gcd. + """ + + def integer_coefficient(x: sympy.Basic) -> int: + integer_coefficients: list[int] = [ + abs(int(arg)) + for arg in sympy.Mul.make_args(x) + if isinstance(arg, (int, sympy.Integer)) + ] + return math.prod(integer_coefficients) + + def integer_factor(expr: sympy.Basic) -> int: + integer_factors: Iterable[int] = map( + integer_coefficient, sympy.Add.make_args(expr) + ) + return functools.reduce(math.gcd, integer_factors) + + gcd: int = math.gcd(integer_factor(p), integer_factor(q)) + p, q = p / gcd, q / gcd # type: ignore[operator, assignment] # remove in py3.12 + + base_splits: list[tuple[sympy.Basic, ...]] = list( + map(sympy.Mul.make_args, sympy.Add.make_args(p)) + ) + divisor_split: tuple[sympy.Basic, ...] = sympy.Mul.make_args(q) + for x in divisor_split: + if all(x in base_split for base_split in base_splits): + gcd = gcd * x # type: ignore[operator] # remove in py3.12 + return gcd # type: ignore[return-value] # remove in py3.12 + + +# It would be nice to have assertions on whether or not inputs is_integer +# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy +# sometimes inconsistently reports floats an integers. +# +# What we can assume from sympy is that if something is an int, it +# definitely is is_integer, but if it is a float it may or may not +# be is_integer. So we are unable to do strong asserts that things +# are NOT integers. + + +# TODO: In Triton, // rounds to zero, but in Python, it is floor division. +# When we can prove both arguments are non-negative, we should just have a +# GenericFloorDiv (name pending) which can codegen efficiently in Python/C, +# and then PythonFloorDiv and CIntDiv which have the appropriate rounding +# semantics. +# +# Right now, FloorDiv de facto changes behavior if arguments are negative or +# not, this can potentially cause correctness issues. +class FloorDiv(sympy.Function): + """ + We maintain this so that: + 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. + 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) + + NB: This is Python-style floor division, round to -Inf + """ + + nargs: tuple[int, ...] = (2,) + precedence: int = 35 # lower precedence than add + is_integer: bool = True + + @property + def base(self) -> sympy.Basic: + # pyrefly: ignore [missing-attribute] + return self.args[0] + + @property + def divisor(self) -> sympy.Basic: + # pyrefly: ignore [missing-attribute] + return self.args[1] + + def _sympystr(self, printer: sympy.printing.StrPrinter) -> str: + base = printer.parenthesize(self.base, PRECEDENCE["Atom"] - 0.5) + divisor = printer.parenthesize(self.divisor, PRECEDENCE["Atom"] - 0.5) + return f"({base}//{divisor})" + + # Automatic evaluation. + # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval + @classmethod + def eval(cls, base: sympy.Integer, divisor: sympy.Integer) -> sympy.Basic | None: + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # Assert triggered by inequality solver + # assert base.is_integer, base + # assert divisor.is_integer, divisor + + # We don't provide the same error message as in Python because SymPy + # makes it difficult to check the types. + if divisor.is_zero: + raise ZeroDivisionError("division by zero") + if base in (int_oo, -int_oo, sympy.oo, -sympy.oo) and divisor in ( + int_oo, + -int_oo, + sympy.oo, + -sympy.oo, + ): + return sympy.nan + if base is sympy.nan or divisor is sympy.nan: + return sympy.nan + + if base.is_zero: + return sympy.S.Zero + if base.is_integer and equal_valued(divisor, 1): + return base + if base.is_integer and equal_valued(divisor, -1): + return sympy.Mul(base, -1) + if ( + isinstance(base, sympy.Number) + and isinstance(divisor, sympy.Number) + and ( + base in (int_oo, -int_oo, sympy.oo, -sympy.oo) + or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo) + ) + ): + r = float(base) / float(divisor) + if r == math.inf: + return int_oo + elif r == -math.inf: + return -int_oo + elif math.isnan(r): + return sympy.nan + else: + return sympy.Integer(math.floor(r)) + if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): + return sympy.Integer(int(base) // int(divisor)) + if isinstance(base, FloorDiv): + return FloorDiv(base.args[0], base.args[1] * divisor) + + # Expands (x + y) // b into x // b + y // b. + # This only works if floor is an identity, i.e. x / b is an integer. + if isinstance(divisor, sympy.Integer): + quotients = 0 + terms = [] + for term in sympy.Add.make_args(base): + quotient = term / divisor + + # This is a sympy bug fixed in https://github.com/sympy/sympy/pull/28442 + # sympy can generate a quotient with (1/22)*.... such that quotient.is_integer is True + # FloorDiv should not allow that as output. see + quotient_is_integer = None + if isinstance(quotient, sympy.Mul) and TorchVersion( + sympy.__version__ + ) < TorchVersion("1.15.0"): + rationals = quotient.atoms(sympy.Rational) + all_rationals_ints = all(r.q == 1 for r in rationals) + quotient_is_integer = quotient.is_integer and all_rationals_ints + else: + quotient_is_integer = quotient.is_integer + + if quotient_is_integer: + terms.append(term) + quotients += quotient + + if len(terms) != 0: + # Passing evaluate = False since expression will be optimized during the subtraction post its construction. + return ( + FloorDiv(base - sympy.Add(*terms, evaluate=False), divisor) + + quotients + ) + + try: + gcd = simple_floordiv_gcd(base, divisor) + if equal_valued(gcd, 1) and isinstance(divisor, sympy.Add): + gcd = sympy.gcd(base, divisor) + if not equal_valued(gcd, 1): + return FloorDiv( + sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) + ) + except sympy.PolynomialError: + pass # https://github.com/pytorch/pytorch/issues/108276 + + return None + + +class ModularIndexing(sympy.Function): + """ + ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus + """ + + nargs: tuple[int, ...] = (3,) + is_integer: bool = True + precedence: int = 35 # lower precedence than add + + @classmethod + def eval( + cls, base: sympy.Integer, divisor: sympy.Integer, modulus: sympy.Integer + ) -> sympy.Basic | None: + if base == 0 or modulus == 1: + return sympy.S.Zero + if ( + isinstance(base, sympy.Integer) + and isinstance(divisor, sympy.Integer) + and isinstance(modulus, sympy.Integer) + ): + return (base // divisor) % modulus + + try: + if divisor != 1: + gcd = sympy.gcd(base, divisor) + if gcd != 1: + return ModularIndexing( + sympy.simplify(base / gcd), + sympy.simplify(divisor / gcd), + modulus, + ) + except sympy.PolynomialError: + pass # https://github.com/pytorch/pytorch/issues/108276 + + if isinstance(base, sympy.Add): + new_terms: list[sympy.Integer] = [] + all_positive: bool = True + for term in base.args: + if sympy.gcd(term, modulus * divisor) != modulus * divisor: + if (isinstance(term, sympy.Integer) and term < 0) or ( + isinstance(term, sympy.Mul) + and isinstance(term.args[0], sympy.Integer) + and term.args[0] < 0 + ): + # workaround for https://github.com/triton-lang/triton/issues/619, + # if there are negative terms, // produces wrong result + # TODO if https://github.com/triton-lang/triton/issues/619 is fixed + # this optimization would become valid + all_positive = False + break + else: + new_terms.append(term) + + if len(new_terms) != len(base.args) and all_positive: + return ModularIndexing(sum(new_terms), divisor, modulus) + + if isinstance(base, FloorDiv): + return ModularIndexing(base.args[0], base.args[1] * divisor, modulus) + + return None + + def _eval_is_nonnegative(self) -> bool | None: + # pyrefly: ignore [missing-attribute] + p, q = self.args[:2] + return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined] + + +class Where(sympy.Function): + """ + Good ol' ternary operator + """ + + nargs: tuple[int, ...] = (3,) + precedence: int = 35 # lower precedence than add + + def _eval_is_integer(self) -> bool | None: + return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] + + def _eval_is_nonnegative(self) -> bool | None: + return ( + True + if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] + else None + ) + + def _eval_is_positive(self) -> bool | None: + return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] + + @classmethod + def eval(cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic) -> sympy.Basic | None: + if c == sympy.true: + return p + elif c == sympy.false: + return q + return None + + +# Python-style modulus: take sign from RHS +class PythonMod(sympy.Function): + nargs: tuple[int, ...] = (2,) + + precedence: int = 35 # lower precedence than add + is_integer: bool = True + + @classmethod + def eval(cls, p: sympy.Expr, q: sympy.Expr) -> sympy.Expr | None: + # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint + # Triggered by sympy.solvers.inequalities.reduce_inequalities + # assert p.is_integer, p + # assert q.is_integer, q + + if q.is_zero: + raise ZeroDivisionError("Modulo by zero") + + # Three cases: + # 1. p == 0 + # 2. p is either q or -q + # 3. p is integer and q == 1 + if p is S.Zero or p in (q, -q) or q == 1: + return S.Zero + + # Evaluate if they are both literals. + if q.is_Number and p.is_Number: + return p % q + + # If q == 2, it's a matter of whether p is odd or even. + if q.is_Number and q == 2: + if p.is_even: + return S.Zero + if p.is_odd: + return S.One + + # If p is a multiple of q. + r = p / q + if r.is_integer: + return S.Zero + + # If p < q and its ratio is positive, then: + # - floor(p / q) = 0 + # - p % q = p - floor(p / q) * q = p + less = p < q + # pyrefly: ignore [missing-attribute] + if less.is_Boolean and bool(less) and r.is_positive: + return p + + if sympy.Mod(p, q) == 0: + return S.Zero + + return None + + # NB: args[1] for PythonMod + def _eval_is_nonnegative(self) -> bool | None: + return True if self.args[1].is_positive else None # type: ignore[attr-defined] + + def _eval_is_nonpositive(self) -> bool | None: + return True if self.args[1].is_negative else None # type: ignore[attr-defined] + + def _ccode(self, printer) -> str: + # pyrefly: ignore [missing-attribute] + p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5) + # pyrefly: ignore [missing-attribute] + q = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5) + # pyrefly: ignore [missing-attribute] + abs_q = str(q) if self.args[1].is_positive else f"abs({q})" + return f"({p} % {q}) < 0 ? {p} % {q} + {abs_q} : {p} % {q}" + + +# Generic modulus: only defined on non-negative arguments +class Mod(sympy.Function): + nargs = (2,) + precedence: int = 35 # lower precedence than add + + is_integer = True + is_nonnegative = True + + @classmethod + def eval(cls, p, q): + # This was adapted from: sympy/core/mod.py + + # Triggered by + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # assert p.is_integer, p + # assert q.is_integer, q + + if q.is_zero: + raise ZeroDivisionError("Modulo by zero") + + # Three cases: + # 1. p == 0 + # 2. p is either q or -q + # 3. p is integer and q == 1 + if p is S.Zero or p in (q, -q) or q == 1: + return S.Zero + + # Evaluate if they are both literals. + if q.is_Number and p.is_Number: + if p < 0: + raise AssertionError(p) + if q < 1: + raise AssertionError(q) + return p % q + + # If q == 2, it's a matter of whether p is odd or even. + if q.is_Number and q == 2: + if p.is_even: + return S.Zero + if p.is_odd: + return S.One + + # If p is a multiple of q. + r = p / q + if r.is_integer: + return S.Zero + + # If p < q and its ratio is positive, then: + # - floor(p / q) = 0 + # - p % q = p - floor(p / q) * q = p + less = p < q + if less.is_Boolean and bool(less) and r.is_positive: + return p + + +class CleanDiv(FloorDiv): + """ + Div where we can assume no rounding. + This is to enable future optimizations. + """ + + +# Don't use sympy ceiling/floor as they will attempt simplifications involving +# frac +class CeilToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if number in (sympy.oo, int_oo): + return int_oo + if number in (-sympy.oo, -int_oo): + return -int_oo + if isinstance(number, sympy.Number): + return sympy.Integer(math.ceil(float(number))) + + def _ccode(self, printer) -> str: + # pyrefly: ignore [missing-attribute] + number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5) + return f"ceil({number})" + + +class FloorToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + if number in (sympy.oo, int_oo): + return int_oo + if number in (-sympy.oo, int_oo): + return -int_oo + if isinstance(number, sympy.Integer): + return number + if isinstance(number, sympy.Number): + return sympy.Integer(math.floor(float(number))) + + +class CeilDiv(sympy.Function): + """ + Div used in indexing that rounds up. + """ + + is_integer = True + + def __new__(cls, base, divisor): + base = sympy.sympify(base) + divisor = sympy.sympify(divisor) + if sympy.gcd(base, divisor) == divisor: + return CleanDiv(base, divisor) + else: + return FloorDiv(base + (divisor - 1), divisor) + + +class LShift(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, base, shift): + if shift < 0: + raise ValueError("negative shift count") + return base * 2**shift + + +class RShift(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, base, shift): + if shift < 0: + raise ValueError("negative shift count") + return FloorDiv(base, 2**shift) + + +class MinMaxBase(Expr, LatticeOp): # type: ignore[misc] + def __new__(cls, *original_args, **assumptions): + from sympy.core.parameters import global_parameters + + evaluate = assumptions.pop("evaluate", global_parameters.evaluate) + args = (sympify(arg) for arg in original_args) + + # See the comment in _satisfy_unique_summations_symbols. + unique_summations_symbols = ( + None + if not evaluate + else cls._satisfy_unique_summations_symbols(original_args) + ) + + if evaluate: + try: + # first standard filter, for cls.zero and cls.identity + # also reshape Max(a, Max(b, c)) to Max(a, b, c) + args = frozenset(cls._new_args_filter(args)) # type: ignore[assignment] + except ShortCircuit: + return cls.zero # type: ignore[attr-defined] + + # No need to run _collapse_arguments and _find_localzeros, see the comment + # in _satisfy_unique_summations_symbols. + if unique_summations_symbols is None: + # remove redundant args that are easily identified + args = cls._collapse_arguments(args, **assumptions) + + # find local zeros + args = cls._find_localzeros(args, **assumptions) + + args = frozenset(args) + + if not args: + return cls.identity # type: ignore[attr-defined] + + if len(args) == 1: + return list(args).pop() + + # base creation + obj = Expr.__new__(cls, *ordered(args), **assumptions) + obj._argset = args + + obj.unique_summations_symbols = unique_summations_symbols + return obj + + @classmethod + def _satisfy_unique_summations_symbols( + cls, args + ) -> set[sympy.core.symbol.Symbol] | None: + """ + One common case in some models is building expressions of the form + max(max(max(a+b...), c+d), e+f) which is simplified to max(a+b, c+d, e+f, ...). + For such expressions, we call the Max constructor X times (once for each nested + max) and the expression gets flattened. + + An expensive cost in constructing those expressions is running _collapse_arguments + and _find_localzeros. However, those two optimizations are unnecessary when the args + to max are all of the form a+b, c+d, ..etc where each term uses a unique set of symbols. + + This function is used to detect such properties of the expressions we are building + and if so inform that we do not need to run those optimizations. To detect those, + we store a property in the expression that tells that this expression is a min/max + operation over terms that use unique symbols "unique_summations_symbols". This property + also memoize the set of symbols used in all the terms to make it faster to detect this + property inductively. + + When we apply max to add a new term, all we need to do is check if the new term uses + unique symbols (with respect to existing terms and itself). + Example: + t = Max(a+b, c+d) ==> satisfies the property + Max(t, h+j) ==> h,j not in [a,b,c,d] => satisfy the property. + + The function returns None if the new expression does not satisfy the unique_summations_symbols + property. Otherwise, it returns a new set of unique symbols. + """ + if len(args) != 2: + return None + + (lhs, rhs) = ( + (args[1], args[0]) + if isinstance(args[1], MinMaxBase) + else (args[0], args[1]) + ) + + if not _is_symbols_binary_summation(rhs): + return None + + # base case max(a+b, c+d) ==> satisfies the property if a+b and c+d use unique symbols. + if _is_symbols_binary_summation(lhs): + return cls._unique_symbols(args) + + # inductive case max(t, h+j) ==> satisfies the property if h, j not in t.unique_summations_symbols + if isinstance(lhs, MinMaxBase): + lhs_unique_summations_symbols = getattr( + lhs, "unique_summations_symbols", None + ) + if lhs_unique_summations_symbols is not None: + return cls._unique_symbols([rhs], lhs_unique_summations_symbols) + + return None + + @classmethod + def _unique_symbols( + cls, args, initial_set: set[sympy.core.symbol.Symbol] | None = None + ) -> set[sympy.core.symbol.Symbol] | None: + """ + Return seen_symbols if all atoms in all args are all unique symbols, + else returns None. initial_set can be used to represent initial value for seen_symbols + """ + seen_symbols = set() if initial_set is None else initial_set + for arg in args: + for element in arg.atoms(): + if not isinstance(element, sympy.core.symbol.Symbol): + return None + elif element in seen_symbols: + return None + else: + seen_symbols.add(element) + return seen_symbols + + @classmethod + def _collapse_arguments(cls, args, **assumptions): + """Remove redundant args. + + Examples + ======== + + >>> from sympy import Min, Max + >>> from sympy.abc import a, b, c, d, e + + Any arg in parent that appears in any + parent-like function in any of the flat args + of parent can be removed from that sub-arg: + + >>> Min(a, Max(b, Min(a, c, d))) + Min(a, Max(b, Min(c, d))) + + If the arg of parent appears in an opposite-than parent + function in any of the flat args of parent that function + can be replaced with the arg: + + >>> Min(a, Max(b, Min(c, d, Max(a, e)))) + Min(a, Max(b, Min(a, c, d))) + """ + if not args: + return args + args = list(ordered(args)) + if cls is Min: + other = Max + else: + other = Min # type: ignore[assignment] + + # find global comparable max of Max and min of Min if a new + # value is being introduced in these args at position 0 of + # the ordered args + if args[0].is_number: + sifted = mins, maxs = [], [] # type: ignore[var-annotated] + for i in args: + for v in walk(i, Min, Max): + if v.args[0].is_comparable: + sifted[isinstance(v, Max)].append(v) + small = Min.identity + for i in mins: + v = i.args[0] + if v.is_number and (v < small) == True: # noqa: E712 + small = v + big = Max.identity + for i in maxs: + v = i.args[0] + if v.is_number and (v > big) == True: # noqa: E712 + big = v + # at the point when this function is called from __new__, + # there may be more than one numeric arg present since + # local zeros have not been handled yet, so look through + # more than the first arg + if cls is Min: + for arg in args: + if not arg.is_number: + break + if (arg < small) == True: # noqa: E712 + small = arg + elif cls == Max: + for arg in args: + if not arg.is_number: + break + if (arg > big) == True: # noqa: E712 + big = arg + T = None + if cls is Min: + if small != Min.identity: + other = Max + T = small + elif big != Max.identity: + other = Min # type: ignore[assignment] + T = big + if T is not None: + # remove numerical redundancy + for i in range(len(args)): + a = args[i] + if isinstance(a, other): + a0 = a.args[0] + if ( # noqa: E712 + (a0 > T) if other == Max else (a0 < T) # noqa: E712 + ) == True: # noqa: E712 + args[i] = cls.identity # type: ignore[attr-defined] + + # remove redundant symbolic args + def do(ai, a): + if not isinstance(ai, (Min, Max)): + return ai + cond = a in ai.args + if not cond: + return ai.func(*[do(i, a) for i in ai.args], evaluate=False) + if isinstance(ai, cls): + # pyrefly: ignore [missing-attribute] + return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False) + return a + + for i, a in enumerate(args): + args[i + 1 :] = [do(ai, a) for ai in args[i + 1 :]] + + # factor out common elements as for + # Min(Max(x, y), Max(x, z)) -> Max(x, Min(y, z)) + # and vice versa when swapping Min/Max -- do this only for the + # easy case where all functions contain something in common; + # trying to find some optimal subset of args to modify takes + # too long + + def factor_minmax(args): + is_other = lambda arg: isinstance(arg, other) # noqa: E731 + other_args, remaining_args = sift(args, is_other, binary=True) + if not other_args: + return args + + # Min(Max(x, y, z), Max(x, y, u, v)) -> {x,y}, ({z}, {u,v}) + arg_sets = [set(arg.args) for arg in other_args] + common = set.intersection(*arg_sets) + if not common: + return args + + new_other_args = list(common) + arg_sets_diff = [arg_set - common for arg_set in arg_sets] + + # If any set is empty after removing common then all can be + # discarded e.g. Min(Max(a, b, c), Max(a, b)) -> Max(a, b) + if all(arg_sets_diff): + other_args_diff = [other(*s, evaluate=False) for s in arg_sets_diff] + new_other_args.append(cls(*other_args_diff, evaluate=False)) + + other_args_factored = other(*new_other_args, evaluate=False) + return remaining_args + [other_args_factored] + + if len(args) > 1: + args = factor_minmax(args) + + return args + + @classmethod + def _new_args_filter(cls, arg_sequence): + """ + Generator filtering args. + + first standard filter, for cls.zero and cls.identity. + Also reshape ``Max(a, Max(b, c))`` to ``Max(a, b, c)``, + and check arguments for comparability + """ + for arg in arg_sequence: + # pre-filter, checking comparability of arguments + if ( + not isinstance(arg, Expr) + or arg.is_extended_real is False + or (arg.is_number and not arg.is_comparable) + ): + raise ValueError(f"The argument '{arg}' is not comparable.") + + if arg == cls.zero: # type: ignore[attr-defined] + raise ShortCircuit(arg) + elif arg == cls.identity: # type: ignore[attr-defined] + continue + elif arg.func == cls: + yield from arg.args + else: + yield arg + + @classmethod + def _find_localzeros(cls, values, **options): + """ + Sequentially allocate values to localzeros. + + When a value is identified as being more extreme than another member it + replaces that member; if this is never true, then the value is simply + appended to the localzeros. + + Unlike the sympy implementation, we only look for zero and one, we don't + do generic is connected test pairwise which is slow + """ + + # First, collapse all numeric arguments + other_values = set() + num_value = None + for arg in values: + if arg.is_Number: + if num_value is None: + num_value = arg + else: + if cls is Max: + num_value = max(num_value, arg) + elif cls is Min: + num_value = min(num_value, arg) + else: + raise AssertionError(f"impossible {cls}") + else: + other_values.add(arg) + + # Special cases when there is only one symbolic value + if num_value is None: + return other_values + + if len(other_values) == 0: + return {num_value} + + if len(other_values) == 1: + other_value = next(iter(other_values)) + if num_value in (0.0, 0) and other_value.is_nonnegative: + return other_values if cls is Max else {num_value} + if num_value == 1 and other_value.is_positive: + return other_values if cls is Max else {num_value} + + other_values.add(num_value) + return other_values + + _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731 + _eval_is_antihermitian = lambda s: _torf( # noqa: E731 + i.is_antihermitian + for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_commutative = lambda s: _torf( # noqa: E731 + i.is_commutative + for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731 + _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731 + _eval_is_even = lambda s: _torf(i.is_even for i in s.args) # noqa: E731 + _eval_is_finite = lambda s: _torf(i.is_finite for i in s.args) # noqa: E731 + _eval_is_hermitian = lambda s: _torf(i.is_hermitian for i in s.args) # noqa: E731 + _eval_is_imaginary = lambda s: _torf(i.is_imaginary for i in s.args) # noqa: E731 + _eval_is_infinite = lambda s: _torf(i.is_infinite for i in s.args) # noqa: E731 + _eval_is_integer = lambda s: _torf(i.is_integer for i in s.args) # noqa: E731 + _eval_is_irrational = lambda s: _torf(i.is_irrational for i in s.args) # noqa: E731 + _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731 + _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731 + _eval_is_nonnegative = lambda s: _torf( # noqa: E731 + i.is_nonnegative + for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_nonpositive = lambda s: _torf( # noqa: E731 + i.is_nonpositive + for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731 + _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731 + _eval_is_polar = lambda s: _torf(i.is_polar for i in s.args) # noqa: E731 + _eval_is_positive = lambda s: _torf(i.is_positive for i in s.args) # noqa: E731 + _eval_is_prime = lambda s: _torf(i.is_prime for i in s.args) # noqa: E731 + _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731 + _eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731 + _eval_is_extended_real = lambda s: _torf( # noqa: E731 + i.is_extended_real + for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_transcendental = lambda s: _torf( # noqa: E731 + i.is_transcendental + for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731 + + +class Max(MinMaxBase, Application): # type: ignore[misc] + r""" + Return, if possible, the maximum value of the list. + """ + + zero = S.Infinity + identity = S.NegativeInfinity + + def _eval_is_positive(self): # type:ignore[override] + return fuzzy_or(a.is_positive for a in self.args) # type: ignore[attr-defined] + + def _eval_is_nonnegative(self): # type:ignore[override] + return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] + + def _eval_is_negative(self): # type:ignore[override] + # pyrefly: ignore [missing-attribute] + return fuzzy_and(a.is_negative for a in self.args) + + +class Min(MinMaxBase, Application): # type: ignore[misc] + """ + Return, if possible, the minimum value of the list. + """ + + zero = S.NegativeInfinity + identity = S.Infinity + + def _eval_is_positive(self): # type:ignore[override] + return fuzzy_and(a.is_positive for a in self.args) # type: ignore[attr-defined] + + def _eval_is_nonnegative(self): # type:ignore[override] + return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] + + def _eval_is_negative(self): # type:ignore[override] + # pyrefly: ignore [missing-attribute] + return fuzzy_or(a.is_negative for a in self.args) + + +def safe_pow(base, exp): + sign = 1 + if base < 0: + base = -base + sign = 1 if exp % 2 == 0 else -1 + return sign * _safe_pow(base, exp) + + +# Prevent people from overflowing pow +def _safe_pow(base, exponent): + if exponent < 0: + raise ValueError("Exponent must be non-negative.") + + if exponent == 0: + return 1 + + half_exp = safe_pow(base, exponent // 2) + if half_exp is int_oo: + return int_oo + + # TODO: microoptimization is to avoid overflowing into arbitrary precision + # and detect overflow prior to doing operations + + result = half_exp * half_exp + if result > sys.maxsize: + return int_oo + + if exponent % 2 == 1: + result *= base + if result > sys.maxsize: + return int_oo + + return result + + +class PowByNatural(sympy.Function): + is_integer = True + + precedence: int = 50 # precedence of mul + + @classmethod + def eval(cls, base, exp): + if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer): + r = safe_pow(base, exp) + if r in (-int_oo, int_oo): + return r + return sympy.Integer(r) + if isinstance(exp, sympy.Integer): + # Rely on regular sympy Pow for this (note that iterated + # multiplication turns into a Pow anyway, you can't escape!!) + return sympy.Pow(base, exp) + if exp in (int_oo, sympy.oo): + if base.is_nonnegative: + return int_oo + elif base.is_negative: + return sympy.zoo # this is apparently what (-2)**sympy.oo does + # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp + # is a natural number if we do + + +# base is assumed to be nonnegative, thereby prevent complex numbers from +# occurring +class FloatPow(sympy.Function): + is_real = True + + precedence: int = 60 # precedence of pow + + @classmethod + def eval(cls, base, exp): + # NB: These test sympy.Number, not sympy.Float, because: + # - Sometimes we may have sympy.oo or int_oo, and that's not a Float + # (but coerces to math.Inf) + # - Sometimes Float(0.0) will unpredictably decay to Integer(0), + # but we should still accept it in floatey contexts + if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): + return sympy.Float(float(base) ** float(exp)) + # NB: do not do any nontrivial reasoning + + +# Overloaded to be compatible with regular Python. +# https://github.com/pytorch/pytorch/issues/90900 +# +# In particular, sympy division is willing to simplify x/x == 1 +# where 1 is an integer, but this must be a float if x was float. +class FloatTrueDiv(sympy.Function): + is_real = True + + precedence: int = 35 # lower precedence than add + + @classmethod + def eval(cls, base, divisor): + # assert base.is_integer is not True, base + # assert divisor.is_integer is not True, divisor + + if divisor.is_zero: + raise ZeroDivisionError("division by zero") + + if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): + return sympy.Float(float(base) / float(divisor)) + + +# Overloaded to be compatible with regular Python. We distinguish this from +# FloatTrueDiv, because the code generation has to be different for this case: +# Python has a fancy algorithm for integer true division that isn't just +# "promote both arguments to float and use float division", so you need to +# codegen it differently. While technically you can work it out from the +# types of the input, this is often inconvenient to do in Inductor codegen, +# so just have a different operator +# NB: Right now, Inductor codegen doesn't implement this correctly lol +class IntTrueDiv(sympy.Function): + is_real = True + + precedence: int = 35 # lower precedence than add + + @classmethod + def eval(cls, base, divisor): + if divisor.is_zero: + raise ZeroDivisionError("division by zero") + + if ( + isinstance(base, sympy.Number) + and isinstance(divisor, sympy.Number) + and ( + base in (int_oo, -int_oo, sympy.oo, -sympy.oo) + or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo) + ) + ): + # Don't have to worry about precision here, you're getting zero or + # inf from the division + return sympy.Float(float(base) / float(divisor)) + if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): + return sympy.Float(int(base) / int(divisor)) + + def _ccode(self, printer) -> str: + # pyrefly: ignore [missing-attribute] + base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5) + # pyrefly: ignore [missing-attribute] + divisor = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5) + return f"((int){base}/(int){divisor})" + + +# TODO: As an indicator, this != 0 implies == 1 (and vice versa). +# Because we do not have the ability to guard on the stride permutation +# at the moment, it is hard to make further inferences when this is true, +# as although we know the tensor is contiguous in *some* layout, we don't +# know which one (however, you could, for example, make the inference that +# reshaping this to a 1D tensor can be guard-free.) +class IsNonOverlappingAndDenseIndicator(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, *args): + if len(args) % 2 != 0: + raise AssertionError( + f"expected an even number of arguments, got {len(args)}" + ) + dim = len(args) // 2 + sizes = args[0:dim] + strides = args[dim:] + + # sym_node imported in torch.__init__. Local import to avoid an import cycle + from torch.fx.experimental.symbolic_shapes import ( + eval_is_non_overlapping_and_dense, + ) + + if all(isinstance(a, sympy.Integer) for a in args): + return eval_is_non_overlapping_and_dense( + [int(a) for a in sizes], [int(a) for a in strides] + ) + + if dim == 1: + # Manually implement the rank one short circuit + if strides[0].is_Number and strides[0] == 1: + return 1 + + if sizes[0].is_Number and sizes[0] < 2: + return 1 + + # return 0 case covered by case above + + # TODO: Inability to access size-obliviousness sucks: if we have a + # size oblivious test on a size-like unbacked SymInt, we could + # confidently return zero when we have a size-like u0 stride + # and a size-like u1 size. Maybe a fancy ValueRanges analysis for + # this function could help figure this out. + + if all(isinstance(a, sympy.Integer) for a in strides): + if dim == 0: + raise AssertionError("dim must not be zero") + # When all strides are integral, we can sort, and the size for the + # largest stride doesn't matter and can be arbitrarily symbolic + s_sizes, s_strides = zip( + *sorted(zip(sizes, strides, strict=True), key=operator.itemgetter(1)), + strict=True, + ) + # Put something arbitrary in the max size spot, it'll be ignored + if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]): + s_sizes = s_sizes[:-1] + (42,) + # We can reuse the regular eval, because it is invariant to + # permutation of dimensions + return eval_is_non_overlapping_and_dense( + [int(a) for a in s_sizes], [int(a) for a in s_strides] + ) + + return None + + +# NB: this is inconsistent with math.trunc in Python +class TruncToFloat(sympy.Function): + is_real = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if isinstance(number, sympy.Number): + # NB: It is safe to use truncation to integer, which is what + # math.trunc does, as Python integers are arbitrary precision and + # so we are guaranteed not to lose precision when we do this + return sympy.Float(math.trunc(float(number))) + + +class TruncToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if number in (sympy.oo, int_oo): + return int_oo + if number in (-sympy.oo, -int_oo): + return -int_oo + if isinstance(number, sympy.Number): + return sympy.Integer(math.trunc(float(number))) + + +# This is float -> int +class RoundToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + + if number is sympy.oo: + return int_oo + if number is -sympy.oo: + return -int_oo + if isinstance(number, sympy.Number): + return sympy.Integer(round(float(number), 0)) + + +# To get float -> int, Python style round semantics. +# +# x = PyFloat_AsDouble(self); +# if (o_ndigits == Py_None) { +# /* single-argument round or with None ndigits: +# * round to nearest integer */ +# rounded = round(x); +# if (fabs(x-rounded) == 0.5) +# /* halfway case: round to even */ +# rounded = 2.0*round(x/2.0); +# return PyLong_FromDouble(rounded); +# } + + +# NB: Like Round, this only ever returns floats. ndigits cannot be None +class RoundDecimal(sympy.Function): + is_real = True + + @classmethod + def eval(cls, number, ndigits): + # assert number.is_integer is not True, number + + if isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): + return sympy.Float(round(float(number), int(ndigits))) + + +class ToFloat(sympy.Function): + is_real = True + + @classmethod + def eval(cls, number): + if number in [sympy.oo, -sympy.oo]: + return number + + if isinstance(number, sympy.Integer): + return sympy.Float(int(number)) + if number is int_oo: + return sympy.oo + if number is -int_oo: + return -sympy.oo + + +class Identity(sympy.Function): + """ + Prevents expansion and other optimizations + """ + + precedence = 10 + + def __repr__(self) -> str: # type: ignore[override] + # pyrefly: ignore [missing-attribute] + return f"Identity({self.args[0]})" + + def _sympystr(self, printer) -> str: + """Controls how sympy's StrPrinter prints this""" + # pyrefly: ignore [missing-attribute] + return f"({printer.doprint(self.args[0])})" + + def _eval_is_real(self): + # pyrefly: ignore [missing-attribute] + return self.args[0].is_real + + def _eval_is_integer(self): + return self.args[0].is_integer # type: ignore[attr-defined] + + def _eval_expand_identity(self, **hints): + # Removes the identity op. + # pyrefly: ignore [missing-attribute] + return self.args[0] + + def __int__(self) -> int: + # pyrefly: ignore [missing-attribute] + return int(self.args[0]) + + def __float__(self) -> float: + # pyrefly: ignore [missing-attribute] + return float(self.args[0]) + + +def make_opaque_unary_fn(name): + class OpaqueUnaryFn(sympy.Function): + """ + Unlike the builtin sympy functions on real numbers like sympy.sqrt, + these equivalents do not do any nontrivial reasoning besides + constant propagation. This helps avoid performing transformations + that are valid for real numbers but are invalid for floating point; + in particular, while we are willing to make optimizations that change + numerics for Tensor compute, we are NOT willing to make optimizations + that change numerics for size compute. + """ + + _torch_handler_name = name + _torch_unpickler = make_opaque_unary_fn + + @classmethod + def eval(cls, a): + if isinstance(a, (sympy.Integer, sympy.Float)): + # Python converts to float64 before computing, c.f. + # >>> math.sin(2**53+1) + # -0.848925964814655 + # >>> math.sin(float(2**53+1)) + # -0.848925964814655 + try: + return sympy.Float(getattr(math, name)(float(a))) + # Just use sympy semantics for infinity/overflow, you might get some + # weird objects but ask silly questions, get silly answers + except OverflowError: + return getattr(sympy, name)(a) + elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]: + if a is int_oo: + a = sympy.oo + if a is -int_oo: + a = -sympy.oo + if name == "log2": + return sympy.log(a, 2) + return getattr(sympy, name)(a) + return None + + nm = "OpaqueUnaryFn_" + name + OpaqueUnaryFn.__name__ = nm + OpaqueUnaryFn.__qualname__ = nm + + return OpaqueUnaryFn + + +# Keep in sync with math_op_names in torch/fx/experimental/sym_node.py +OpaqueUnaryFn_sqrt = make_opaque_unary_fn("sqrt") +OpaqueUnaryFn_cos = make_opaque_unary_fn("cos") +OpaqueUnaryFn_cosh = make_opaque_unary_fn("cosh") +OpaqueUnaryFn_sin = make_opaque_unary_fn("sin") +OpaqueUnaryFn_sinh = make_opaque_unary_fn("sinh") +OpaqueUnaryFn_tan = make_opaque_unary_fn("tan") +OpaqueUnaryFn_tanh = make_opaque_unary_fn("tanh") +OpaqueUnaryFn_asin = make_opaque_unary_fn("asin") +OpaqueUnaryFn_acos = make_opaque_unary_fn("acos") +OpaqueUnaryFn_atan = make_opaque_unary_fn("atan") +OpaqueUnaryFn_exp = make_opaque_unary_fn("exp") +OpaqueUnaryFn_log = make_opaque_unary_fn("log") +OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh") +OpaqueUnaryFn_log2 = make_opaque_unary_fn("log2") + + +def make_opaque_bitwise_fn(name, real_op_name): + if name == "bitwise_and": + prec = PRECEDENCE["BitwiseAnd"] + elif name == "bitwise_xor": + prec = PRECEDENCE["BitwiseXor"] + elif name == "bitwise_or": + prec = PRECEDENCE["BitwiseOr"] + else: + raise AssertionError(f"unrecognized {name}") + + class BitwiseFn(sympy.Function): + _torch_handler_name = name + precedence: int = prec + _torch_unpickler = functools.partial( + make_opaque_bitwise_fn, real_op_name=real_op_name + ) + + @classmethod + def eval(cls, a, b): + if a.is_Boolean and b.is_Boolean: + return getattr(operator, real_op_name)(a, b) + if a.is_Boolean: + a = sympy.Integer(1 if a else 0) + if b.is_Boolean: + b = sympy.Integer(1 if b else 0) + if isinstance(a, (sympy.Integer, int)) and isinstance( + b, (sympy.Integer, int) + ): + return sympy.Integer(getattr(operator, real_op_name)(int(a), int(b))) + return None + + nm = "BitwiseFn_" + name + BitwiseFn.__name__ = nm + BitwiseFn.__qualname__ = nm + + return BitwiseFn + + +BitwiseFn_bitwise_and = make_opaque_bitwise_fn("bitwise_and", "and_") +BitwiseFn_bitwise_or = make_opaque_bitwise_fn("bitwise_or", "or_") +BitwiseFn_bitwise_xor = make_opaque_bitwise_fn("bitwise_xor", "xor") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/interp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/interp.py new file mode 100644 index 0000000000000000000000000000000000000000..6eca9e389d85ae452cbf357d01ca9278239a617d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/interp.py @@ -0,0 +1,228 @@ +# mypy: allow-untyped-defs +""" +This is a simple interpreter for Sympy expressions that dispatches to +classes following the torch._inductor.virtualized calling convention. +For directness, the interpreter takes the handler directly rather than +consulting the TLS. It does not use most of the methods on the full +handler; only those with corresponding Sympy expressions. To see an example +of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis. +""" + +import functools +import logging +from typing import Any + +import sympy +from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom + +import torch + +from .functions import ( + BitwiseFn_bitwise_and, + BitwiseFn_bitwise_or, + BitwiseFn_bitwise_xor, + CeilToInt, + CleanDiv, + FloatPow, + FloatTrueDiv, + FloorDiv, + FloorToInt, + Identity, + IntTrueDiv, + IsNonOverlappingAndDenseIndicator, + Max, + Min, + Mod, + ModularIndexing, + OpaqueUnaryFn_log2, + PowByNatural, + PythonMod, + RoundDecimal, + RoundToInt, + ToFloat, + TruncToFloat, + TruncToInt, + Where, +) + + +log = logging.getLogger(__name__) + + +# TODO: Dedupe this with SYMPY_INTERP + + +@functools.cache +def handlers(): + # TODO add CeilDiv (it doesn't appear in the index_expr) + + # TODO default to some decompositions if the interpreter doesn't have them + # like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a) + + HANDLERS = { + sympy.Or: "or_", + sympy.And: "and_", + sympy.Eq: "eq", + sympy.Ne: "ne", + sympy.Lt: "lt", + sympy.Gt: "gt", + sympy.Le: "le", + sympy.Ge: "ge", + sympy.Not: "not_", + IntTrueDiv: "int_truediv", + FloatTrueDiv: "truediv", + FloorDiv: "floordiv", + CleanDiv: "floordiv", # TODO: hmm? + TruncToFloat: "trunc", + Where: "where", + sympy.Add: "add", + sympy.Mul: "mul", + FloatPow: "pow", + PowByNatural: "pow_by_natural", + # sympy simplifies x * x into Pow(x, 2), so we need to handle this. + # Do NOT use builtin Pow for floats + # TODO: There is a hazard here, if we have float * float it will + # also get turned into Pow(float, 2) but we don't want this because + # pow_by_natural is assumed to only be integers. Probably the fix is + # to add a FloatMul to impede this optimization + sympy.Pow: "pow_by_natural", + Mod: "mod", + PythonMod: "python_mod", + # TODO: Inductor can generate these, but it's ill-specified which + # semantics were intended here. Needs to be cleaned up along with + # FloorDiv in a bigger cleanup + sympy.Mod: "mod", + sympy.Abs: "abs", + sympy.log: "log", + sympy.exp: "exp", + sympy.Min: "minimum", + sympy.Max: "maximum", + Min: "minimum", + Max: "maximum", + ModularIndexing: "modular_indexing", + sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", + sympy.Piecewise: "piecewise", + Identity: "identity", + IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", + RoundDecimal: "round_decimal", + # TODO: do the rest of the opaque unary functions... + OpaqueUnaryFn_log2: "log2", + BitwiseFn_bitwise_and: "bitwise_and", + BitwiseFn_bitwise_or: "bitwise_or", + BitwiseFn_bitwise_xor: "bitwise_xor", + } + # TODO: This is kind of pointless, we shouldn't be generating sympy.sin + # for these functions, they should be Opaque instead + for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]: + HANDLERS[getattr(sympy, name)] = name + + return HANDLERS + + +ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"} + + +def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64): + # Special cases + if isinstance(expr, sympy.Pow) and isinstance( + expr.args[1], sympy.core.numbers.Half + ): + return analysis.sqrt(args[0]) + if isinstance(expr, ToFloat): + return analysis.to_dtype(args[0], torch.float64) + + # These handlers are special because they take an extra dtype argument + # specifying what they should convert to, and we need to appropriately set + # this up when we convert from Sympy. A reasonable default when you + # are translating is to conservatively do int64, and then narrow these + # arguments later when you discover you can narrow the index range. But + # if you already know that 32-bit indexing is OK, you can directly do the + # sympy translation with index_dtype=torch.int32 + INDEX_DTYPE_HANDLERS = { + TruncToInt: "trunc_to_int", + sympy.floor: "floor_to_int", + sympy.ceiling: "ceil_to_int", + FloorToInt: "floor_to_int", + CeilToInt: "ceil_to_int", + RoundToInt: "round_to_int", + } + if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None: + return getattr(analysis, handler_name)(*args, index_dtype) + + # Fastpath for n-ary integral addition + if expr.func is sympy.Add and expr.is_integer and hasattr(analysis, "sym_sum"): + r = analysis.sym_sum(args) + log.debug("sym_sum(%s) -> %s", args, r) + return r + + if hasattr(expr.func, "_torch_handler_name"): + handler_name = expr.func._torch_handler_name + else: + handler_name = handlers()[expr.func] + handler = getattr(analysis, handler_name) + try: + if handler_name in ASSOCIATIVE_OPS: + if len(args) <= 1: + raise AssertionError("associative op needs >1 args") + acc = handler(args[0], args[1]) + for i in range(2, len(args)): + acc = handler(acc, args[i]) + log.debug("%s(%s) -> %s", handler_name, args, acc) + return acc + else: + r = handler(*args) + log.debug("%s(%s) -> %s", handler_name, args, r) + return r + except NotImplementedError: + raise + except Exception: + log.warning("failed while executing %s(%s)", handler_name, args) + raise + + +_nil = object() + + +def sympy_interp( + analysis, + env: dict[sympy.Symbol, Any], + expr: sympy.Expr | SympyBoolean, + *, + index_dtype=torch.int64, + missing_handler=None, +): + # Handle base cases + dtype = None + if isinstance(expr, BooleanAtom): + dtype = torch.bool + elif isinstance(expr, sympy.Integer): + dtype = torch.int64 + elif isinstance(expr, sympy.Number): + dtype = torch.double + + if dtype is not None: + return analysis.constant(expr, dtype) + elif isinstance(expr, sympy.Symbol): + if (r := env.get(expr, _nil)) is not _nil: + return r + elif missing_handler: + return missing_handler(expr) + else: + raise KeyError(expr) + + # Recursive case + return _run_sympy_handler( + analysis, + [ + sympy_interp( + analysis, + env, + arg, + index_dtype=index_dtype, + missing_handler=missing_handler, + ) + for arg in expr.args + ], + expr, + index_dtype=index_dtype, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/numbers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..8b08e01d8e52bbed86c4630a88974172fe096ab4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/numbers.py @@ -0,0 +1,399 @@ +# mypy: allow-untyped-defs +import mpmath.libmp as mlib # type: ignore[import-untyped] +import sympy +from sympy import Expr +from sympy.core.decorators import _sympifyit +from sympy.core.expr import AtomicExpr +from sympy.core.numbers import Number +from sympy.core.parameters import global_parameters +from sympy.core.singleton import S, Singleton + + +# pyrefly: ignore [invalid-inheritance] +class IntInfinity(Number, metaclass=Singleton): + r"""Positive integer infinite quantity. + + Integer infinity is a value in an extended integers which + is greater than all other integers. We distinguish it from + sympy's existing notion of infinity in that it reports that + it is_integer. + + Infinity is a singleton, and can be accessed by ``S.IntInfinity``, + or can be imported as ``int_oo``. + """ + + # NB: We can't actually mark this as infinite, as integer and infinite are + # inconsistent assumptions in sympy. We also report that we are complex, + # different from sympy.oo + + is_integer = True + is_commutative = True + is_number = True + is_extended_real = True + is_comparable = True + is_extended_positive = True + is_prime = False + + # Ensure we get dispatched to before plain numbers + _op_priority = 100.0 + + __slots__ = () + + def __new__(cls): + return AtomicExpr.__new__(cls) + + def _sympystr(self, printer) -> str: + return "int_oo" + + def _eval_subs(self, old, new): + if self == old: + return new + + # We could do these, not sure about it + """ + def _eval_evalf(self, prec=None): + return Float('inf') + + def evalf(self, prec=None, **options): + return self._eval_evalf(prec) + """ + + @_sympifyit("other", NotImplemented) + def __add__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other in (S.Infinity, S.NegativeInfinity): + return other + if other in (S.NegativeIntInfinity, S.NaN): + return S.NaN + return self + return Number.__add__(self, other) + + __radd__ = __add__ + + @_sympifyit("other", NotImplemented) + def __sub__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other is S.Infinity: + return S.NegativeInfinity + if other is S.NegativeInfinity: + return S.Infinity + if other in (S.IntInfinity, S.NaN): + return S.NaN + return self + return Number.__sub__(self, other) + + @_sympifyit("other", NotImplemented) + def __rsub__(self, other): + return (-self).__add__(other) + + @_sympifyit("other", NotImplemented) + def __mul__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other.is_zero or other is S.NaN: + return S.NaN + if other.is_extended_positive: + return self + return S.NegativeIntInfinity + return Number.__mul__(self, other) + + __rmul__ = __mul__ + + @_sympifyit("other", NotImplemented) + def __truediv__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other in ( + S.Infinity, + S.IntInfinity, + S.NegativeInfinity, + S.NegativeIntInfinity, + S.NaN, + ): + return S.NaN + if other.is_extended_nonnegative: + return S.Infinity # truediv produces float + return S.NegativeInfinity # truediv produces float + return Number.__truediv__(self, other) + + def __abs__(self): + return S.IntInfinity + + def __neg__(self): + return S.NegativeIntInfinity + + def _eval_power(self, expt): + if expt.is_extended_positive: + return S.IntInfinity + if expt.is_extended_negative: + return S.Zero + if expt is S.NaN: + return S.NaN + if expt is S.ComplexInfinity: + return S.NaN + if expt.is_extended_real is False and expt.is_number: + from sympy.functions.elementary.complexes import re + + expt_real = re(expt) + if expt_real.is_positive: + return S.ComplexInfinity + if expt_real.is_negative: + return S.Zero + if expt_real.is_zero: + return S.NaN + + return self ** expt.evalf() + + def _as_mpf_val(self, prec): + return mlib.finf + + def __hash__(self): + return super().__hash__() + + def __eq__(self, other): + return other is S.IntInfinity + + def __ne__(self, other): + return other is not S.IntInfinity + + def __gt__(self, other): + if other is S.Infinity: + return sympy.false # sympy.oo > int_oo + elif other is S.IntInfinity: + return sympy.false # consistency with sympy.oo + else: + return sympy.true + + def __ge__(self, other): + if other is S.Infinity: + return sympy.false # sympy.oo > int_oo + elif other is S.IntInfinity: + return sympy.true # consistency with sympy.oo + else: + return sympy.true + + def __lt__(self, other): + if other is S.Infinity: + return sympy.true # sympy.oo > int_oo + elif other is S.IntInfinity: + return sympy.false # consistency with sympy.oo + else: + return sympy.false + + def __le__(self, other): + if other is S.Infinity: + return sympy.true # sympy.oo > int_oo + elif other is S.IntInfinity: + return sympy.true # consistency with sympy.oo + else: + return sympy.false + + @_sympifyit("other", NotImplemented) + def __mod__(self, other): + if not isinstance(other, Expr): + return NotImplemented + return S.NaN + + __rmod__ = __mod__ + + def floor(self): + return self + + def ceiling(self): + return self + + +int_oo = S.IntInfinity + + +# pyrefly: ignore [invalid-inheritance] +class NegativeIntInfinity(Number, metaclass=Singleton): + """Negative integer infinite quantity. + + NegativeInfinity is a singleton, and can be accessed + by ``S.NegativeInfinity``. + + See Also + ======== + + IntInfinity + """ + + # Ensure we get dispatched to before plain numbers + _op_priority = 100.0 + + is_integer = True + is_extended_real = True + is_commutative = True + is_comparable = True + is_extended_negative = True + is_number = True + is_prime = False + + __slots__ = () + + def __new__(cls): + return AtomicExpr.__new__(cls) + + def _eval_subs(self, old, new): + if self == old: + return new + + def _sympystr(self, printer) -> str: + return "-int_oo" + + """ + def _eval_evalf(self, prec=None): + return Float('-inf') + + def evalf(self, prec=None, **options): + return self._eval_evalf(prec) + """ + + @_sympifyit("other", NotImplemented) + def __add__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other is S.Infinity: + return S.Infinity + if other in (S.IntInfinity, S.NaN): + return S.NaN + return self + return Number.__add__(self, other) + + __radd__ = __add__ + + @_sympifyit("other", NotImplemented) + def __sub__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other is S.NegativeInfinity: + return S.Infinity + if other in (S.NegativeIntInfinity, S.NaN): + return S.NaN + return self + return Number.__sub__(self, other) + + @_sympifyit("other", NotImplemented) + def __rsub__(self, other): + return (-self).__add__(other) + + @_sympifyit("other", NotImplemented) + def __mul__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other.is_zero or other is S.NaN: + return S.NaN + if other.is_extended_positive: + return self + return S.IntInfinity + return Number.__mul__(self, other) + + __rmul__ = __mul__ + + @_sympifyit("other", NotImplemented) + def __truediv__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other in ( + S.Infinity, + S.IntInfinity, + S.NegativeInfinity, + S.NegativeIntInfinity, + S.NaN, + ): + return S.NaN + if other.is_extended_nonnegative: + return self + return S.Infinity # truediv returns float + return Number.__truediv__(self, other) + + def __abs__(self): + return S.IntInfinity + + def __neg__(self): + return S.IntInfinity + + def _eval_power(self, expt): + if expt.is_number: + if expt in ( + S.NaN, + S.Infinity, + S.NegativeInfinity, + S.IntInfinity, + S.NegativeIntInfinity, + ): + return S.NaN + + if isinstance(expt, sympy.Integer) and expt.is_extended_positive: + if expt.is_odd: + return S.NegativeIntInfinity + else: + return S.IntInfinity + + inf_part = S.IntInfinity**expt + s_part = S.NegativeOne**expt + if inf_part == 0 and s_part.is_finite: + return inf_part + if ( + inf_part is S.ComplexInfinity + and s_part.is_finite + and not s_part.is_zero + ): + return S.ComplexInfinity + return s_part * inf_part + + def _as_mpf_val(self, prec): + return mlib.fninf + + def __hash__(self): + return super().__hash__() + + def __eq__(self, other): + return other is S.NegativeIntInfinity + + def __ne__(self, other): + return other is not S.NegativeIntInfinity + + def __gt__(self, other): + if other is S.NegativeInfinity: + return sympy.true # -sympy.oo < -int_oo + elif other is S.NegativeIntInfinity: + return sympy.false # consistency with sympy.oo + else: + return sympy.false + + def __ge__(self, other): + if other is S.NegativeInfinity: + return sympy.true # -sympy.oo < -int_oo + elif other is S.NegativeIntInfinity: + return sympy.true # consistency with sympy.oo + else: + return sympy.false + + def __lt__(self, other): + if other is S.NegativeInfinity: + return sympy.false # -sympy.oo < -int_oo + elif other is S.NegativeIntInfinity: + return sympy.false # consistency with sympy.oo + else: + return sympy.true + + def __le__(self, other): + if other is S.NegativeInfinity: + return sympy.false # -sympy.oo < -int_oo + elif other is S.NegativeIntInfinity: + return sympy.true # consistency with sympy.oo + else: + return sympy.true + + @_sympifyit("other", NotImplemented) + def __mod__(self, other): + if not isinstance(other, Expr): + return NotImplemented + return S.NaN + + __rmod__ = __mod__ + + def floor(self): + return self + + def ceiling(self): + return self + + def as_powers_dict(self): + return {S.NegativeOne: 1, S.IntInfinity: 1} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/printers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/printers.py new file mode 100644 index 0000000000000000000000000000000000000000..7006b7f7fdc65552a239d805dc30119e8ac2acf5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/printers.py @@ -0,0 +1,593 @@ +import sys + +import sympy +from sympy.printing.precedence import PRECEDENCE, precedence +from sympy.printing.str import StrPrinter + + +INDEX_TYPE = "int64_t" +INDEX_TYPE_MAX = (1 << 63) - 1 +INDEX_TYPE_MIN = -1 << 63 + + +# This printer contains rules that are supposed to be generic for both C/C++ and +# Python +class ExprPrinter(StrPrinter): + # override this so that _print_FloorDiv is used + printmethod = "_torch_sympystr" + + def _print_Mul(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, "*", precedence(expr)) + + def _print_Not(self, expr: sympy.Expr) -> str: + return f"not ({self._print(expr.args[0])})" + + def _print_Add(self, expr: sympy.Expr, order: str | None = None) -> str: + return self.stringify(expr.args, " + ", precedence(expr)) + + def _print_Relational(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr)) + + def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " & ", PRECEDENCE["BitwiseAnd"]) + + def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " | ", PRECEDENCE["BitwiseOr"]) + + def _print_BitwiseFn_bitwise_xor(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " ^ ", PRECEDENCE["BitwiseXor"]) + + # NB: this is OK to put here, because Mod is only defined for positive + # numbers, and so across C/Python its behavior is consistent + def _print_Mod(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) + + def _print_FloatTrueDiv(self, expr: sympy.Expr) -> str: + s = self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5) + return f"({s})" + + def _print_CleanDiv(self, expr: sympy.Expr) -> str: + return self._print_FloorDiv(expr) + + def _print_Identity(self, expr: sympy.Expr) -> str: + return self._print(expr.args[0]) + + def _print_Float(self, expr: sympy.Expr) -> str: + if expr._prec == 53: + # IEEE-754 double precision have 53 bits. SymPy prints them with + # 15 digits, but we need 17 for round-trip correctness + return str(sympy.Float(expr, dps=17)) + else: + # We don't use other precisions in pytorch + return str(expr) + + # This must be implemented because sympy will collect x * x into Pow(x, 2), without + # any explicit intervention. We print it just like x * x, notably, we + # never generate sympy.Pow with floats. + # + # NB: this pow by natural, you should never have used builtin sympy.pow + # for FloatPow, and a symbolic exponent should be PowByNatural. These + # means exp is guaranteed to be integer. + # pyrefly: ignore [bad-override] + def _print_Pow(self, expr: sympy.Expr) -> str: + base, exp = expr.args + if exp != int(exp): + raise AssertionError(exp) + exp = int(exp) + if exp < 0: + raise AssertionError(f"exponent must be non-negative, got {exp}") + if exp > 0: + return self.stringify([base] * exp, "*", PRECEDENCE["Mul"]) + return "1" + + # Explicit NotImplemented functions are to prevent default sympy printing + # behavior, which will just barf out ToFloat(...) to your IR. The error + # message is better here because it tells you which printer class it needs + # to go in. + + def _print_ToFloat(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") + + def _print_Infinity(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") + + def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: + raise NotImplementedError( + f"_print_NegativeInfinity not implemented for {type(self)}" + ) + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + + def _print_PythonMod(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") + + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") + + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatPow(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") + + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + raise NotImplementedError( + f"_print_RoundDecimal not implemented for {type(self)}" + ) + + # NB: Some float operations are INTENTIONALLY not implemented for + # printers. You can implement them as a quick unblock, but it is better + # to ask yourself why we haven't done this computation in the Tensor + # universe instead + + def _print_TruncToFloat(self, expr: sympy.Expr) -> str: + raise NotImplementedError( + f"_print_TruncToFloat not implemented for {type(self)}" + ) + + +class PythonPrinter(ExprPrinter): + def _print_ToFloat(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("ToFloat expects exactly one argument") + # NB: We use sym_float here because the printer is used for cache + # serialization, and cache guards get evaluated with SymInt to + # propagate guards to the parent ShapeEnv. However, this comes at a + # runtime cost for guards involving float. If this is unacceptable + # overhead, what you want to do is have two separate printers for + # SymInt, one for when the inputs are guaranteed to be int, and + # another for when they could be SymInt. + # + # NB: sym_min/sym_max also have this problem, but I chose not to fix + # those. + # + # See https://github.com/pytorch/pytorch/issues/142507 for more + # context. + return f"torch.sym_float({self._print(expr.args[0])})" + + def _print_And(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " and ", precedence(expr)) + + def _print_Or(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " or ", precedence(expr)) + + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: + x, div, mod = ( + self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args + ) + if div != "1": + x = f"({x} // {div})" + return f"({x} % {mod})" + + def _print_Infinity(self, expr: sympy.Expr) -> str: + return "math.inf" + + def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: + return "-math.inf" + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_PythonMod(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + x, div = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args) + return f"{x} // {div}" + + # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python + # does a special algorithm + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5) + + def _helper_sqrt(self, expr: sympy.Expr) -> str: + return f"math.sqrt({self._print(expr)})" + + def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str: + return self._helper_sqrt(expr.args[0]) + + def _print_FloatPow(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"]) + + # TODO: Not sure this works with Triton, even when base/exp are integral + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"]) + + def _print_floor(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("floor expects exactly one argument") + return f"math.floor({self._print(expr.args[0])})" + + def _print_FloorToInt(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("FloorToInt expects exactly one argument") + return f"math.floor({self._print(expr.args[0])})" + + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("TruncToInt expects exactly one argument") + # This also could have been int(), they'll do the same thing for float + return f"math.trunc({self._print(expr.args[0])})" + + def _print_ceiling(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("ceiling expects exactly one argument") + return f"math.ceil({self._print(expr.args[0])})" + + def _print_CeilToInt(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("CeilToInt expects exactly one argument") + return f"math.ceil({self._print(expr.args[0])})" + + def _print_Abs(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("Abs expects exactly one argument") + return f"abs({self._print(expr.args[0])})" + + # NB: It's expected that we've made explicit any promotion in the sympy + # expression, so it doesn't matter that Python max/min doesn't perform + # promotion + def _print_Max(self, expr: sympy.Expr) -> str: + if len(expr.args) < 2: + raise AssertionError("Max expects at least two arguments") + return f"max({', '.join(map(self._print, expr.args))})" + + def _print_Min(self, expr: sympy.Expr) -> str: + if len(expr.args) < 2: + raise AssertionError("Min expects at least two arguments") + return f"min({', '.join(map(self._print, expr.args))})" + + def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("cos expects exactly one argument") + return f"math.cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("cosh expects exactly one argument") + return f"math.cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("acos expects exactly one argument") + return f"math.acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("sin expects exactly one argument") + return f"math.sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("sinh expects exactly one argument") + return f"math.sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("asin expects exactly one argument") + return f"math.asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("tan expects exactly one argument") + return f"math.tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("tanh expects exactly one argument") + return f"math.tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("atan expects exactly one argument") + return f"math.atan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("log2 expects exactly one argument") + return f"math.log2({self._print(expr.args[0])})" + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("RoundToInt expects exactly one argument") + return f"round({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + if len(expr.args) != 2: + raise AssertionError("RoundDecimal expects exactly two arguments") + number, ndigits = expr.args + if not isinstance(ndigits, sympy.Integer): + raise TypeError("ndigits must be an instance of sympy.Integer") + return f"round({self._print(number)}, {ndigits})" + + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary expressions + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: e1 if c1 else (e2 if c2 else (... else eN)) + result: str | None = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self._print(expr_i) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self._print(cond_i) + if result is None: + result = expr_str + else: + result = f"({expr_str} if {cond_str} else {result})" + return result if result else "0" + + +class CppPrinter(ExprPrinter): + def _print_Integer(self, expr: sympy.Expr) -> str: + suffix = "LL" if sys.platform in ["darwin", "win32"] else "L" + i = int(expr) + if i > INDEX_TYPE_MAX or i < INDEX_TYPE_MIN: + raise OverflowError(f"{i} too big to convert to {INDEX_TYPE}") + elif i == INDEX_TYPE_MIN: + if i != (-1) << 63: + raise AssertionError("unexpected minimum index type value") + # Writing -9223372036854775808L makes the value overflow + # as it is parsed as -(9223372036854775808L) by the C/C++ compiler + return f"(-1{suffix} << 63)" + return f"{i}{suffix}" + + def _print_Where(self, expr: sympy.Expr) -> str: + c, p, q = ( + self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args + ) + return f"{c} ? {p} : {q}" + + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary operators + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: c1 ? e1 : (c2 ? e2 : (... : eN)) + result: str | None = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5) + if result is None: + result = expr_str + else: + result = f"{cond_str} ? {expr_str} : {result}" + return f"({result})" if result else "0" + + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: + x, div, mod = expr.args + x = self.doprint(x) + if div != 1: + div = self.doprint(div) + if expr.is_integer: + x = f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" + else: + x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + mod = self.doprint(mod) + return f"(static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod}))" + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + x, div = expr.args + x = self.doprint(x) + div = self.doprint(div) + if expr.is_integer: + return f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" + return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + + def _print_floor(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("floor expects exactly one argument") + r = f"std::floor({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_FloorToInt(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("FloorToInt expects exactly one argument") + r = f"std::floor({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("TruncToInt expects exactly one argument") + r = f"std::trunc({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" + + def _print_TruncToFloat(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("TruncToFloat expects exactly one argument") + return f"std::trunc({self._print(expr.args[0])})" + + def _print_ToFloat(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("ToFloat expects exactly one argument") + return f"static_cast({self._print(expr.args[0])})" + + def _print_PythonMod(self, expr: sympy.Expr) -> str: + x, div = expr.args + x = self.doprint(x) + div = self.doprint(div) + return f"c10::div_mod({x}, {div})" + + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + lhs, rhs = expr.args + # TODO: This is only accurate up to 2**53 + return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" + + # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT + # use std::pow, that operates on floats + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + # Implement the special-case of 2**x for now + base, exp = expr.args + if base == 2: + return f"(1 << ({self._print(exp)}))" + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatPow(self, expr: sympy.Expr) -> str: + base, exp = expr.args + return f"std::pow({self._print(base)}, {self._print(exp)})" + + def _print_Pow(self, expr: sympy.Expr) -> str: + # Uses float constants to perform FP div + base, exp = expr.args + + if exp == 0.5 or exp == -0.5: + base = self._print(base) + return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})" + if exp.is_integer: + exp = int(exp) + if exp > 0: + r = self.stringify([base] * exp, "*", PRECEDENCE["Mul"]) + elif exp < -1: + r = ( + "1.0/(" + + self.stringify([base] * abs(exp), "*", PRECEDENCE["Mul"]) + + ")" + ) + elif exp == -1: + r = "1.0/" + self._print(base) + else: # exp == 0 + r = "1.0" + + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + else: + # TODO: float vs double + return f"std::pow({base}, {float(exp)})" + + def _print_Rational(self, expr: sympy.Expr) -> str: + # Uses float constants to perform FP div + if expr.q == 1: + r = f"{expr.p}" + else: + r = f"{expr.p}.0/{expr.q}.0" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_ceiling(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("ceiling expects exactly one argument") + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_CeilToInt(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("CeilToInt expects exactly one argument") + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Min(self, expr: sympy.Expr) -> str: + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::min<{INDEX_TYPE}>({il})" + + def _print_Max(self, expr: sympy.Expr) -> str: + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::max<{INDEX_TYPE}>({il})" + + def _print_Abs(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("Abs expects exactly one argument") + return f"std::abs({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("cos expects exactly one argument") + return f"std::cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("cosh expects exactly one argument") + return f"std::cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("acos expects exactly one argument") + return f"std::acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("sin expects exactly one argument") + return f"math.sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("sinh expects exactly one argument") + return f"std::sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("asin expects exactly one argument") + return f"std::asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("tan expects exactly one argument") + return f"std::tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("tanh expects exactly one argument") + return f"std::tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("atan expects exactly one argument") + return f"std::atan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str: + return f"std::sqrt({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + return f"std::log2({self._print(expr.args[0])})" + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + if len(expr.args) != 1: + raise AssertionError("RoundToInt expects exactly one argument") + # TODO: dispatch to llrint depending on index type + return f"std::lrint({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + if len(expr.args) != 2: + raise AssertionError("RoundDecimal expects exactly two arguments") + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + if ndigits >= 0: + raise AssertionError("ndigits must be negative for integer inputs") + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + number_str = self.parenthesize(number, PRECEDENCE["Mul"]) + return f"static_cast(std::nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits})" + + def _print_BooleanTrue(self, expr: sympy.Expr) -> str: + return "true" + + def _print_BooleanFalse(self, expr: sympy.Expr) -> str: + return "false" + + def _print_Infinity(self, expr: sympy.Expr) -> str: + return "std::numeric_limits::infinity()" + + def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: + return f"-{self._print_Infinity(expr)}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/reference.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/reference.py new file mode 100644 index 0000000000000000000000000000000000000000..015285eaaa1b6e8d59c57a5c943e7f2c73768a4f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/reference.py @@ -0,0 +1,600 @@ +# mypy: allow-untyped-defs +import math +import operator +from typing import NoReturn + +import sympy + +import torch +from torch.utils._sympy.functions import ( + _keep_float, + BitwiseFn_bitwise_and, + BitwiseFn_bitwise_or, + BitwiseFn_bitwise_xor, + FloatPow, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, + Max, + Min, + Mod, + OpaqueUnaryFn_exp, + OpaqueUnaryFn_log, + OpaqueUnaryFn_log2, + OpaqueUnaryFn_sqrt, + PowByNatural, + RoundDecimal, + RoundToInt, + ToFloat, + TruncToInt, +) + + +# The sympy interpretation of operators. It will also sometimes work with +# plain int/float, but if you do certain operations you will get out a +# sympy.Basic in the end. If you want the Python/FX traceable interpretation, +# check PythonReferenceAnalysis. +# NB: For magic methods this needs to use normal magic methods +# so that test_magic_methods works +class ReferenceAnalysis: + @staticmethod + def constant(c, dtype): + return sympy.sympify(c) + + @staticmethod + def or_(a, b): + return a | b + + @staticmethod + def and_(a, b): + return a & b + + @staticmethod + def eq(a, b): + if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr): + return sympy.Eq(a, b) + return a == b + + @classmethod + def ne(cls, a, b): + return cls.not_(cls.eq(a, b)) + + @staticmethod + def lt(a, b): + return a < b + + @staticmethod + def gt(a, b): + return a > b + + @staticmethod + def le(a, b): + return a <= b + + @staticmethod + def ge(a, b): + return a >= b + + @staticmethod + def not_(a): + if isinstance(a, bool): + raise AssertionError("not_ needs sympy expr") + return ~a + + @staticmethod + def reciprocal(x): + return FloatTrueDiv(1.0, x) + + @staticmethod + def square(x): + return PowByNatural(x, 2) + + @staticmethod + def trunc_to_int(x, dtype): + return TruncToInt(x) + + @staticmethod + def ceil_to_int(x, dtype): + return sympy.ceiling(x) + + @staticmethod + def floor_to_int(x, dtype): + return sympy.floor(x) + + @staticmethod + def floor(x): + return _keep_float(sympy.floor)(x) + + @staticmethod + def ceil(x): + return _keep_float(sympy.ceiling)(x) + + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return ToFloat(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + + @staticmethod + def mod(x, y): + return Mod(x, y) + + @staticmethod + def abs(x): + return abs(x) + + @staticmethod + def neg(x): + return -x + + @staticmethod + def truediv(a, b): + return FloatTrueDiv(a, b) + + @staticmethod + def int_truediv(a, b): + return IntTrueDiv(a, b) + + @staticmethod + def floordiv(a, b): + return FloorDiv(a, b) + + @staticmethod + def truncdiv(a, b) -> NoReturn: + raise NotImplementedError("TODO: truncdiv") + + @staticmethod + def add(a, b): + return _keep_float(operator.add)(a, b) + + @classmethod + def sym_sum(cls, args): + return sympy.Add(*args) + + @staticmethod + def mul(a, b): + return _keep_float(operator.mul)(a, b) + + @staticmethod + def sub(a, b): + return _keep_float(operator.sub)(a, b) + + @staticmethod + def exp(x): + return OpaqueUnaryFn_exp(x) + + @staticmethod + def log(x): + return OpaqueUnaryFn_log(x) + + @staticmethod + def log2(x): + return OpaqueUnaryFn_log2(x) + + @staticmethod + def sqrt(x): + return OpaqueUnaryFn_sqrt(x) + + @staticmethod + def pow(a, b): + # pyrefly: ignore [bad-argument-type] + return _keep_float(FloatPow)(a, b) + + @staticmethod + def pow_by_natural(a, b): + return PowByNatural(a, b) + + @staticmethod + def minimum(a, b): + return Min(a, b) + + @staticmethod + def maximum(a, b): + return Max(a, b) + + @staticmethod + def round_to_int(a, dtype): + return RoundToInt(a) + + @staticmethod + def round_decimal(a, b): + return RoundDecimal(a, b) + + @staticmethod + def bitwise_and(a, b): + return BitwiseFn_bitwise_and(a, b) + + @staticmethod + def bitwise_or(a, b): + return BitwiseFn_bitwise_or(a, b) + + @staticmethod + def bitwise_xor(a, b): + return BitwiseFn_bitwise_xor(a, b) + + +# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain +# Python types and is FX traceable. Inheritance here is purely for code +# sharing (TODO: considering splitting out a BaseReferenceAnalysis). +class PythonReferenceAnalysis(ReferenceAnalysis): + @staticmethod + def constant(c, dtype): + if dtype is torch.int64: + return int(c) + elif dtype is torch.double: + return float(c) + elif dtype is torch.bool: + return bool(c) + else: + raise AssertionError(f"unrecognized dtype {dtype}") + + @staticmethod + def not_(a): + return torch.sym_not(a) + + @classmethod + def sym_sum(cls, args): + if len(args) == 0: + return 0 + if len(args) == 1: + return args[0] + acc = cls.add(args[0], args[1]) + for i in range(2, len(args)): + acc = cls.add(acc, args[i]) + return acc + + @staticmethod + def floordiv(a, b): + return a // b + + @staticmethod + def mod(x, y): + return x % y + + @staticmethod + def python_mod(x, y): + return x % y + + @staticmethod + def truncdiv(a, b): + return a / b + + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return torch.sym_float(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + + @staticmethod + def exp(x) -> NoReturn: + raise AssertionError("exp is not valid shape sympy expr") + + @staticmethod + def log(x) -> NoReturn: + raise AssertionError("log is not valid shape sympy expr") + + @staticmethod + def log2(x): + return torch._sym_log2(x) # type: ignore[attr-defined] + + @staticmethod + def sqrt(x): + return torch._sym_sqrt(x) # type: ignore[attr-defined] + + @staticmethod + def minimum(a, b): + return torch.sym_min(a, b) + + @staticmethod + def maximum(a, b): + return torch.sym_max(a, b) + + @staticmethod + def floor_to_int(x, dtype): + return math.floor(x) + + @staticmethod + def ceil_to_int(x, dtype): + return math.ceil(x) + + @staticmethod + def floor(x): + return float(math.floor(x)) + + @staticmethod + def ceil(x): + return float(math.ceil(x)) + + @staticmethod + def truediv(a, b): + return a / b + + @staticmethod + def pow(a, b): + return a**b + + @staticmethod + def pow_by_natural(a, b): + # Pray that safe_pow is not needed here lol. In particular, this + # never participates in VR low/high ranges, so overflow should be + # unlikely + return a**b + + @staticmethod + def round_to_int(a, dtype): + return round(a) + + @staticmethod + def round_decimal(a, b): + return round(a, ndigits=b) + + @staticmethod + def bitwise_and(a, b): + return a & b + + @staticmethod + def bitwise_or(a, b): + return a | b + + @staticmethod + def bitwise_xor(a, b): + return a ^ b + + +# Like PythonReferenceAnalysis, but some export-unfriendly choices of +# operators to make things faster +class OptimizedPythonReferenceAnalysis(PythonReferenceAnalysis): + @staticmethod + def sym_sum(args): + return torch.sym_sum(args) + + +def _to_dtype(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return torch.ops.prims.convert_element_type.default(x, dtype) + + +# Suppose we have some int/float arguments. This diagram commutes: +# +# int/float -- PythonReferenceAnalysis.op --> int/float +# | | +# | | +# torch.tensor(..., dtype=torch.int64/torch.float64) +# | | +# V V +# Tensor -- TensorReferenceAnalysis.op --> Tensor +# +# NB: int before and after must be representable in int64 (we will +# insert guards accordingly.) +# +# This is guaranteed to be FX traceable with OpOverloads only. +class TensorReferenceAnalysis: + # NB: This is actually dead, because with Proxy tracing the factory + # function isn't traced correctly. Here for completeness. + @staticmethod + def constant(c, dtype): + d: int | float | bool + if dtype is torch.int64: + d = int(c) + elif dtype is torch.double: + d = float(c) + elif dtype is torch.bool: + d = bool(c) + else: + raise AssertionError(f"unrecognized dtype {dtype}") + return torch.ops.aten.scalar_tensor.default(d, dtype=dtype) + + @staticmethod + def or_(a, b): + return torch.ops.aten.logical_or.default(a, b) + + @staticmethod + def and_(a, b): + return torch.ops.aten.logical_and.default(a, b) + + @staticmethod + def bitwise_and(a, b): + return torch.ops.aten.bitwise_and(a, b) + + @staticmethod + def bitwise_or(a, b): + return torch.ops.aten.bitwise_or(a, b) + + @staticmethod + def bitwise_xor(a, b): + return torch.ops.aten.bitwise_xor(a, b) + + @staticmethod + def eq(a, b): + return torch.ops.aten.eq.Tensor(a, b) + + @classmethod + def ne(cls, a, b): + return torch.ops.aten.ne.Tensor(a, b) + + @staticmethod + def lt(a, b): + return torch.ops.aten.lt.Tensor(a, b) + + @staticmethod + def gt(a, b): + return torch.ops.aten.gt.Tensor(a, b) + + @staticmethod + def le(a, b): + return torch.ops.aten.le.Tensor(a, b) + + @staticmethod + def ge(a, b): + return torch.ops.aten.ge.Tensor(a, b) + + @staticmethod + def not_(a): + return torch.ops.aten.logical_not.default(a) + + @staticmethod + def reciprocal(x): + return torch.ops.aten.reciprocal.default(x) + + @staticmethod + def square(x): + # TODO: maybe composite implicit autograd doesn't work here? + return torch.ops.aten.square.default(x) + + @staticmethod + def trunc_to_int(x, dtype): + return _to_dtype(torch.ops.aten.trunc.default(x), dtype) + + @staticmethod + def ceil_to_int(x, dtype): + return _to_dtype(torch.ops.aten.ceil.default(x), dtype) + + @staticmethod + def floor_to_int(x, dtype): + return _to_dtype(torch.ops.aten.floor.default(x), dtype) + + @staticmethod + def floor(x): + return torch.ops.aten.floor.default(x) + + @staticmethod + def ceil(x): + return torch.ops.aten.ceil.default(x) + + @staticmethod + def to_dtype(x, dtype): + return _to_dtype(x, dtype) + + @staticmethod + def mod(x, y) -> NoReturn: + # TODO: https://github.com/pytorch/pytorch/pull/133654 + raise NotImplementedError( + "no C-style modulus operation available from frontend atm" + ) + + @staticmethod + def abs(x): + return torch.ops.aten.abs.default(x) + + @staticmethod + def neg(x): + return torch.ops.aten.neg.default(x) + + @staticmethod + def truediv(a, b): + return torch.ops.aten.true_divide.Tensor(a, b) + + @staticmethod + def int_truediv(a, b): + raise NotImplementedError( + "Python int truediv difficult to implement in PyTorch atm" + ) + + # TODO: This is wrong, CPython has a custom implementation of true + # division that results in higher precision when the floats are + # sufficiently large. Short term fix: add a guard here + return torch.ops.aten.true_divide.default( + _to_dtype(a, torch.float64), _to_dtype(b, torch.float64) + ) + + @staticmethod + def floordiv(a, b): + return torch.ops.aten.div.Tensor_mode(a, b, rounding_mode="floor") + + @staticmethod + def truncdiv(a, b) -> NoReturn: + raise NotImplementedError( + "no C-style truncdiv operation available from frontend atm" + ) + + @staticmethod + def add(a, b): + return torch.ops.aten.add.Tensor(a, b) + + @staticmethod + def mul(a, b): + return torch.ops.aten.mul.Tensor(a, b) + + @staticmethod + def sub(a, b): + return torch.ops.aten.sub.Tensor(a, b) + + @staticmethod + def exp(x): + return torch.ops.aten.exp.default(x) + + @staticmethod + def log(x): + return torch.ops.aten.log.default(x) + + @staticmethod + def log2(x): + return torch.ops.aten.log2.default(x) + + @staticmethod + def sqrt(x): + return torch.ops.aten.sqrt.default(x) + + @staticmethod + def sin(x): + return torch.ops.aten.sin.default(x) + + @staticmethod + def cos(x): + return torch.ops.aten.cos.default(x) + + @staticmethod + def tanh(x): + return torch.ops.aten.tanh.default(x) + + @staticmethod + def sinh(x): + return torch.ops.aten.sinh.default(x) + + @staticmethod + def cosh(x): + return torch.ops.aten.cosh.default(x) + + @staticmethod + def tan(x): + return torch.ops.aten.tan.default(x) + + @staticmethod + def acos(x): + return torch.ops.aten.acos.default(x) + + @staticmethod + def atan(x): + return torch.ops.aten.atan.default(x) + + @staticmethod + def asin(x): + return torch.ops.aten.asin.default(x) + + @staticmethod + def pow(a, b): + return torch.ops.aten.pow.Tensor_Tensor(a, b) + + @staticmethod + def pow_by_natural(a, b): + # NB: pow handles int x int fine + return torch.ops.aten.pow.Tensor_Tensor(a, b) + + @staticmethod + def minimum(a, b): + return torch.ops.aten.minimum.default(a, b) + + @staticmethod + def maximum(a, b): + return torch.ops.aten.maximum.default(a, b) + + @staticmethod + def round_to_int(a, dtype): + return torch.ops.aten.round.default(a) + + @staticmethod + def round_decimal(a, b) -> NoReturn: + raise NotImplementedError( + "round decimal doesn't support Tensor second argument atm" + ) + + # return torch.ops.aten.round.decimals(a, b) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/singleton_int.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/singleton_int.py new file mode 100644 index 0000000000000000000000000000000000000000..57d5615e552711a490e306ebc97a260a55251c21 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/singleton_int.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +import sympy +from sympy.multipledispatch import dispatch + + +__all__ = ["SingletonInt"] + + +class SingletonInt(sympy.AtomicExpr): + # This is probably not super important unless we are in multiple dispatch + # situations with other more exotic Expr types. + _op_priority = 99999 + + def __new__(cls, *args, coeff=None, **kwargs): + instance = super().__new__(cls, *args, **kwargs) + return instance + + # The semantics of this class should match that of NestedIntSymNodeImpl in + # c10/core/NestedIntSymNodeImpl.h + def __init__(self, val, *, coeff=1) -> None: + self._val = val + self._coeff = coeff + super().__init__() + + # See NOTE [ Inequalities with nested int ] + def _eval_Eq(self, other): + if ( + isinstance(other, SingletonInt) + and other._val == self._val + and self._coeff == other._coeff + ): + return sympy.true + else: + return sympy.false + + # This is necessary so that calling expr.free_symbols on exprs that contain + # this Singleton does not error + @property + def free_symbols(self): + return set() + + def __mul__(self, other): + if isinstance(other, SingletonInt): + raise ValueError( + "SingletonInt cannot be multiplied by another SingletonInt" + ) + return SingletonInt(self._val, coeff=self._coeff * other) + + def __rmul__(self, other): + if isinstance(other, SingletonInt): + raise ValueError( + "SingletonInt cannot be multiplied by another SingletonInt" + ) + return SingletonInt(self._val, coeff=self._coeff * other) + + # Make sure we promptly raise an error instead of falling back to building + # an expression tree. There are probably more ops, how can we be exhaustive? + def __add__(self, other): + raise NotImplementedError("NYI") + + def __sub__(self, other): + raise NotImplementedError("NYI") + + def __truediv__(self, other): + raise NotImplementedError("NYI") + + def __floordiv__(self, other): + raise NotImplementedError("NYI") + + def __mod__(self, other): + raise NotImplementedError("NYI") + + +# See NOTE [ Inequalities with nested int ] +@dispatch(sympy.Integer, SingletonInt) +def _eval_is_ge(a, b): + if a < 2: + return sympy.false + raise ValueError("Symbolic SingletonInt: Relation is indeterminate") + + +@dispatch(SingletonInt, sympy.Integer) # type: ignore[no-redef] +def _eval_is_ge(a, b): # noqa: F811 + if b <= 2: + return sympy.true + raise ValueError("Symbolic SingletonInt: Relation is indeterminate") + + +@dispatch(SingletonInt, SingletonInt) # type: ignore[no-redef] +def _eval_is_ge(a, b): # noqa: F811 + if a._val == b._val: + if a._coeff >= b._coeff: + return sympy.true + else: + return sympy.false + raise ValueError("Symbolic SingletonInt: Relation is indeterminate") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/solve.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/solve.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd5e1484601ffa1c7c2743ffa228c536cd54fb5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/solve.py @@ -0,0 +1,179 @@ +import logging + +import sympy + +from torch.utils._sympy.functions import FloorDiv + + +log = logging.getLogger(__name__) + +_MIRROR_REL_OP: dict[type[sympy.Basic], type[sympy.Rel]] = { + sympy.Eq: sympy.Eq, + sympy.Ne: sympy.Ne, + sympy.Ge: sympy.Le, + sympy.Gt: sympy.Lt, + sympy.Le: sympy.Ge, + sympy.Lt: sympy.Gt, +} + +INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le) + + +def mirror_rel_op(type: type) -> type[sympy.Rel] | None: + return _MIRROR_REL_OP.get(type) + + +# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side. +# +# Returns a tuple of: +# 1. The simplified expression +# 2. The expression on the right-hand side +# +# Returns 'None' if it can't reach a state where the only thing in the left +# hand side is 'thing'. +# +# 'trials': number of times 'try_solve' will try to isolate 'thing' to the +# left-hand side. +# +# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into +# inequalities. +def try_solve( + expr: sympy.Basic, + thing: sympy.Basic, + trials: int = 5, + floordiv_inequality: bool = True, +) -> tuple[sympy.Rel, sympy.Expr] | None: + mirror = mirror_rel_op(type(expr)) + + # Ignore unsupported expressions: + # - Those that are not relational operations + # - Those that don't have a mirror (just avoiding unexpected classes) + if not isinstance(expr, sympy.Rel) or mirror is None: + log.debug("expression with unsupported type: %s", type(expr)) + return None + + lhs_has_thing = expr.lhs.has(thing) + rhs_has_thing = expr.rhs.has(thing) + + # Give up when 'thing' appears on both sides of the relational expression. + # That is because, as is, we assume the thing we are trying to isolate is + # only on the right-hand side. + if lhs_has_thing and rhs_has_thing: + log.debug("thing (%s) found in both sides of expression: %s", thing, expr) + return None + + # Try considering both LHS and RHS by mirroring the original expression: + # a < b ==> b > a + expressions = [] + + # Add each version of 'expr' if 'thing' is in its left-hand side. + if lhs_has_thing: + expressions.append(expr) + if rhs_has_thing: + expressions.append(mirror(expr.rhs, expr.lhs)) + + for e in expressions: + if e is None: + continue + + if not isinstance(e, sympy.Rel): + raise AssertionError("expected sympy.Rel") + + for _ in range(trials): + trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality) + # Stop if there was no change in this trial. + if trial == e: + break + e = trial # type: ignore[assignment] + + # Return if we were able to isolate 'thing' on the left-hand side. + if isinstance(e, sympy.Rel) and e.lhs == thing: + log.debug("solved: %s ---> %s", expr, e) + return e, e.rhs + + return None + + +def _try_isolate_lhs( + e: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool +) -> sympy.Basic: + op = type(e) + + if isinstance(e, sympy.Rel): + # Move any constants in the left-hand side to the right-hand side. + lhs_not_thing = ( + sum(a for a in e.lhs.args if not a.has(thing)) + if isinstance(e.lhs, sympy.Add) + else 0 + ) + e = op(e.lhs - lhs_not_thing, e.rhs - lhs_not_thing) # type: ignore[attr-defined] + + # Divide both sides by the factors that don't contain thing. + if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul): + lhs, rhs = e.args + other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)]) + + # If we can't tell whether 'other' is negative or positive, we do nothing. + # That is because we don't know whether we have mirror the operation or not. + # We also divide only when we know 'rhs' is not zero. + if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None) and not ( + not isinstance(e, INEQUALITY_TYPES) and rhs.is_zero + ): + # Divide both sides by 'other'. + lhs = lhs / other + rhs = rhs / other + + # If 'e' is an inequality and 'other' is negative, we have to + # mirror the expression. + if isinstance(e, INEQUALITY_TYPES) and other.is_negative: + op = mirror_rel_op(op) # type: ignore[assignment] + + if op is None: + raise AssertionError("expected op to be not None") + e = op(lhs, rhs) + + ################################################################################ + # left-hand side is FloorDiv + ################################################################################ + # + # Given the expression: a // b op c + # where 'op' is a relational operation, these rules only work if: + # - b > 0 + # - c is an integer + if ( + floordiv_inequality + and isinstance(e, sympy.Rel) + and isinstance(e.lhs, FloorDiv) + and e.lhs.divisor.is_positive + and e.rhs.is_integer + ): + # a // b == expr + # => a >= (b * expr) and a < (b * (expr + 1)) + if isinstance(e, sympy.Eq): + numerator, denominator = e.lhs.args + return sympy.And( + sympy.Ge(numerator, (e.rhs * denominator)), + sympy.Lt(numerator, ((e.rhs + 1) * denominator)), + ) + # a // b != expr + # => a < (b * expr) or a >= (b * (expr + 1)) + if isinstance(e, sympy.Ne): + numerator, denominator = e.lhs.args + return sympy.Or( + sympy.Lt(numerator, (e.rhs * denominator)), + sympy.Ge(numerator, ((e.rhs + 1) * denominator)), + ) + # The transformations below only work if b is positive. + # Note: we only have this information for constants. + # a // b > expr => a >= b * (expr + 1) + # a // b >= expr => a >= b * expr + if isinstance(e, (sympy.Gt, sympy.Ge)): + quotient = e.rhs if isinstance(e, sympy.Ge) else (e.rhs + 1) + return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) + # a // b < expr => a < b * expr + # a // b <= expr => a < b * (expr + 1) + if isinstance(e, (sympy.Lt, sympy.Le)): + quotient = e.rhs if isinstance(e, sympy.Lt) else (e.rhs + 1) + return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) + + return e diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/symbol.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/symbol.py new file mode 100644 index 0000000000000000000000000000000000000000..61a7c147458e03e2eaf1704a42335e357eb69be7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/symbol.py @@ -0,0 +1,101 @@ +# mypy: allow-untyped-defs +""" +This file contains canonical definitions for our symbol naming conventions, +across torch.fx.experimental.symbolic_shapes and torch._inductor. The +intention is: + +1. To make it easily greppable where all the sites we use a prefix are +2. Make it possible to easily tell if we can introduce a new prefix without + introducing a conflict + +You can occasionally test if prefixes have been hardcoded by renaming prefixes +in this file and seeing what breaks. +""" + +from collections.abc import Iterable +from enum import auto, Enum + +import sympy + + +class SymT(Enum): + SIZE = auto() + FLOAT = auto() + UNBACKED_INT = auto() + UNBACKED_FLOAT = auto() + # Inductor: The intermediates in inner_fn tmp0, one generated per ops call. + # If one of these shows up in an indexing expression, that means an + # indirect load is happening. + TMP = auto() + # Inductor: Placeholder variable that is later replaced with TMP + INDIRECT = auto() + # Inductor: Some size expressions are replaced with a precomputed size ps0 + # which is computed host side, and then directly reused in the kernel, so + # we don't repeatedly recompute it on device. + PRECOMPUTED_SIZE = auto() + # Inductor: An indexing variable i0 in loops IR which ranges over non-reduced + # dim in the loop + INDEX = auto() + # Inductor: A reduction indexing (r0, r1) variables in loops IR which ranges over + # reduced dim(s) in the loop + R0_INDEX = auto() + R1_INDEX = auto() + # Inductor: In templated kernels torch._inductor.kernel, we have a hook to + # store the final output and append epilogue fusions. To do this, we must + # know what the indexes the outputs range over. NB: These will also + # advertise as INDEX, this is... probably OK? + TEMPLATE_INDEX = auto() + # Inductor: iteration domain for blockIdx.x/blockIdx.y + XBLOCK = auto() + YBLOCK = auto() + ZBLOCK = auto() + # Inductor: this is used solely for dynamic_reshape_indexer + VIEW = auto() + # Alternate (non-modular) indexing used in halide kernels + HALIDE = auto() + + +# Invariant: there must not be a prefix which is a prefix of another string, +# as this introduces ambiguity +prefix_str = { + SymT.SIZE: "s", # integer + SymT.UNBACKED_INT: "u", # integer + # Prefix z here is chosen to avoid false aliasing in symbol_is_type test + # DO NOT add a "z" type. You also need to avoid conflicts on these + # prefixes but this is somewhat easier to manage + SymT.FLOAT: "zf", + SymT.UNBACKED_FLOAT: "zuf", + SymT.TMP: "tmp", + SymT.PRECOMPUTED_SIZE: "ps", + SymT.INDEX: "i", + SymT.R0_INDEX: "r0_", + SymT.R1_INDEX: "r1_", + SymT.TEMPLATE_INDEX: "idx", + SymT.XBLOCK: "x", + SymT.YBLOCK: "y", + SymT.ZBLOCK: "z", + SymT.INDIRECT: "indirect", # false aliasing? + SymT.VIEW: "view", + SymT.HALIDE: "h", +} + + +def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol: + # TODO: maybe put the assumptions here directly + return sympy.Symbol(f"{prefix_str[prefix]}{idx}", **kwargs) + + +# This type is a little wider than it should be, because free_symbols says +# that it contains Basic, rather than Symbol +def symbol_is_type(sym: sympy.Basic, prefix: SymT | Iterable[SymT]) -> bool: + if not isinstance(sym, sympy.Symbol): + raise AssertionError("expected sympy.Symbol") + name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK + if isinstance(prefix, SymT): + return name_str.startswith(prefix_str[prefix]) + else: + return name_str.startswith(tuple(prefix_str[p] for p in prefix)) + + +def free_symbol_is_type(e: sympy.Expr, prefix: SymT | Iterable[SymT]) -> bool: + return any(symbol_is_type(v, prefix) for v in e.free_symbols) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/value_ranges.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/value_ranges.py new file mode 100644 index 0000000000000000000000000000000000000000..ddfd086a10aa7f023cef849b3da875a722d20505 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/_sympy/value_ranges.py @@ -0,0 +1,1145 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import functools +import itertools +import logging +import math +import operator +from collections.abc import Callable +from typing import Generic, overload, SupportsFloat, TYPE_CHECKING, TypeGuard, TypeVar +from typing_extensions import TypeIs + +import sympy +from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom + +import torch +from torch._logging import LazyString +from torch._prims_common import dtype_to_type + +from .functions import ( + _keep_float, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, + OpaqueUnaryFn_exp, + OpaqueUnaryFn_log, + OpaqueUnaryFn_log2, + OpaqueUnaryFn_sqrt, + PowByNatural, + RoundDecimal, + RoundToInt, + safe_pow, + ToFloat, + TruncToFloat, + TruncToInt, +) +from .interp import sympy_interp +from .numbers import int_oo, IntInfinity, NegativeIntInfinity + + +log = logging.getLogger(__name__) + +__all__ = ["ValueRanges", "bound_sympy"] + +_T = TypeVar("_T", sympy.Expr, SympyBoolean) + + +class ValueRangeError(RuntimeError): + pass + + +# Like sympify, but supports less stuff, and also ensures that direct +# sympy expressions don't have free variables +def simple_sympify(e): + if isinstance(e, bool): + return sympy.true if e else sympy.false + elif isinstance(e, int): + return sympy.Integer(e) + elif isinstance(e, float): + # infinity is special; we use it to bracket integers as well + if math.isinf(e): + return sympy.oo if e > 0 else -sympy.oo + return sympy.Float(e) + elif isinstance(e, sympy.Expr): + if not getattr(e, "is_number", False): + raise AssertionError(e) + # NaNs can occur when doing things like 0 * sympy.oo, but it is better + # if the operator notices this and takes care of it, because sometimes + # the NaN is inappropriate (for example, for ints, the [-oo, oo] range + # should go to zero when multiplied with [0, 0]) + if e == sympy.nan: + raise AssertionError("sympy expression is NaN") + return e + elif isinstance(e, BooleanAtom): + return e + else: + raise AssertionError(f"not simple sympy type {type(e)}: {e}") + + +# Sympy atomics only. Unlike <=, it also works on Sympy bools. +def sympy_generic_le(lower, upper): + if isinstance(lower, sympy.Expr): + if not isinstance(upper, sympy.Expr): + raise AssertionError( + "upper must be a sympy.Expr when lower is a sympy.Expr" + ) + # instead of lower <= upper, we do upper >= lower since upper is mostly int_oo + # and we have better code paths there. + return upper >= lower + else: + # only negative condition is True > False + if not isinstance(lower, SympyBoolean) or not isinstance(upper, SympyBoolean): + raise AssertionError((lower, upper)) + return not (lower and not upper) + + +def vr_is_bool(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[SympyBoolean]]: + return vr.is_bool + + +def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]: + return not vr.is_bool + + +def is_sympy_integer(value) -> TypeIs[sympy.Integer]: + return isinstance(value, sympy.Integer) + + +ExprIn = int | float | sympy.Expr +BoolIn = bool | SympyBoolean +AllIn = ExprIn | BoolIn +ExprFn = Callable[[sympy.Expr], sympy.Expr] +ExprFn2 = Callable[[sympy.Expr, sympy.Expr], sympy.Expr] +BoolFn = Callable[[SympyBoolean], SympyBoolean] +BoolFn2 = Callable[[SympyBoolean, SympyBoolean], SympyBoolean] +AllFn = ExprFn | BoolFn +AllFn2 = ExprFn2 | BoolFn2 + + +@dataclasses.dataclass(frozen=True) +class ValueRanges(Generic[_T]): + if TYPE_CHECKING: + # ruff doesn't understand circular references but mypy does + # pyrefly: ignore [unbound-name] + ExprVR = ValueRanges[sympy.Expr] # noqa: F821 + # pyrefly: ignore [unbound-name] + BoolVR = ValueRanges[SympyBoolean] # noqa: F821 + AllVR = ExprVR | BoolVR + + # Although the type signature here suggests you can pass any + # sympy expression, in practice the analysis here only works + # with constant sympy expressions + lower: _T + upper: _T + is_bool: bool + is_int: bool + is_float: bool + + def __repr__(self) -> str: + return f"VR[{self.lower}, {self.upper}]" + + @overload + def __init__( + self: ValueRanges[sympy.Expr], + lower: ExprIn, + upper: ExprIn, + ) -> None: ... + + @overload + def __init__( # type: ignore[misc] + self: ValueRanges[SympyBoolean], + lower: BoolIn, + upper: BoolIn, + ) -> None: ... + + def __init__(self, lower: AllIn, upper: AllIn) -> None: + lower = simple_sympify(lower) + upper = simple_sympify(upper) + # TODO: when the bounds have free variables, this may be + # nontrivial to actually verify + try: + if not sympy_generic_le(lower, upper): + raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]") + except TypeError as e: + raise TypeError(f"Could not compare {lower} <= {upper}") from e + + is_bool_lower = isinstance(lower, SympyBoolean) + is_bool_upper = isinstance(upper, SympyBoolean) + if is_bool_lower != is_bool_upper: + raise AssertionError((lower, upper)) + + # Warning: is_int/is_float is best effort. We do pretty well in + # Dynamo, but in Inductor these attributes are often wrong because we + # are not very rigorous in dtype analysis. This is also why we need + # the flexible analysis for is_int: sometimes a sympy.oo pops in for + # an integer bound. I would /like/ for us not to do this, but it's + # too hard to push the invariant through right now. + if isinstance(lower, sympy.Integer) and upper == sympy.oo: + upper = int_oo + if isinstance(upper, sympy.Integer) and lower == -sympy.oo: + lower = -int_oo + # NB: [-int_oo, -int_oo] and [int_oo, int_oo] are allowed + integer_types = (sympy.Integer, NegativeIntInfinity, IntInfinity) + is_int_lower = isinstance(lower, integer_types) + is_int_upper = isinstance(upper, integer_types) + + # Because this is a frozen class + object.__setattr__(self, "lower", lower) + object.__setattr__(self, "upper", upper) + # Unlike bool/int in Python, we don't report bools are ints + # + # NB: is_bool_lower == is_bool_upper, so we only need to check one + object.__setattr__(self, "is_bool", is_bool_lower) + object.__setattr__( + self, + "is_int", + not self.is_bool and is_int_lower and is_int_upper, + ) + """ + # This assert is just impossible right now, too many sympy bugs + if self.is_int: + # NB: sympy will sometimes randomly lose the float-ness of zero, + # so we also need to account for that in the assertion here. + # See also https://github.com/sympy/sympy/issues/26620 + assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], ( + lower, + upper, + ) + assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper) + """ + # NB: [-oo, oo] always advertises as float! + object.__setattr__(self, "is_float", not self.is_bool and not self.is_int) + if not self.is_bool and not self.is_int and not self.is_float: + raise AssertionError((lower, upper)) + + def boolify(self) -> ValueRanges[SympyBoolean]: + if vr_is_bool(self): + return self + elif self == ValueRanges.unknown(): + return ValueRanges.unknown_bool() + else: + raise AssertionError(f"not bool like {self}") + + def __contains__(self, x: AllIn) -> bool: + return ValueRanges.wrap(x).issubset(self) + + def issubset(self, other): + if other is self.unknown_int(): + return True + return sympy_generic_le(other.lower, self.lower) and sympy_generic_le( + self.upper, other.upper + ) + + def tighten(self, other) -> ValueRanges: + """Given two ValueRanges, returns their intersection""" + return self & other + + # Intersection + @overload + def __and__( + self: ValueRanges[sympy.Expr], + other: ValueRanges[sympy.Expr], + ) -> ValueRanges[sympy.Expr]: ... + + @overload + def __and__( # type: ignore[misc] + self: ValueRanges[SympyBoolean], + other: ValueRanges[SympyBoolean], + ) -> ValueRanges[SympyBoolean]: ... + + def __and__(self: AllVR, other: AllVR) -> AllVR: + if other in (ValueRanges.unknown(), ValueRanges.unknown_int()): + return self + if self in (ValueRanges.unknown(), ValueRanges.unknown_int()): + return other + if self.is_bool != other.is_bool: + raise AssertionError((self, other)) + if self.is_int != other.is_int: + raise AssertionError((self, other)) + if self.is_float != other.is_float: + raise AssertionError((self, other)) + if self.is_bool: + return ValueRanges( + sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper) + ) + else: + return ValueRanges( + sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper) + ) + + # Union + @overload + def __or__( + self: ValueRanges[sympy.Expr], + other: ValueRanges[sympy.Expr], + ) -> ValueRanges[sympy.Expr]: ... + + @overload + def __or__( # type: ignore[misc] + self: ValueRanges[SympyBoolean], + other: ValueRanges[SympyBoolean], + ) -> ValueRanges[SympyBoolean]: ... + + def __or__(self: AllVR, other: AllVR) -> AllVR: + if ValueRanges.unknown() in (self, other): + return ValueRanges.unknown() + if self.is_bool != other.is_bool: + raise AssertionError((self, other)) + if self.is_int != other.is_int: + raise AssertionError((self, other)) + if self.is_float != other.is_float: + raise AssertionError((self, other)) + if self.is_bool: + return ValueRanges( + sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper) + ) + else: + return ValueRanges( + sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper) + ) + + def is_singleton(self) -> bool: + return self.lower == self.upper + + @staticmethod + @functools.cache + def unknown() -> ValueRanges[sympy.Expr]: + return ValueRanges(-sympy.oo, sympy.oo) + + @staticmethod + @functools.cache + def unknown_int() -> ValueRanges[sympy.Expr]: + return ValueRanges(-int_oo, int_oo) + + @staticmethod + @functools.cache + def unknown_bool() -> ValueRanges[SympyBoolean]: + return ValueRanges(sympy.false, sympy.true) + + @overload + @staticmethod + # work around the fact that bool and int overlap + def wrap(arg: ExprIn | ExprVR) -> ExprVR: # type: ignore[overload-overlap] + ... + + @overload + @staticmethod + def wrap(arg: BoolIn | BoolVR) -> BoolVR: # type: ignore[misc] + ... + + @staticmethod + def wrap(arg: AllIn | AllVR) -> AllVR: + if isinstance(arg, ValueRanges): + return arg + if isinstance(arg, float) and math.isnan(arg): + return ValueRanges.unknown() + # arg is either ExprIn or BoolIn, but we don't know it here + return ValueRanges(arg, arg) # type: ignore[arg-type] + + @staticmethod + def increasing_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR: + """Increasing: x <= y => f(x) <= f(y).""" + x = ValueRanges.wrap(x) + return ValueRanges(fn(x.lower), fn(x.upper)) + + @overload + @staticmethod + def decreasing_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR: ... + + @overload + @staticmethod + def decreasing_map(x: BoolIn | BoolVR, fn: BoolFn) -> BoolVR: # type: ignore[misc] + ... + + @staticmethod + def decreasing_map(x: AllIn | AllVR, fn: AllFn) -> AllVR: + """Decreasing: x <= y => f(x) >= f(y).""" + x = ValueRanges.wrap(x) + # consistently either Expr or Bool, but we don't know it here + return ValueRanges(fn(x.upper), fn(x.lower)) # type: ignore[arg-type] + + @staticmethod + def monotone_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR: + """It's increasing or decreasing.""" + x = ValueRanges.wrap(x) + l = fn(x.lower) + u = fn(x.upper) + return ValueRanges(min(l, u), max(l, u)) + + @staticmethod + def convex_min_zero_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR: + """Fn is convex and has a minimum at 0.""" + x = ValueRanges.wrap(x) + if 0 in x: + upper = max(fn(x.lower), fn(x.upper)) + upper = simple_sympify(upper) + if isinstance(upper, sympy.Float) or upper == sympy.oo: + return ValueRanges(0.0, upper) + return ValueRanges(0, upper) + return ValueRanges.monotone_map(x, fn) + + @overload + @staticmethod + def coordinatewise_increasing_map( + x: ExprIn | ExprVR, + y: ExprIn | ExprVR, + fn: ExprFn2, + ) -> ExprVR: ... + + @overload + @staticmethod + def coordinatewise_increasing_map( # type: ignore[misc] + x: BoolIn | BoolVR, + y: BoolIn | BoolVR, + fn: BoolFn2, + ) -> BoolVR: ... + + @staticmethod + def coordinatewise_increasing_map( + x: AllIn | AllVR, + y: AllIn | AllVR, + fn: AllFn2, + ) -> AllVR: + """ + It's increasing on each coordinate. + + Mathematically: + For every 1 <= i <= n and x_i <= y_i we have that + f(x1, .., xn) <= f(x1, , yi, ..., xn) + """ + x, y = ValueRanges.wrap(x), ValueRanges.wrap(y) + return ValueRanges( + fn(x.lower, y.lower), # type: ignore[arg-type] + fn(x.upper, y.upper), # type: ignore[arg-type] + ) + + @classmethod + def coordinatewise_monotone_map(cls, x, y, fn): + """It's increasing or decreasing on each coordinate.""" + x, y = cls.wrap(x), cls.wrap(y) + products = [ + fn(a, b) + for a, b in itertools.product([x.lower, x.upper], [y.lower, y.upper]) + ] + return ValueRanges(min(products), max(products)) + + +class SymPyValueRangeAnalysis: + """ + It gives bounds on a SymPy operator given bounds on its arguments + See the function `bound_sympy` for a function that applies this logic to a full SymPy expression + """ + + @staticmethod + def constant(value, dtype): + if isinstance(value, ValueRanges): + if not value.is_singleton(): + raise AssertionError("ValueRanges must be a singleton for constant()") + value = value.lower + # NB: value is NOT a sympy expression, it's a constant! + is_python = isinstance(value, (int, float, bool)) + if not is_python and not isinstance( + value, (BooleanAtom, sympy.Integer, sympy.Number) + ): + raise AssertionError(f"not a supported constant type: {type(value)}") + + # using nan makes subsequent computation throw, and for the purposes of optimization + # returning -math.inf - math.inf is equivalent to giving up + if isinstance(value, SupportsFloat) and math.isnan(value): + if dtype == torch.bool: + return ValueRanges.unknown_bool() + elif dtype.is_floating_point: + return ValueRanges.unknown() + else: + return ValueRanges.unknown_int() + + if is_python: + type_ = dtype_to_type(dtype) + value = type_(value) + else: + # We do a type check on a best-effort basis + # We don't want to force a cast to sympy.Float if the value is Rational to avoid losing precision + if dtype == torch.bool: + if not isinstance(value, BooleanAtom): + raise AssertionError("expected BooleanAtom for bool dtype") + elif dtype.is_floating_point: + if value.is_finite and not value.is_real: + raise AssertionError( + "expected float-like sympy value for float dtype" + ) + else: + # dtype is intXX + if not getattr(value, "is_integer", False): + raise AssertionError("expected integer sympy value for int dtype") + + r = ValueRanges.wrap(value) + return r + + @staticmethod + def to_dtype(a, dtype, src_dtype=None): + if dtype == torch.float64: + # pyrefly: ignore [bad-argument-type] + return ValueRanges.increasing_map(a, ToFloat) + elif dtype == torch.bool: + return ValueRanges.unknown_bool() + elif not dtype.is_floating_point: + return ValueRanges.unknown_int() + return ValueRanges.unknown() + + @staticmethod + def trunc_to_int(a, dtype): + # pyrefly: ignore [bad-argument-type] + return ValueRanges.increasing_map(a, TruncToInt) + + @staticmethod + def not_(a): + a = ValueRanges.wrap(a) + a = a.boolify() + if not a.is_bool: + raise AssertionError("not_ expects a boolean ValueRanges") + return ValueRanges.decreasing_map(a, sympy.Not) + + @staticmethod + def or_(a, b): + return ValueRanges.coordinatewise_increasing_map(a, b, sympy.Or) + + @staticmethod + def and_(a, b): + return ValueRanges.coordinatewise_increasing_map(a, b, sympy.And) + + @staticmethod + def _bool_to_int(x): + if x.is_singleton(): + return ValueRanges.wrap(sympy.Integer(1 if x.lower else 0)) + else: + return ValueRanges(sympy.Integer(0), sympy.Integer(1)) + + @classmethod + def bitwise_and(cls, a, b): + a, b = ValueRanges.wrap(a), ValueRanges.wrap(b) + if a.is_bool and b.is_bool: + return cls.and_(a, b) + if a.is_bool: + a = cls._bool_to_int(a) + if b.is_bool: + b = cls._bool_to_int(b) + lower = min(a.lower, b.lower) + if lower < 0 and lower != -sympy.oo and lower != -int_oo: + # If both lower bounds are negative, then bits start like + # 1...10..., so the smallest possible value is 1...101...1. + # Thus, we need to find the next smallest power of 2 (inclusive). + try: + lower = -(1 << int(-lower - 1).bit_length()) + except Exception: + lower = -int_oo + else: + lower = 0 + return ValueRanges(lower, max(a.upper, b.upper)) + + @classmethod + def bitwise_or(cls, a, b): + a, b = ValueRanges.wrap(a), ValueRanges.wrap(b) + if a.is_bool and b.is_bool: + return cls.or_(a, b) + if a.is_bool: + a = cls._bool_to_int(a) + if b.is_bool: + b = cls._bool_to_int(b) + upper = max(a.upper, b.upper) + if upper == 0: + upper = 0 + elif upper > 0 and upper != sympy.oo and upper != int_oo: + # If both upper bounds are positive, then the largest + # possible value is 01...1, so we need to find + # next largest power of 2 (exclusive), minus 1 + try: + upper = (1 << int(upper).bit_length()) - 1 + except Exception: + upper = int_oo + elif upper < 0: + upper = -1 + return ValueRanges(min(a.lower, b.lower), upper) + + @classmethod + def bitwise_xor(cls, a, b): + a, b = ValueRanges.wrap(a), ValueRanges.wrap(b) + if a.is_bool and b.is_bool: + bounds = { + a.lower ^ b.lower, + a.lower ^ b.upper, + a.upper ^ b.lower, + a.upper ^ b.upper, + } + + has_false = any(bound == sympy.false for bound in bounds) + has_true = any(bound == sympy.true for bound in bounds) + + if has_false and has_true: + lower, upper = sympy.false, sympy.true + elif has_true: + lower = upper = sympy.true + elif has_false: + lower = upper = sympy.false + else: + raise AssertionError(f"Non-boolean xor result: {bounds}") + + return ValueRanges(lower, upper) + if a.is_bool: + a = cls._bool_to_int(a) + if b.is_bool: + b = cls._bool_to_int(b) + if ( + a.lower == a.upper + and b.lower == b.upper + and is_sympy_integer(a.lower) + and is_sympy_integer(b.lower) + ): + value_range = a.lower ^ b.lower + return ValueRanges(value_range, value_range) + return ValueRanges(-int_oo, int_oo) + + @staticmethod + def eq(a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if a.is_singleton() and b.is_singleton() and a.lower == b.lower: + return ValueRanges.wrap(sympy.true) + elif a.lower > b.upper or b.lower > a.upper: # ranges disjoint + return ValueRanges.wrap(sympy.false) + return ValueRanges(sympy.false, sympy.true) + + @classmethod + def ne(cls, a, b): + return cls.not_(cls.eq(a, b)) + + @classmethod + def identity(cls, a): + return ValueRanges.wrap(a) + + @classmethod + def lt(cls, a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if a.is_bool != b.is_bool: + raise AssertionError( + "operands must both be boolean ValueRanges or both non-boolean" + ) + if a.is_bool: + return cls.and_(cls.not_(a), b) + else: + if a.upper < b.lower: + return ValueRanges.wrap(sympy.true) + elif a.lower >= b.upper: + return ValueRanges.wrap(sympy.false) + return ValueRanges(sympy.false, sympy.true) + + @classmethod + def gt(cls, a, b): + return cls.lt(b, a) + + @classmethod + def le(cls, a, b): + return cls.not_(cls.gt(a, b)) + + @classmethod + def ge(cls, a, b): + return cls.not_(cls.lt(a, b)) + + @staticmethod + def add(a, b): + return ValueRanges.coordinatewise_increasing_map( + a, b, _keep_float(operator.add) + ) + + @classmethod + def mul(cls, a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + + if a.is_bool != b.is_bool: + raise AssertionError( + "operands must both be boolean ValueRanges or both non-boolean" + ) + if a.is_bool: + return cls.and_(a, b) + + def safe_mul(a, b): + # Make unknown() * wrap(0.0) == wrap(0.0) + if a == 0.0 or a == 0: + return a + elif b == 0.0 or b == 0: + return b + else: + return a * b + + return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul)) + + @staticmethod + def int_truediv(a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if 0 in b or ((-int_oo in a or int_oo in a) and (-int_oo in b or int_oo in b)): + return ValueRanges.unknown() + else: + return ValueRanges.coordinatewise_monotone_map( + a, + b, + # pyrefly: ignore [bad-argument-type] + _keep_float(IntTrueDiv), + ) + + @staticmethod + def truediv(a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if 0 in b or ( + (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) + ): + return ValueRanges.unknown() + else: + return ValueRanges.coordinatewise_monotone_map( + a, + b, + # pyrefly: ignore [bad-argument-type] + _keep_float(FloatTrueDiv), + ) + + @staticmethod + def floordiv(a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + + # TODO We shall assume division is always valid probably. + if 0 in b: + if b.lower >= 0 and a.lower >= 0: + return ValueRanges(0, int_oo) + if b.upper <= 0 and a.upper <= 0: + return ValueRanges(0, int_oo) + if b.upper <= 0 and a.lower >= 0: + return ValueRanges(-int_oo, 0) + if b.lower >= 0 and a.upper <= 0: + return ValueRanges(-int_oo, 0) + return ValueRanges.unknown_int() + products = [] + for x, y in itertools.product([a.lower, a.upper], [b.lower, b.upper]): + r = FloorDiv(x, y) + if r is sympy.nan: + products.append((sympy.sign(x) * sympy.sign(y)) * int_oo) + else: + products.append(r) + + return ValueRanges(min(products), max(products)) + + @classmethod + def mod(cls, x, y): + x = ValueRanges.wrap(x) + y = ValueRanges.wrap(y) + # nb. We implement C semantics + + def c_mod(a, b): + ret = abs(a) % abs(b) + if a < 0: + ret *= -1 + return ret + + def c_div(a, b): + x = a / b + return sympy.Integer(x) if x.is_finite and x not in (int_oo, -int_oo) else x + + if 0 in y: + return ValueRanges.unknown_int() + elif y.is_singleton(): + y_val = abs(y.lower) + # If it wraps, we need to take the whole interval + + # The function is locally linear if they are in the same class + if c_div(x.lower, y_val) == c_div(x.upper, y_val): + return ValueRanges.increasing_map(x, lambda u: c_mod(u, y_val)) + if x.upper < 0: + # Negative case + return ValueRanges(-y_val + 1, 0) + elif x.lower > 0: + # Positive case + return ValueRanges(0, y_val - 1) + else: + # Mixed case + lower = max(-y_val + 1, x.lower) + upper = min(y_val - 1, x.upper) + return ValueRanges(lower, upper) + else: + # Too difficult, we bail out + upper = cls.abs(y).upper - 1 + return ValueRanges(-upper, upper) + + @classmethod + def python_mod(cls, x, y): + """Python-style modulo: result has same sign as divisor. + + Assumes valid input where y is never 0. + - When y > 0: result is in [0, y - 1] + - When y < 0: result is in [y + 1, 0] + """ + + x = ValueRanges.wrap(x) + y = ValueRanges.wrap(y) + if x.lower >= 0 and y.lower >= 0: + return SymPyValueRangeAnalysis.mod(x, y) + lower = y.lower + 1 if y.lower < 0 else 0 + upper = y.upper - 1 if y.upper > 0 else 0 + return ValueRanges(lower, upper) + + @classmethod + def modular_indexing(cls, a, b, c): + return cls.mod(cls.floordiv(a, b), c) + + @classmethod + def is_non_overlapping_and_dense_indicator(cls, *args): + return ValueRanges.unknown_int() + + @classmethod + def pow_by_natural(cls, a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if a.is_singleton() and b.is_singleton(): + return ValueRanges.wrap(safe_pow(a.lower, b.lower)) + # NB: Exclude zero, because zero is special + elif a.lower >= 1: + # We should know that b >= 0 but we may have forgotten this fact due + # to replacements, so don't assert it, but DO clamp it to prevent + # degenerate problems + # pyrefly: ignore [no-matching-overload] + return ValueRanges.coordinatewise_increasing_map( + a, b & ValueRanges(0, int_oo), PowByNatural + ) + elif b.is_singleton(): + if b.lower % 2 == 0: + # x^n where n is even + return ValueRanges.convex_min_zero_map( + a, lambda x: safe_pow(x, b.lower) + ) + else: + # x^n where n is odd + return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower)) + else: + # a is potentially negative, and we don't know if the exponent is + # even or odd. So just conservatively set the upper and lower + # bound based on what the maximum absolute value could be, in both + # directions + max_base = max(a.upper, -a.lower) + return ValueRanges( + -(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper) + ) + + @classmethod + def pow(cls, a, b): + return ValueRanges.unknown() + + # We could implement all this, but for floating point pow, is there + # really a point? + """ + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + + # Not implemented yet. It's a bit tricky + # If you want to implement it, compute the partial derivatives of a ** b + # and check the ranges where the function is increasing / decreasing + # Another non-tight way of doing this is defaulting to doing noting that for a > 0, a ** b == exp(b * log(a)) + # If this second option is implemented, by carefult about the types and possible infinities here and there. + if not b.is_singleton(): + return ValueRanges.unknown() + + b = b.lower + if a.is_singleton(): + a = a.lower + r = a**b + if not r.is_finite: + return ValueRanges.unknown() + return ValueRanges.wrap(r) + + if b == 0: + if not a.lower.is_finite: + return ValueRanges.unknown() + return ValueRanges.wrap(1.0) + + if b < 0: + a = cls.reciprocal(a) + b = -b + + if a == ValueRanges.unknown(): + return ValueRanges.unknown() + + # If the base is positive, then we're good, otherwise nothing's defined + if a.lower >= 0: + return ValueRanges.increasing_map(a, lambda x: x**b) + else: + return ValueRanges.unknown() + """ + + @staticmethod + def reciprocal(x): + """Needed as it's used in pow, but it won't appear on a SymPy expression""" + x = ValueRanges.wrap(x) + if 0 in x: + return ValueRanges.unknown() + else: + return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y)) # type: ignore[operator] + + @staticmethod + def abs(x): + return ValueRanges.convex_min_zero_map(x, abs) + + @staticmethod + def exp(x): + return ValueRanges.increasing_map(x, OpaqueUnaryFn_exp) + + @staticmethod + def log(x): + x = ValueRanges.wrap(x) + if x.lower <= 0: + return ValueRanges.unknown() + return ValueRanges.increasing_map(x, OpaqueUnaryFn_log) + + @staticmethod + def log2(x): + x = ValueRanges.wrap(x) + if x.lower <= 0: + return ValueRanges.unknown() + return ValueRanges.increasing_map(x, OpaqueUnaryFn_log2) + + @classmethod + def minimum(cls, a, b): + return cls.min_or_max(a, b, sympy.Min) + + @classmethod + def maximum(cls, a, b): + return cls.min_or_max(a, b, sympy.Max) + + @staticmethod + def min_or_max(a, b, fn): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + return ValueRanges.coordinatewise_increasing_map(a, b, fn) + + @classmethod + def floor_to_int(cls, x, dtype): + return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) + + @classmethod + def ceil_to_int(cls, x, dtype): + return ValueRanges.increasing_map( + x, sympy.functions.elementary.integers.ceiling + ) + + # I think these implementations are sound. The hazard here is that sympy + # will carry out the floor/ceil at too high precision and then something + # bad will happen when we convert it to float. + # + # For truncation, the implementation is clearly sound, because the desired + # target float is always exactly representable, since you're just chopping + # off bits the mantissa. But what about ceil/floor? + # + # The important constraint here is that we're not defining floor on + # arbitrary real numbers, only representable float numbers. So we can + # take advantage of the fact that before we reach the first + # unrepresentable integer in floating point space, we have the range of + # numbers corresponding to exponent zero: all integers, with no fractional + # amounts. floor/ceil is an identity operation in this case. In the + # range below here, representable floating point numbers are spaced + # exactly 1/2 apart, and notably, both the floor/ceil are defined floating + # point numbers. There is no "gap" as you step up to the next exponent. + + @classmethod + def floor(cls, x): + return ValueRanges.increasing_map( + x, _keep_float(sympy.functions.elementary.integers.floor) + ) + + @classmethod + def ceil(cls, x): + return ValueRanges.increasing_map( + x, _keep_float(sympy.functions.elementary.integers.ceiling) + ) + + @classmethod + def round_decimal(cls, number, ndigits): + if not ndigits.is_singleton(): + return ValueRanges.unknown() + + ndigits = ndigits.lower + # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind + # the second parameter. + fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 + + return ValueRanges.increasing_map(number, fn) + + @classmethod + def round_to_int(cls, number, dtype): + # pyrefly: ignore [bad-argument-type] + return ValueRanges.increasing_map(number, RoundToInt) + + # It's used in some models on symints + @staticmethod + def sqrt(x): + x = ValueRanges.wrap(x) + if x.lower < 0: + return ValueRanges.unknown() + return ValueRanges.increasing_map(x, OpaqueUnaryFn_sqrt) + + @staticmethod + def where(a, b, c): + b = ValueRanges.wrap(b) + c = ValueRanges.wrap(c) + a = a.boolify() + # We sometimes write unknown without specifying the type correctly + # In particular, we do that when initialising the bounds for loads in bounds.py + if b.is_bool != c.is_bool and ValueRanges.unknown() not in (b, c): + raise AssertionError( + "where() requires b and c to have the same boolean-ness or allow unknown()" + ) + if b.is_bool: + return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper)) + else: + return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper)) + + # expr_cond_pair is used to represent a single (expr, condition) pair in piecewise. + # We just return the value range of the expression and its corresponding condition as a tuple + # and defer the analysis to piecewise + @staticmethod + def expr_cond_pair(a, b): + b = b.boolify() + return (a, b) + + # piecewise function can be used to convert a SymBool to SymInt: + # int_expr = Piecewise((1, bool_expr), (0, True)), it evaluates to 1 when sym_bool is True and 0 otherwise. + # + # ranges is a sequence of (expr_range, condition_range) pairs. The range pair is constructed in expr_cond_pair. + # The ValueRange of Piecewise is just the union of all expr ranges whose condition expr can be True. + @staticmethod + def piecewise(*ranges): + init_range = None + for expr_range, cond_range in ranges: + if sympy.true in cond_range: + if init_range is None: + init_range = expr_range + else: + init_range = init_range | expr_range + return init_range + + @staticmethod + def cos(x): + # TODO: We should tighten value ranges + # If input range span is pi + 2*pi*k, then output range is (-1, 1) + # otherwise the minimum of the value of the function on the extremes + return ValueRanges(-1.0, 1.0) + + @staticmethod + def cosh(x): + return ValueRanges(0.0, sympy.oo) + """ + x = ValueRanges.wrap(x) + if x.lower > 0: + return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh) + elif x.upper < 0: + return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh) + return ValueRanges(0.0, sympy.oo) + """ + + @staticmethod + def sin(x): + # TODO: We should tighten value ranges + # See details on cos + return ValueRanges(-1.0, 1.0) + + @staticmethod + def sinh(x): + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) + return ValueRanges(-sympy.oo, sympy.oo) + + @staticmethod + def tan(x): + return ValueRanges(-sympy.oo, sympy.oo) + + @staticmethod + def tanh(x): + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) + return ValueRanges(-sympy.oo, sympy.oo) + + @staticmethod + def asin(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ + x = ValueRanges.wrap(x) + if -1 <= x.lower and x.upper <= 1: + return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh) + return ValueRanges.unknown() + """ + + @staticmethod + def acos(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ + x = ValueRanges.wrap(x) + if -1 <= x.lower and x.upper <= 1: + return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos) + return ValueRanges.unknown() + """ + + @staticmethod + def atan(x): + return ValueRanges(-sympy.oo, sympy.oo) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) + + @staticmethod + def trunc(x): + # pyrefly: ignore [bad-argument-type] + return ValueRanges.increasing_map(x, TruncToFloat) + + +def bound_sympy( + expr: sympy.Expr, ranges: dict[sympy.Symbol, ValueRanges] | None = None +) -> ValueRanges: + log.debug( + "bound_sympy(%s)%s", + expr, + LazyString( + lambda: ( + "\n" + + "\n".join( + f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols + ) + if ranges + else "" + ) + ), + ) + if isinstance(expr, sympy.Number): + return ValueRanges.wrap(expr) + + ranges = ranges or {} + + # If there's a tracing context, augment available constrained ranges. + context = torch._guards.TracingContext.try_get() + if context and context.fake_mode and context.fake_mode.shape_env: + if ranges: + ranges = {**context.fake_mode.shape_env.var_to_range, **ranges} + else: + ranges = context.fake_mode.shape_env.var_to_range + + def missing_handler(s): + if s.is_integer: # type: ignore[attr-defined] + if s.is_positive: # type: ignore[attr-defined] + vr = ValueRanges(1, int_oo) + elif s.is_nonnegative: # type: ignore[attr-defined] + vr = ValueRanges(0, int_oo) + else: + vr = ValueRanges.unknown_int() + else: + # Don't bother trying very hard here + vr = ValueRanges.unknown() + return vr + + return sympy_interp( + SymPyValueRangeAnalysis, ranges, expr, missing_handler=missing_handler + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/backcompat/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/backcompat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ec989be1e078ba857d30b06a91e1dc54131e4b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/backcompat/__init__.py @@ -0,0 +1,27 @@ +# mypy: allow-untyped-defs +from torch._C import ( + _get_backcompat_broadcast_warn, + _get_backcompat_keepdim_warn, + _set_backcompat_broadcast_warn, + _set_backcompat_keepdim_warn, +) + + +class Warning: + def __init__(self, setter, getter) -> None: + self.setter = setter + self.getter = getter + + def set_enabled(self, value) -> None: + self.setter(value) + + def get_enabled(self): + return self.getter() + + enabled = property(get_enabled, set_enabled) + + +broadcast_warning = Warning( + _set_backcompat_broadcast_warn, _get_backcompat_broadcast_warn +) +keepdim_warning = Warning(_set_backcompat_keepdim_warn, _get_backcompat_keepdim_warn) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/backcompat/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/backcompat/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7cd55bd93a1f56e8cd45f623cfbc29082a7e943 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/backcompat/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e814aaf4671ca35484c43bc38677849d02a81ec --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/__init__.py @@ -0,0 +1,6 @@ +from torch.utils.benchmark.utils.common import * # noqa: F403 +from torch.utils.benchmark.utils.timer import * # noqa: F403 +from torch.utils.benchmark.utils.compare import * # noqa: F403 +from torch.utils.benchmark.utils.fuzzer import * # noqa: F403 +from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import * # noqa: F403 +from torch.utils.benchmark.utils.sparse_fuzzer import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03297c6a2722bd826db1cb249636c6e73a04b9f4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76e97e2872c816ba969d5516f89e4c8ddf7f94fc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/compare.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/compare.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b3cfe61957f27c7e14de2a4439e2fb403435fda Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/compare.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/fuzzer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/fuzzer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25a3c30b7deed5e7dcc6343b6ec13399ae55a6b3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/fuzzer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/op_benchmark.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/op_benchmark.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30e6b7edd0ad4788cdd0d3be5f4e20addfd420bb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/op_benchmark.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/simple_timeit.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/simple_timeit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..999264ace349610fc4f6c6497898e1d22e6f66f6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/simple_timeit.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/spectral_ops_fuzz_test.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/spectral_ops_fuzz_test.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f74bba0107001f1ee27f94a31341ea01e500028d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/spectral_ops_fuzz_test.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/compare.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/compare.py new file mode 100644 index 0000000000000000000000000000000000000000..1c266e7cf9a6e604c94dfb28f19f31f1649220f4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/compare.py @@ -0,0 +1,99 @@ +# mypy: allow-untyped-defs +"""Example of Timer and Compare APIs: + +$ python -m examples.compare +""" + +import pickle +import sys +import time + +import torch + +import torch.utils.benchmark as benchmark_utils + + +class FauxTorch: + """Emulate different versions of pytorch. + + In normal circumstances this would be done with multiple processes + writing serialized measurements, but this simplifies that model to + make the example clearer. + """ + def __init__(self, real_torch, extra_ns_per_element) -> None: + self._real_torch = real_torch + self._extra_ns_per_element = extra_ns_per_element + + def extra_overhead(self, result): + # time.sleep has a ~65 us overhead, so only fake a + # per-element overhead if numel is large enough. + numel = int(result.numel()) + if numel > 5000: + time.sleep(numel * self._extra_ns_per_element * 1e-9) + return result + + def add(self, *args, **kwargs): + return self.extra_overhead(self._real_torch.add(*args, **kwargs)) + + def mul(self, *args, **kwargs): + return self.extra_overhead(self._real_torch.mul(*args, **kwargs)) + + def cat(self, *args, **kwargs): + return self.extra_overhead(self._real_torch.cat(*args, **kwargs)) + + def matmul(self, *args, **kwargs): + return self.extra_overhead(self._real_torch.matmul(*args, **kwargs)) + + +def main() -> None: + tasks = [ + ("add", "add", "torch.add(x, y)"), + ("add", "add (extra +0)", "torch.add(x, y + zero)"), + ] + + serialized_results = [] + repeats = 2 + timers = [ + benchmark_utils.Timer( + stmt=stmt, + globals={ + "torch": torch if branch == "master" else FauxTorch(torch, overhead_ns), + "x": torch.ones((size, 4)), + "y": torch.ones((1, 4)), + "zero": torch.zeros(()), + }, + label=label, + sub_label=sub_label, + description=f"size: {size}", + env=branch, + num_threads=num_threads, + ) + for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)] + for label, sub_label, stmt in tasks + for size in [1, 10, 100, 1000, 10000, 50000] + for num_threads in [1, 4] + ] + + for i, timer in enumerate(timers * repeats): + serialized_results.append(pickle.dumps( + timer.blocked_autorange(min_run_time=0.05) + )) + print(f"\r{i + 1} / {len(timers) * repeats}", end="") + sys.stdout.flush() + print() + + comparison = benchmark_utils.Compare([ + pickle.loads(i) for i in serialized_results + ]) + + print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n") + comparison.print() + + print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n") + comparison.trim_significant_figures() + comparison.colorize() + comparison.print() + + +if __name__ == "__main__": + main() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/fuzzer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/fuzzer.py new file mode 100644 index 0000000000000000000000000000000000000000..80a4e733928d8b059919d847da1b461d55dd7402 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/fuzzer.py @@ -0,0 +1,86 @@ +# mypy: allow-untyped-defs +"""Example of the Timer and Fuzzer APIs: + +$ python -m examples.fuzzer +""" + +import sys + +import torch.utils.benchmark as benchmark_utils + + +def main() -> None: + add_fuzzer = benchmark_utils.Fuzzer( + parameters=[ + [ + benchmark_utils.FuzzedParameter( + name=f"k{i}", + minval=16, + maxval=16 * 1024, + distribution="loguniform", + ) for i in range(3) + ], + benchmark_utils.FuzzedParameter( + name="d", + distribution={2: 0.6, 3: 0.4}, + ), + ], + tensors=[ + [ + benchmark_utils.FuzzedTensor( + name=name, + size=("k0", "k1", "k2"), + dim_parameter="d", + probability_contiguous=0.75, + min_elements=64 * 1024, + max_elements=128 * 1024, + ) for name in ("x", "y") + ], + ], + seed=0, + ) + + n = 250 + measurements = [] + for i, (tensors, tensor_properties, _) in enumerate(add_fuzzer.take(n=n)): + x, x_order = tensors["x"], str(tensor_properties["x"]["order"]) + y, y_order = tensors["y"], str(tensor_properties["y"]["order"]) + shape = ", ".join(tuple(f'{i:>4}' for i in x.shape)) + + description = "".join([ + f"{x.numel():>7} | {shape:<16} | ", + f"{'contiguous' if x.is_contiguous() else x_order:<12} | ", + f"{'contiguous' if y.is_contiguous() else y_order:<12} | ", + ]) + + timer = benchmark_utils.Timer( + stmt="x + y", + globals=tensors, + description=description, + ) + + measurements.append(timer.blocked_autorange(min_run_time=0.1)) + measurements[-1].metadata = {"numel": x.numel()} + print(f"\r{i + 1} / {n}", end="") + sys.stdout.flush() + print() + + # More string munging to make pretty output. + print(f"Average attempts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}") + + def time_fn(m): + return m.median / m.metadata["numel"] + measurements.sort(key=time_fn) + + template = f"{{:>6}}{' ' * 19}Size Shape{' ' * 13}X order Y order\n{'-' * 80}" + print(template.format("Best:")) + for m in measurements[:15]: + print(f"{time_fn(m) * 1e9:>4.1f} ns / element {m.description}") + + print("\n" + template.format("Worst:")) + for m in measurements[-15:]: + print(f"{time_fn(m) * 1e9:>4.1f} ns / element {m.description}") + + +if __name__ == "__main__": + main() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/op_benchmark.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/op_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f65599ee18a4f2c4a0d35b514c8f87725affae01 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/op_benchmark.py @@ -0,0 +1,107 @@ +# mypy: allow-untyped-defs +"""Example use of Timer and op fuzzers to measure kernel performance. + +$ python -m examples.op_benchmark +""" + +import numpy as np +import torch + +from torch.utils.benchmark import Timer +from torch.utils.benchmark.op_fuzzers.binary import BinaryOpFuzzer +from torch.utils.benchmark.op_fuzzers.unary import UnaryOpFuzzer +import operator + + +_MEASURE_TIME = 1.0 + + +def assert_dicts_equal(dict_0, dict_1) -> None: + """Builtin dict comparison will not compare numpy arrays. + e.g. + x = {"a": np.ones((2, 1))} + x == x # Raises ValueError + """ + if set(dict_0.keys()) != set(dict_0.keys()): + raise AssertionError("dicts must have the same keys") + if all(np.all(v != dict_1[k]) for k, v in dict_0.items() if k != "dtype"): + raise AssertionError("dict values differ for keys other than 'dtype'") + + +def run(n, stmt, fuzzer_cls) -> None: + float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n) + int_iter = fuzzer_cls(seed=0, dtype=torch.int32).take(n) + raw_results = [] + for i, (float_values, int_values) in enumerate(zip(float_iter, int_iter, strict=True)): + float_tensors, float_tensor_params, float_params = float_values + int_tensors, int_tensor_params, int_params = int_values + + # This benchmark assumes that the two fuzzers generate identically + # sized and strided Tensors, since the same seed is used. + assert_dicts_equal(float_params, int_params) + assert_dicts_equal(float_tensor_params["x"], int_tensor_params["x"]) + + float_measurement, int_measurement = ( + Timer( + stmt, + globals=tensors, + ).blocked_autorange(min_run_time=_MEASURE_TIME) + for tensors in (float_tensors, int_tensors) + ) + + descriptions = [] + for name in float_tensors: + shape_str = "(" + ", ".join([ + f"2 ** {int(np.log2(i))}" + if 2 ** int(np.log2(i)) == i and i > 1 + else str(i) + for i in float_tensors[name].shape + ]) + ")" + order = float_tensor_params[name]["order"] + order_str = ("" if all(order == np.arange(len(order))) else str(tuple(order))) + steps = float_tensor_params[name]["steps"] + steps_str = str(steps) if sum(steps) > len(steps) else "" + descriptions.append((name, shape_str, order_str, steps_str)) + raw_results.append((float_measurement, int_measurement, descriptions)) + + print(f"\r{i + 1} / {n}", end="") + print() + + parsed_results, name_len, shape_len, order_len, steps_len = [], 0, 0, 0, 0 + for float_measurement, int_measurement, descriptions in raw_results: + t_float = float_measurement.median * 1e6 + t_int = int_measurement.median * 1e6 + rel_diff = abs(t_float - t_int) / (t_float + t_int) * 2 + parsed_results.append((t_float, t_int, rel_diff, descriptions)) + for name, shape, order, steps in descriptions: + name_len = max(name_len, len(name)) + shape_len = max(shape_len, len(shape)) + order_len = max(order_len, len(order)) + steps_len = max(steps_len, len(steps)) + + parsed_results.sort(key=operator.itemgetter(2)) + + print(f"stmt: {stmt}") + print(f" diff faster{'':>17}{' ' * name_len} ", end="") + print(f"{'shape'.ljust(shape_len)}{'':>16}{'order'.ljust(order_len)}", end="") + print(f" steps\n{'-' * 100}") + for results, spacer in [(parsed_results[:10], "..."), (parsed_results[-10:], "")]: + for t_float, t_int, rel_diff, descriptions in results: + time_str = [f"{rel_diff * 100:>4.1f}% {'int' if t_int < t_float else 'float':<20}"] + time_str.extend(["".ljust(len(time_str[0])) for _ in descriptions[:-1]]) + for t_str, (name, shape, order, steps) in zip(time_str, descriptions, strict=True): + name = f"{name}:".ljust(name_len + 1) + shape = shape.ljust(shape_len + 10) + order = order.ljust(order_len) + print(f"{t_str} {name} {shape}| {order} | {steps}") + print(spacer) + + +def main() -> None: + run(n=100, stmt="torch.median(x, dim=0)", fuzzer_cls=UnaryOpFuzzer) + run(n=100, stmt="torch.square(x)", fuzzer_cls=UnaryOpFuzzer) + run(n=100, stmt="x + y", fuzzer_cls=BinaryOpFuzzer) + + +if __name__ == "__main__": + main() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/simple_timeit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/simple_timeit.py new file mode 100644 index 0000000000000000000000000000000000000000..8137d4d8791975b46b1314c2f3a05ed048dbdcd3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/simple_timeit.py @@ -0,0 +1,25 @@ +"""Trivial use of Timer API: + +$ python -m examples.simple_timeit +""" + +import torch + +import torch.utils.benchmark as benchmark_utils + + +def main() -> None: + timer = benchmark_utils.Timer( + stmt="x + y", + globals={"x": torch.ones((4, 8)), "y": torch.ones((1, 8))}, + label="Broadcasting add (4x8)", + ) + + for i in range(3): + print(f"Run: {i}\n{'-' * 40}") + print(f"timeit:\n{timer.timeit(10000)}\n") + print(f"autorange:\n{timer.blocked_autorange()}\n\n") + + +if __name__ == "__main__": + main() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py new file mode 100644 index 0000000000000000000000000000000000000000..81a33c34bc8229a44838ea93c29af34895061c53 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py @@ -0,0 +1,114 @@ +# mypy: allow-untyped-defs +"""Microbenchmarks for the torch.fft module""" +from argparse import ArgumentParser +from collections import namedtuple +from collections.abc import Iterable + +import torch +import torch.fft +from torch.utils import benchmark +from torch.utils.benchmark.op_fuzzers.spectral import SpectralOpFuzzer + + +def _dim_options(ndim): + if ndim == 1: + return [None] + elif ndim == 2: + return [0, 1, None] + elif ndim == 3: + return [0, 1, 2, (0, 1), (0, 2), None] + raise ValueError(f"Expected ndim in range 1-3, got {ndim}") + + +def run_benchmark(name: str, function: object, dtype: torch.dtype, seed: int, device: str, samples: int, + probability_regular: float): + cuda = device == 'cuda' + spectral_fuzzer = SpectralOpFuzzer(seed=seed, dtype=dtype, cuda=cuda, + probability_regular=probability_regular) + results = [] + for tensors, tensor_params, params in spectral_fuzzer.take(samples): + shape = [params['k0'], params['k1'], params['k2']][:params['ndim']] + str_shape = ' x '.join([f"{s:<4}" for s in shape]) + sub_label = f"{str_shape} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}" + for dim in _dim_options(params['ndim']): + for nthreads in (1, 4, 16) if not cuda else (1,): + measurement = benchmark.Timer( + stmt='func(x, dim=dim)', + globals={'func': function, 'x': tensors['x'], 'dim': dim}, + label=f"{name}_{device}", + sub_label=sub_label, + description=f"dim={dim}", + num_threads=nthreads, + ).blocked_autorange(min_run_time=1) + measurement.metadata = { + 'name': name, + 'device': device, + 'dim': dim, + 'shape': shape, + } + measurement.metadata.update(tensor_params['x']) + results.append(measurement) + return results + + +Benchmark = namedtuple('Benchmark', ['name', 'function', 'dtype']) +BENCHMARKS = [ + Benchmark('fft_real', torch.fft.fftn, torch.float32), + Benchmark('fft_complex', torch.fft.fftn, torch.complex64), + Benchmark('ifft', torch.fft.ifftn, torch.complex64), + Benchmark('rfft', torch.fft.rfftn, torch.float32), + Benchmark('irfft', torch.fft.irfftn, torch.complex64), +] +BENCHMARK_MAP = {b.name: b for b in BENCHMARKS} +BENCHMARK_NAMES = [b.name for b in BENCHMARKS] +DEVICE_NAMES = ['cpu', 'cuda'] + +def _output_csv(file, results) -> None: + file.write('benchmark,device,num_threads,numel,shape,contiguous,dim,mean (us),median (us),iqr (us)\n') + for measurement in results: + metadata = measurement.metadata + device, dim, shape, name, numel, contiguous = ( + metadata['device'], metadata['dim'], metadata['shape'], + metadata['name'], metadata['numel'], metadata['is_contiguous']) + + if isinstance(dim, Iterable): + dim_str = '-'.join(str(d) for d in dim) + else: + dim_str = str(dim) + shape_str = 'x'.join(str(s) for s in shape) + + print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str, # type: ignore[possibly-undefined] + measurement.mean * 1e6, measurement.median * 1e6, measurement.iqr * 1e6, + sep=',', file=file) + + +if __name__ == '__main__': + parser = ArgumentParser(description=__doc__) + parser.add_argument('--device', type=str, choices=DEVICE_NAMES, nargs='+', default=DEVICE_NAMES) + parser.add_argument('--bench', type=str, choices=BENCHMARK_NAMES, nargs='+', default=BENCHMARK_NAMES) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--samples', type=int, default=10) + parser.add_argument('--probability-regular', '--probability_regular', type=float, default=1.0) + parser.add_argument('-o', '--output', type=str) + args = parser.parse_args() + + num_benchmarks = len(args.device) * len(args.bench) + i = 0 + results = [] + for device in args.device: + for bench in (BENCHMARK_MAP[b] for b in args.bench): + results += run_benchmark( + name=bench.name, function=bench.function, dtype=bench.dtype, + seed=args.seed, device=device, samples=args.samples, + probability_regular=args.probability_regular) + i += 1 + print(f'Completed {bench.name} benchmark on {device} ({i} of {num_benchmarks})') + + if args.output is not None: + with open(args.output, 'w') as f: + _output_csv(f, results) + + compare = benchmark.Compare(results) + compare.trim_significant_figures() + compare.colorize() + compare.print() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a81ddc51248fe764e29c3d35f3d7ca21fcd383e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/binary.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/binary.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64752775e78f9fb1b4312e37f952d01d5f77e217 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/binary.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_binary.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_binary.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74c8fe31f1e902211ded0e18811694e300fe5a68 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_binary.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_unary.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_unary.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad70a8bec47505447e063b4ef6bc01215411a9e3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_unary.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/spectral.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/spectral.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e113e2f2a2a57cf52128cbba221ed61e77a94d36 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/spectral.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/unary.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/unary.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48f08043907c342fc162f4e95e06330155dd4a58 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/unary.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/binary.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/binary.py new file mode 100644 index 0000000000000000000000000000000000000000..e53c310111bec8166e6090f351e39153dbe400aa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/binary.py @@ -0,0 +1,107 @@ +# mypy: allow-untyped-defs +import numpy as np +import torch + +from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedTensor + + +_MIN_DIM_SIZE = 16 +_MAX_DIM_SIZE = 16 * 1024 ** 2 +_POW_TWO_SIZES = tuple(2 ** i for i in range( + int(np.log2(_MIN_DIM_SIZE)), + int(np.log2(_MAX_DIM_SIZE)) + 1, +)) + + +class BinaryOpFuzzer(Fuzzer): + def __init__(self, seed, dtype=torch.float32, cuda=False) -> None: + super().__init__( + parameters=[ + # Dimensionality of x and y. (e.g. 1D, 2D, or 3D.) + FuzzedParameter("dim", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True), + + # Shapes for `x` and `y`. + # It is important to test all shapes, however + # powers of two are especially important and therefore + # warrant special attention. This is done by generating + # both a value drawn from all integers between the min and + # max allowed values, and another from only the powers of two + # (both distributions are loguniform) and then randomly + # selecting between the two. + # Moreover, `y` will occasionally have singleton + # dimensions in order to test broadcasting. + [ + FuzzedParameter( + name=f"k_any_{i}", + minval=_MIN_DIM_SIZE, + maxval=_MAX_DIM_SIZE, + distribution="loguniform", + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k_pow2_{i}", + distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES} + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k{i}", + distribution={ + ParameterAlias(f"k_any_{i}"): 0.8, + ParameterAlias(f"k_pow2_{i}"): 0.2, + }, + strict=True, + ) for i in range(3) + ], + + [ + FuzzedParameter( + name=f"y_k{i}", + distribution={ + ParameterAlias(f"k{i}"): 0.8, + 1: 0.2, + }, + strict=True, + ) for i in range(3) + ], + + # Steps for `x` and `y`. (Benchmarks strided memory access.) + [ + FuzzedParameter( + name=f"{name}_step_{i}", + distribution={1: 0.8, 2: 0.06, 4: 0.06, 8: 0.04, 16: 0.04}, + ) + for i in range(3) + for name in ("x", "y") + ], + + # Repeatable entropy for downstream applications. + FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"), + ], + tensors=[ + FuzzedTensor( + name="x", + size=("k0", "k1", "k2"), + steps=("x_step_0", "x_step_1", "x_step_2"), + probability_contiguous=0.75, + min_elements=4 * 1024, + max_elements=32 * 1024 ** 2, + max_allocation_bytes=2 * 1024**3, # 2 GB + dim_parameter="dim", + dtype=dtype, + cuda=cuda, + ), + FuzzedTensor( + name="y", + size=("y_k0", "y_k1", "y_k2"), + steps=("x_step_0", "x_step_1", "x_step_2"), + probability_contiguous=0.75, + max_allocation_bytes=2 * 1024**3, # 2 GB + dim_parameter="dim", + dtype=dtype, + cuda=cuda, + ), + ], + seed=seed, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_binary.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_binary.py new file mode 100644 index 0000000000000000000000000000000000000000..8e6269464e0d53d2c3c51ed5406d7c88598fec79 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_binary.py @@ -0,0 +1,107 @@ +# mypy: allow-untyped-defs +import numpy as np +import torch + +from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedSparseTensor + + +_MIN_DIM_SIZE = 16 +_MAX_DIM_SIZE = 16 * 1024 ** 2 +_POW_TWO_SIZES = tuple(2 ** i for i in range( + int(np.log2(_MIN_DIM_SIZE)), + int(np.log2(_MAX_DIM_SIZE)) + 1, +)) + + +class BinaryOpSparseFuzzer(Fuzzer): + def __init__(self, seed, dtype=torch.float32, cuda=False) -> None: + super().__init__( + parameters=[ + # Dimensionality of x and y. (e.g. 1D, 2D, or 3D.) + FuzzedParameter("dim_parameter", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True), + FuzzedParameter( + name="sparse_dim", + distribution={1: 0.4, 2: 0.4, 3: 0.2}, + strict=True + ), + # Shapes for `x` and `y`. + # It is important to test all shapes, however + # powers of two are especially important and therefore + # warrant special attention. This is done by generating + # both a value drawn from all integers between the min and + # max allowed values, and another from only the powers of two + # (both distributions are loguniform) and then randomly + # selecting between the two. + # Moreover, `y` will occasionally have singleton + # dimensions in order to test broadcasting. + [ + FuzzedParameter( + name=f"k_any_{i}", + minval=_MIN_DIM_SIZE, + maxval=_MAX_DIM_SIZE, + distribution="loguniform", + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k_pow2_{i}", + distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES} + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k{i}", + distribution={ + ParameterAlias(f"k_any_{i}"): 0.8, + ParameterAlias(f"k_pow2_{i}"): 0.2, + }, + strict=True, + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"y_k{i}", + distribution={ + ParameterAlias(f"k{i}"): 1.0}, + strict=True, + ) for i in range(3) + ], + FuzzedParameter( + name="density", + distribution={0.1: 0.4, 0.05: 0.3, 0.01: 0.3}, + ), + FuzzedParameter( + name="coalesced", + distribution={True: 0.5, False: 0.5}, + ), + # Repeatable entropy for downstream applications. + FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"), + ], + tensors=[ + FuzzedSparseTensor( + name="x", + size=("k0", "k1", "k2"), + dim_parameter="dim_parameter", + sparse_dim="sparse_dim", + density="density", + coalesced="coalesced", + min_elements=4 * 1024, + max_elements=32 * 1024 ** 2, + dtype=dtype, + cuda=cuda, + ), + FuzzedSparseTensor( + name="y", + size=("y_k0", "y_k1", "y_k2"), + dim_parameter="dim_parameter", + sparse_dim="sparse_dim", + density="density", + coalesced="coalesced", + min_elements=4 * 1024, + max_elements=32 * 1024 ** 2, + dtype=dtype, + cuda=cuda, + ), + ], + seed=seed, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_unary.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_unary.py new file mode 100644 index 0000000000000000000000000000000000000000..18921becd078cb3140a1705078dd57f4a597a2ec --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_unary.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import torch + +if TYPE_CHECKING: + from torch.types import _dtype + +from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedSparseTensor + +__all__ = ["UnaryOpSparseFuzzer"] + +_MIN_DIM_SIZE = 16 +_MAX_DIM_SIZE = 16 * 1024 ** 2 +_POW_TWO_SIZES = tuple(2 ** i for i in range( + int(np.log2(_MIN_DIM_SIZE)), + int(np.log2(_MAX_DIM_SIZE)) + 1, +)) + +class UnaryOpSparseFuzzer(Fuzzer): + def __init__(self, seed: int | None, dtype: _dtype | None = None, cuda: bool = False) -> None: + if dtype is None: + dtype = getattr(torch, 'float32', None) + super().__init__( + parameters=[ + # Sparse dim parameter of x. (e.g. 1D, 2D, or 3D.) + FuzzedParameter("dim_parameter", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True), + FuzzedParameter( + name="sparse_dim", + distribution={1: 0.4, 2: 0.4, 3: 0.2}, + strict=True + ), + # Shapes for `x`. + # It is important to test all shapes, however + # powers of two are especially important and therefore + # warrant special attention. This is done by generating + # both a value drawn from all integers between the min and + # max allowed values, and another from only the powers of two + # (both distributions are loguniform) and then randomly + # selecting between the two. + [ + FuzzedParameter( + name=f"k_any_{i}", + minval=_MIN_DIM_SIZE, + maxval=_MAX_DIM_SIZE, + distribution="loguniform", + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k_pow2_{i}", + distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES} + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k{i}", + distribution={ + ParameterAlias(f"k_any_{i}"): 0.8, + ParameterAlias(f"k_pow2_{i}"): 0.2, + }, + strict=True, + ) for i in range(3) + ], + FuzzedParameter( + name="density", + distribution={0.1: 0.4, 0.05: 0.3, 0.01: 0.3}, + ), + FuzzedParameter( + name="coalesced", + distribution={True: 0.5, False: 0.5}, + ), + FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"), + ], + tensors=[ + FuzzedSparseTensor( + name="x", + size=("k0", "k1", "k2"), + dim_parameter="dim_parameter", + sparse_dim="sparse_dim", + min_elements=4 * 1024, + max_elements=32 * 1024 ** 2, + density="density", + coalesced="coalesced", + dtype=dtype, + cuda=cuda, + ), + ], + seed=seed, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/spectral.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/spectral.py new file mode 100644 index 0000000000000000000000000000000000000000..c324e338dca5da3d2b8b9a55e7d89f108d6783dd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/spectral.py @@ -0,0 +1,94 @@ +# mypy: allow-untyped-defs +import math + +import torch +from torch.utils import benchmark +from torch.utils.benchmark import FuzzedParameter, FuzzedTensor, ParameterAlias + + +__all__ = ['SpectralOpFuzzer'] + +MIN_DIM_SIZE = 16 +MAX_DIM_SIZE = 16 * 1024 + +def power_range(upper_bound, base): + return (base ** i for i in range(int(math.log(upper_bound, base)) + 1)) + +# List of regular numbers from MIN_DIM_SIZE to MAX_DIM_SIZE +# These numbers factorize into multiples of prime factors 2, 3, and 5 only +# and are usually the fastest in FFT implementations. +REGULAR_SIZES = [] +for i in power_range(MAX_DIM_SIZE, 2): + for j in power_range(MAX_DIM_SIZE // i, 3): + ij = i * j + for k in power_range(MAX_DIM_SIZE // ij, 5): + ijk = ij * k + if ijk > MIN_DIM_SIZE: + REGULAR_SIZES.append(ijk) +REGULAR_SIZES.sort() + +class SpectralOpFuzzer(benchmark.Fuzzer): + def __init__(self, *, seed: int, dtype=torch.float64, + cuda: bool = False, probability_regular: float = 1.0) -> None: + super().__init__( + parameters=[ + # Dimensionality of x. (e.g. 1D, 2D, or 3D.) + FuzzedParameter("ndim", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True), + + # Shapes for `x`. + # It is important to test all shapes, however + # regular sizes are especially important to the FFT and therefore + # warrant special attention. This is done by generating + # both a value drawn from all integers between the min and + # max allowed values, and another from only the regular numbers + # (both distributions are loguniform) and then randomly + # selecting between the two. + [ + FuzzedParameter( + name=f"k_any_{i}", + minval=MIN_DIM_SIZE, + maxval=MAX_DIM_SIZE, + distribution="loguniform", + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k_regular_{i}", + distribution={size: 1. / len(REGULAR_SIZES) for size in REGULAR_SIZES} + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k{i}", + distribution={ + ParameterAlias(f"k_regular_{i}"): probability_regular, + ParameterAlias(f"k_any_{i}"): 1 - probability_regular, + }, + strict=True, + ) for i in range(3) + ], + + # Steps for `x`. (Benchmarks strided memory access.) + [ + FuzzedParameter( + name=f"step_{i}", + distribution={1: 0.8, 2: 0.06, 4: 0.06, 8: 0.04, 16: 0.04}, + ) for i in range(3) + ], + ], + tensors=[ + FuzzedTensor( + name="x", + size=("k0", "k1", "k2"), + steps=("step_0", "step_1", "step_2"), + probability_contiguous=0.75, + min_elements=4 * 1024, + max_elements=32 * 1024 ** 2, + max_allocation_bytes=2 * 1024**3, # 2 GB + dim_parameter="ndim", + dtype=dtype, + cuda=cuda, + ), + ], + seed=seed, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/unary.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/unary.py new file mode 100644 index 0000000000000000000000000000000000000000..6008adfe459218cd0e239efede5a3f1cd35ee61b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/unary.py @@ -0,0 +1,82 @@ +# mypy: allow-untyped-defs +import numpy as np +import torch + +from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedTensor + + +_MIN_DIM_SIZE = 16 +_MAX_DIM_SIZE = 16 * 1024 ** 2 +_POW_TWO_SIZES = tuple(2 ** i for i in range( + int(np.log2(_MIN_DIM_SIZE)), + int(np.log2(_MAX_DIM_SIZE)) + 1, +)) + + +class UnaryOpFuzzer(Fuzzer): + def __init__(self, seed, dtype=torch.float32, cuda=False) -> None: + super().__init__( + parameters=[ + # Dimensionality of x. (e.g. 1D, 2D, or 3D.) + FuzzedParameter("dim", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True), + + # Shapes for `x`. + # It is important to test all shapes, however + # powers of two are especially important and therefore + # warrant special attention. This is done by generating + # both a value drawn from all integers between the min and + # max allowed values, and another from only the powers of two + # (both distributions are loguniform) and then randomly + # selecting between the two. + [ + FuzzedParameter( + name=f"k_any_{i}", + minval=_MIN_DIM_SIZE, + maxval=_MAX_DIM_SIZE, + distribution="loguniform", + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k_pow2_{i}", + distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES} + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k{i}", + distribution={ + ParameterAlias(f"k_any_{i}"): 0.8, + ParameterAlias(f"k_pow2_{i}"): 0.2, + }, + strict=True, + ) for i in range(3) + ], + + # Steps for `x`. (Benchmarks strided memory access.) + [ + FuzzedParameter( + name=f"x_step_{i}", + distribution={1: 0.8, 2: 0.06, 4: 0.06, 8: 0.04, 16: 0.04}, + ) for i in range(3) + ], + + # Repeatable entropy for downstream applications. + FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"), + ], + tensors=[ + FuzzedTensor( + name="x", + size=("k0", "k1", "k2"), + steps=("x_step_0", "x_step_1", "x_step_2"), + probability_contiguous=0.75, + min_elements=4 * 1024, + max_elements=32 * 1024 ** 2, + max_allocation_bytes=2 * 1024**3, # 2 GB + dim_parameter="dim", + dtype=dtype, + cuda=cuda, + ), + ], + seed=seed, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8bedca989e9e299b7aa761baba38cd1fe2bdba1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/_stubs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/_stubs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bd70c32ebbd56ea03095a398abb75841cb52b3f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/_stubs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2accfb6f763143e11aae9d8b235cf54ce43ec1d5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/common.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/compare.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/compare.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2288c58dc3a415201c9149f835f2aca1e05dd1bf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/compare.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/compile.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/compile.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d0e26b63fbd0cb18b58b0300dfb28c302f53abe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/compile.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/cpp_jit.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/cpp_jit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6f59b1fd7d4a4e87adbef78ca6d3fb9b7d3dd52 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/cpp_jit.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/fuzzer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/fuzzer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e2ab5c24607926e6568a4cb7fa29b61ac9b0735 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/fuzzer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/sparse_fuzzer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/sparse_fuzzer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e7302b04b6d74b3d6bc0962030ece953771c518 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/sparse_fuzzer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/timer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/timer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf100e77f5451f2f8fc8f2329dba672c4ef770a8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/timer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/_stubs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/_stubs.py new file mode 100644 index 0000000000000000000000000000000000000000..c91e3d12b29e1c050edbadaebb877d7fc0761e57 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/_stubs.py @@ -0,0 +1,42 @@ +from typing import Any +from collections.abc import Callable +from typing_extensions import Protocol, runtime_checkable + + +class TimerClass(Protocol): + """This is the portion of the `timeit.Timer` API used by benchmark utils.""" + def __init__( + self, + stmt: str, + setup: str, + timer: Callable[[], float], + globals: dict[str, Any], + **kwargs: Any, + ) -> None: + ... + + def timeit(self, number: int) -> float: + ... + + +@runtime_checkable +class TimeitModuleType(Protocol): + """Modules generated from `timeit_template.cpp`.""" + def timeit(self, number: int) -> float: + ... + + +class CallgrindModuleType(Protocol): + """Replicates the valgrind endpoints in `torch._C`. + + These bindings are used to collect Callgrind profiles on earlier versions + of PyTorch and will eventually be removed. + """ + __file__: str + __name__: str + + def _valgrind_supported_platform(self) -> bool: + ... + + def _valgrind_toggle(self) -> None: + ... diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..d4f328d19083f0fc92da79e34d70a68b8ef891ff --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/common.py @@ -0,0 +1,359 @@ +"""Base shared classes and utilities.""" + +import collections +import contextlib +import dataclasses +import os +import shutil +import tempfile +import textwrap +import time +from typing import cast, Any +from collections.abc import Iterable, Iterator +import uuid + +import torch + + +__all__ = ["TaskSpec", "Measurement", "select_unit", "unit_to_english", "trim_sigfig", "ordered_unique", "set_torch_threads"] + + +_MAX_SIGNIFICANT_FIGURES = 4 +_MIN_CONFIDENCE_INTERVAL = 25e-9 # 25 ns + +# Measurement will include a warning if the distribution is suspect. All +# runs are expected to have some variation; these parameters set the +# thresholds. +_IQR_WARN_THRESHOLD = 0.1 +_IQR_GROSS_WARN_THRESHOLD = 0.25 + + +@dataclasses.dataclass(init=True, repr=False, eq=True, frozen=True) +class TaskSpec: + """Container for information used to define a Timer. (except globals)""" + stmt: str + setup: str + global_setup: str = "" + label: str | None = None + sub_label: str | None = None + description: str | None = None + env: str | None = None + num_threads: int = 1 + + @property + def title(self) -> str: + """Best effort attempt at a string label for the measurement.""" + if self.label is not None: + return self.label + (f": {self.sub_label}" if self.sub_label else "") + elif "\n" not in self.stmt: + return self.stmt + (f": {self.sub_label}" if self.sub_label else "") + return ( + f"stmt:{f' ({self.sub_label})' if self.sub_label else ''}\n" + f"{textwrap.indent(self.stmt, ' ')}" + ) + + def setup_str(self) -> str: + return ( + "" if (self.setup == "pass" or not self.setup) + else f"setup:\n{textwrap.indent(self.setup, ' ')}" if "\n" in self.setup + else f"setup: {self.setup}" + ) + + def summarize(self) -> str: + """Build TaskSpec portion of repr string for other containers.""" + sections = [ + self.title, + self.description or "", + self.setup_str(), + ] + return "\n".join([f"{i}\n" if "\n" in i else i for i in sections if i]) + +_TASKSPEC_FIELDS = tuple(i.name for i in dataclasses.fields(TaskSpec)) + + +@dataclasses.dataclass(init=True, repr=False) +class Measurement: + """The result of a Timer measurement. + + This class stores one or more measurements of a given statement. It is + serializable and provides several convenience methods + (including a detailed __repr__) for downstream consumers. + """ + number_per_run: int + raw_times: list[float] + task_spec: TaskSpec + metadata: dict[Any, Any] | None = None # Reserved for user payloads. + + def __post_init__(self) -> None: + self._sorted_times: tuple[float, ...] = () + self._warnings: tuple[str, ...] = () + self._median: float = -1.0 + self._mean: float = -1.0 + self._p25: float = -1.0 + self._p75: float = -1.0 + + def __getattr__(self, name: str) -> Any: + # Forward TaskSpec fields for convenience. + if name in _TASKSPEC_FIELDS: + return getattr(self.task_spec, name) + return super().__getattribute__(name) + + # ========================================================================= + # == Convenience methods for statistics =================================== + # ========================================================================= + # + # These methods use raw time divided by number_per_run; this is an + # extrapolation and hides the fact that different number_per_run will + # result in different amortization of overheads, however if Timer has + # selected an appropriate number_per_run then this is a non-issue, and + # forcing users to handle that division would result in a poor experience. + @property + def times(self) -> list[float]: + return [t / self.number_per_run for t in self.raw_times] + + @property + def median(self) -> float: + self._lazy_init() + return self._median + + @property + def mean(self) -> float: + self._lazy_init() + return self._mean + + @property + def iqr(self) -> float: + self._lazy_init() + return self._p75 - self._p25 + + @property + def significant_figures(self) -> int: + """Approximate significant figure estimate. + + This property is intended to give a convenient way to estimate the + precision of a measurement. It only uses the interquartile region to + estimate statistics to try to mitigate skew from the tails, and + uses a static z value of 1.645 since it is not expected to be used + for small values of `n`, so z can approximate `t`. + + The significant figure estimation used in conjunction with the + `trim_sigfig` method to provide a more human interpretable data + summary. __repr__ does not use this method; it simply displays raw + values. Significant figure estimation is intended for `Compare`. + """ + self._lazy_init() + n_total = len(self._sorted_times) + lower_bound = int(n_total // 4) + upper_bound = int(torch.tensor(3 * n_total / 4).ceil()) + interquartile_points: tuple[float, ...] = self._sorted_times[lower_bound:upper_bound] + std = torch.tensor(interquartile_points).std(unbiased=False).item() + sqrt_n = torch.tensor(len(interquartile_points)).sqrt().item() + + # Rough estimates. These are by no means statistically rigorous. + confidence_interval = max(1.645 * std / sqrt_n, _MIN_CONFIDENCE_INTERVAL) + relative_ci = torch.tensor(self._median / confidence_interval).log10().item() + num_significant_figures = int(torch.tensor(relative_ci).floor()) + return min(max(num_significant_figures, 1), _MAX_SIGNIFICANT_FIGURES) + + @property + def has_warnings(self) -> bool: + self._lazy_init() + return bool(self._warnings) + + def _lazy_init(self) -> None: + if self.raw_times and not self._sorted_times: + self._sorted_times = tuple(sorted(self.times)) + _sorted_times = torch.tensor(self._sorted_times, dtype=torch.float64) + self._median = _sorted_times.quantile(.5).item() + self._mean = _sorted_times.mean().item() + self._p25 = _sorted_times.quantile(.25).item() + self._p75 = _sorted_times.quantile(.75).item() + + def add_warning(msg: str) -> None: + rel_iqr = self.iqr / self.median * 100 + self._warnings += ( + f" WARNING: Interquartile range is {rel_iqr:.1f}% " + f"of the median measurement.\n {msg}", + ) + + if not self.meets_confidence(_IQR_GROSS_WARN_THRESHOLD): + add_warning("This suggests significant environmental influence.") + elif not self.meets_confidence(_IQR_WARN_THRESHOLD): + add_warning("This could indicate system fluctuation.") + + + def meets_confidence(self, threshold: float = _IQR_WARN_THRESHOLD) -> bool: + return self.iqr / self.median < threshold + + @property + def title(self) -> str: + return self.task_spec.title + + @property + def env(self) -> str: + return ( + "Unspecified env" if self.taskspec.env is None + else cast(str, self.taskspec.env) + ) + + @property + def as_row_name(self) -> str: + return self.sub_label or self.stmt or "[Unknown]" + + def __repr__(self) -> str: + """ + Example repr: + + Broadcasting add (4x8) + Median: 5.73 us + IQR: 2.25 us (4.01 to 6.26) + 372 measurements, 100 runs per measurement, 1 thread + WARNING: Interquartile range is 39.4% of the median measurement. + This suggests significant environmental influence. + """ + self._lazy_init() + skip_line, newline = "MEASUREMENT_REPR_SKIP_LINE", "\n" + n = len(self._sorted_times) + time_unit, time_scale = select_unit(self._median) + iqr_filter = '' if n >= 4 else skip_line + + repr_str = f""" +{super().__repr__()} +{self.task_spec.summarize()} + {'Median: ' if n > 1 else ''}{self._median / time_scale:.2f} {time_unit} + {iqr_filter}IQR: {self.iqr / time_scale:.2f} {time_unit} ({self._p25 / time_scale:.2f} to {self._p75 / time_scale:.2f}) + {n} measurement{'s' if n > 1 else ''}, {self.number_per_run} runs {'per measurement,' if n > 1 else ','} {self.num_threads} thread{'s' if self.num_threads > 1 else ''} +{newline.join(self._warnings)}""".strip() # noqa: B950 + + return "\n".join(l for l in repr_str.splitlines(keepends=False) if skip_line not in l) + + @staticmethod + def merge(measurements: Iterable["Measurement"]) -> list["Measurement"]: + """Convenience method for merging replicates. + + Merge will extrapolate times to `number_per_run=1` and will not + transfer any metadata. (Since it might differ between replicates) + """ + grouped_measurements: collections.defaultdict[TaskSpec, list[Measurement]] = collections.defaultdict(list) + for m in measurements: + grouped_measurements[m.task_spec].append(m) + + def merge_group(task_spec: TaskSpec, group: list["Measurement"]) -> "Measurement": + times: list[float] = [] + for m in group: + # Different measurements could have different `number_per_run`, + # so we call `.times` which normalizes the results. + times.extend(m.times) + + return Measurement( + number_per_run=1, + raw_times=times, + task_spec=task_spec, + metadata=None, + ) + + return [merge_group(t, g) for t, g in grouped_measurements.items()] + + +def select_unit(t: float) -> tuple[str, float]: + """Determine how to scale times for O(1) magnitude. + + This utility is used to format numbers for human consumption. + """ + time_unit = {-3: "ns", -2: "us", -1: "ms"}.get(int(torch.tensor(t).log10().item() // 3), "s") + time_scale = {"ns": 1e-9, "us": 1e-6, "ms": 1e-3, "s": 1}[time_unit] + return time_unit, time_scale + + +def unit_to_english(u: str) -> str: + return { + "ns": "nanosecond", + "us": "microsecond", + "ms": "millisecond", + "s": "second", + }[u] + + +def trim_sigfig(x: float, n: int) -> float: + """Trim `x` to `n` significant figures. (e.g. 3.14159, 2 -> 3.10000)""" + if n != int(n): + raise AssertionError("Number of significant figures must be an integer") + magnitude = int(torch.tensor(x).abs().log10().ceil().item()) + scale = 10 ** (magnitude - n) + return float(torch.tensor(x / scale).round() * scale) + + +def ordered_unique(elements: Iterable[Any]) -> list[Any]: + return list(collections.OrderedDict(dict.fromkeys(elements)).keys()) + + +@contextlib.contextmanager +def set_torch_threads(n: int) -> Iterator[None]: + prior_num_threads = torch.get_num_threads() + try: + torch.set_num_threads(n) + yield + finally: + torch.set_num_threads(prior_num_threads) + + +def _make_temp_dir(prefix: str | None = None, gc_dev_shm: bool = False) -> str: + """Create a temporary directory. The caller is responsible for cleanup. + + This function is conceptually similar to `tempfile.mkdtemp`, but with + the key additional feature that it will use shared memory if the + `BENCHMARK_USE_DEV_SHM` environment variable is set. This is an + implementation detail, but an important one for cases where many Callgrind + measurements are collected at once. (Such as when collecting + microbenchmarks.) + + This is an internal utility, and is exported solely so that microbenchmarks + can reuse the util. + """ + use_dev_shm: bool = (os.getenv("BENCHMARK_USE_DEV_SHM") or "").lower() in ("1", "true") + if use_dev_shm: + root = "/dev/shm/pytorch_benchmark_utils" + if os.name != "posix": + raise AssertionError(f"tmpfs (/dev/shm) is POSIX only, current platform is {os.name}") + if not os.path.exists("/dev/shm"): + raise AssertionError("This system does not appear to support tmpfs (/dev/shm).") + os.makedirs(root, exist_ok=True) + + # Because we're working in shared memory, it is more important than + # usual to clean up ALL intermediate files. However we don't want every + # worker to walk over all outstanding directories, so instead we only + # check when we are sure that it won't lead to contention. + if gc_dev_shm: + for i in os.listdir(root): + owner_file = os.path.join(root, i, "owner.pid") + if not os.path.exists(owner_file): + continue + + with open(owner_file) as f: + owner_pid = int(f.read()) + + if owner_pid == os.getpid(): + continue + + try: + # https://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid-in-python + os.kill(owner_pid, 0) + + except OSError: + print(f"Detected that {os.path.join(root, i)} was orphaned in shared memory. Cleaning up.") + shutil.rmtree(os.path.join(root, i)) + + else: + root = tempfile.gettempdir() + + # We include the time so names sort by creation time, and add a UUID + # to ensure we don't collide. + name = f"{prefix or tempfile.gettempprefix()}__{int(time.time())}__{uuid.uuid4()}" + path = os.path.join(root, name) + os.makedirs(path, exist_ok=False) + + if use_dev_shm: + with open(os.path.join(path, "owner.pid"), "w") as f: + f.write(str(os.getpid())) + + return path diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/compare.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/compare.py new file mode 100644 index 0000000000000000000000000000000000000000..c1e232e6e04260f277254c9b181c63dfeaadee62 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/compare.py @@ -0,0 +1,345 @@ +# mypy: allow-untyped-defs +"""Display class to aggregate and print the results of many measurements.""" +import collections +import enum +import itertools as it + +from torch.utils.benchmark.utils import common +from torch import tensor as _tensor +import operator + +__all__ = ["Colorize", "Compare"] + +BEST = "\033[92m" +GOOD = "\033[34m" +BAD = "\033[2m\033[91m" +VERY_BAD = "\033[31m" +BOLD = "\033[1m" +TERMINATE = "\033[0m" + + +class Colorize(enum.Enum): + NONE = "none" + COLUMNWISE = "columnwise" + ROWWISE = "rowwise" + + +# Classes to separate internal bookkeeping from what is rendered. +class _Column: + def __init__( + self, + grouped_results: list[tuple[common.Measurement | None, ...]], + time_scale: float, + time_unit: str, + trim_significant_figures: bool, + highlight_warnings: bool, + ) -> None: + self._grouped_results = grouped_results + self._flat_results = [*it.chain.from_iterable(grouped_results)] + self._time_scale = time_scale + self._time_unit = time_unit + self._trim_significant_figures = trim_significant_figures + self._highlight_warnings = ( + highlight_warnings + and any(r.has_warnings for r in self._flat_results if r) + ) + leading_digits = [ + int(_tensor(r.median / self._time_scale).log10().ceil()) if r else None + for r in self._flat_results + ] + unit_digits = max(d for d in leading_digits if d is not None) + decimal_digits = min( + max(m.significant_figures - digits, 0) + for digits, m in zip(leading_digits, self._flat_results, strict=True) + if (m is not None) and (digits is not None) + ) if self._trim_significant_figures else 1 + length = unit_digits + decimal_digits + (1 if decimal_digits else 0) + self._template = f"{{:>{length}.{decimal_digits}f}}{{:>{7 if self._highlight_warnings else 0}}}" + + def get_results_for(self, group): + return self._grouped_results[group] + + def num_to_str(self, value: float | None, estimated_sigfigs: int, spread: float | None): + if value is None: + return " " * len(self.num_to_str(1, estimated_sigfigs, None)) + + if self._trim_significant_figures: + value = common.trim_sigfig(value, estimated_sigfigs) + + return self._template.format( + value, + f" (! {spread * 100:.0f}%)" if self._highlight_warnings and spread is not None else "") + + +def optional_min(seq): + l = list(seq) + return None if len(l) == 0 else min(l) + + +class _Row: + def __init__(self, results, row_group, render_env, env_str_len, + row_name_str_len, time_scale, colorize, num_threads=None) -> None: + super().__init__() + self._results = results + self._row_group = row_group + self._render_env = render_env + self._env_str_len = env_str_len + self._row_name_str_len = row_name_str_len + self._time_scale = time_scale + self._colorize = colorize + self._columns: tuple[_Column, ...] = () + self._num_threads = num_threads + + def register_columns(self, columns: tuple[_Column, ...]) -> None: + self._columns = columns + + def as_column_strings(self): + concrete_results = [r for r in self._results if r is not None] + env = f"({concrete_results[0].env})" if self._render_env else "" + env = env.ljust(self._env_str_len + 4) + output = [" " + env + concrete_results[0].as_row_name] + for m, col in zip(self._results, self._columns or (), strict=False): + if m is None: + output.append(col.num_to_str(None, 1, None)) + else: + output.append(col.num_to_str( + m.median / self._time_scale, + m.significant_figures, + m.iqr / m.median if m.has_warnings else None + )) + return output + + @staticmethod + def color_segment(segment, value, best_value): + if value <= best_value * 1.01 or value <= best_value + 100e-9: + return BEST + BOLD + segment + TERMINATE * 2 + if value <= best_value * 1.1: + return GOOD + BOLD + segment + TERMINATE * 2 + if value >= best_value * 5: + return VERY_BAD + BOLD + segment + TERMINATE * 2 + if value >= best_value * 2: + return BAD + segment + TERMINATE * 2 + + return segment + + def row_separator(self, overall_width): + return ( + [f"{self._num_threads} threads: ".ljust(overall_width, "-")] + if self._num_threads is not None else [] + ) + + def finalize_column_strings(self, column_strings, col_widths): + best_values = [-1 for _ in column_strings] + if self._colorize == Colorize.ROWWISE: + row_min = min(r.median for r in self._results if r is not None) + best_values = [row_min for _ in column_strings] + elif self._colorize == Colorize.COLUMNWISE: + best_values = [ + optional_min(r.median for r in column.get_results_for(self._row_group) if r is not None) + for column in (self._columns or ()) + ] + + row_contents = [column_strings[0].ljust(col_widths[0])] + for col_str, width, result, best_value in zip(column_strings[1:], col_widths[1:], self._results, best_values, strict=False): + col_str = col_str.center(width) + if self._colorize != Colorize.NONE and result is not None and best_value is not None: + col_str = self.color_segment(col_str, result.median, best_value) + row_contents.append(col_str) + return row_contents + + +class Table: + def __init__( + self, + results: list[common.Measurement], + colorize: Colorize, + trim_significant_figures: bool, + highlight_warnings: bool + ) -> None: + if len({r.label for r in results}) != 1: + raise AssertionError("All results must share the same label") + + self.results = results + self._colorize = colorize + self._trim_significant_figures = trim_significant_figures + self._highlight_warnings = highlight_warnings + self.label = results[0].label + self.time_unit, self.time_scale = common.select_unit( + min(r.median for r in results) + ) + + self.row_keys = common.ordered_unique([self.row_fn(i) for i in results]) + self.row_keys.sort(key=operator.itemgetter(slice(2))) # preserve stmt order + self.column_keys = common.ordered_unique([self.col_fn(i) for i in results]) + self.rows, self.columns = self.populate_rows_and_columns() + + @staticmethod + def row_fn(m: common.Measurement) -> tuple[int, str | None, str]: + return m.num_threads, m.env, m.as_row_name + + @staticmethod + def col_fn(m: common.Measurement) -> str | None: + return m.description + + def populate_rows_and_columns(self) -> tuple[tuple[_Row, ...], tuple[_Column, ...]]: + rows: list[_Row] = [] + columns: list[_Column] = [] + ordered_results: list[list[common.Measurement | None]] = [ + [None for _ in self.column_keys] + for _ in self.row_keys + ] + row_position = {key: i for i, key in enumerate(self.row_keys)} + col_position = {key: i for i, key in enumerate(self.column_keys)} + for r in self.results: + i = row_position[self.row_fn(r)] + j = col_position[self.col_fn(r)] + ordered_results[i][j] = r + + unique_envs = {r.env for r in self.results} + render_env = len(unique_envs) > 1 + env_str_len = max(len(i) for i in unique_envs) if render_env else 0 + + row_name_str_len = max(len(r.as_row_name) for r in self.results) + + prior_num_threads = -1 + prior_env = "" + row_group = -1 + rows_by_group: list[list[list[common.Measurement | None]]] = [] + for (num_threads, env, _), row in zip(self.row_keys, ordered_results, strict=True): + thread_transition = (num_threads != prior_num_threads) + if thread_transition: + prior_num_threads = num_threads + prior_env = "" + row_group += 1 + rows_by_group.append([]) + rows.append( + _Row( + results=row, + row_group=row_group, + render_env=(render_env and env != prior_env), + env_str_len=env_str_len, + row_name_str_len=row_name_str_len, + time_scale=self.time_scale, + colorize=self._colorize, + num_threads=num_threads if thread_transition else None, + ) + ) + rows_by_group[-1].append(row) + prior_env = env + + for i in range(len(self.column_keys)): + grouped_results = [tuple(row[i] for row in g) for g in rows_by_group] + column = _Column( + grouped_results=grouped_results, + time_scale=self.time_scale, + time_unit=self.time_unit, + trim_significant_figures=self._trim_significant_figures, + highlight_warnings=self._highlight_warnings,) + columns.append(column) + + rows_tuple, columns_tuple = tuple(rows), tuple(columns) + for ri in rows_tuple: + ri.register_columns(columns_tuple) + return rows_tuple, columns_tuple + + def render(self) -> str: + string_rows = [[""] + self.column_keys] + string_rows.extend(r.as_column_strings() for r in self.rows) + num_cols = max(len(i) for i in string_rows) + for sr in string_rows: + sr.extend(["" for _ in range(num_cols - len(sr))]) + + col_widths = [max(len(j) for j in i) for i in zip(*string_rows, strict=True)] + finalized_columns = [" | ".join(i.center(w) for i, w in zip(string_rows[0], col_widths, strict=True))] + overall_width = len(finalized_columns[0]) + for string_row, row in zip(string_rows[1:], self.rows, strict=True): + finalized_columns.extend(row.row_separator(overall_width)) + finalized_columns.append(" | ".join(row.finalize_column_strings(string_row, col_widths))) + + newline = "\n" + has_warnings = self._highlight_warnings and any(ri.has_warnings for ri in self.results) + return f""" +[{(' ' + (self.label or '') + ' ').center(overall_width - 2, '-')}] +{newline.join(finalized_columns)} + +Times are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}). +{'(! XX%) Measurement has high variance, where XX is the IQR / median * 100.' + newline if has_warnings else ""}"""[1:] + + +class Compare: + """Helper class for displaying the results of many measurements in a + formatted table. + + The table format is based on the information fields provided in + :class:`torch.utils.benchmark.Timer` (`description`, `label`, `sub_label`, + `num_threads`, etc). + + The table can be directly printed using :meth:`print` or casted as a `str`. + + For a full tutorial on how to use this class, see: + https://pytorch.org/tutorials/recipes/recipes/benchmark.html + + Args: + results: List of Measurement to display. + """ + def __init__(self, results: list[common.Measurement]) -> None: + self._results: list[common.Measurement] = [] + self.extend_results(results) + self._trim_significant_figures = False + self._colorize = Colorize.NONE + self._highlight_warnings = False + + def __str__(self) -> str: + return "\n".join(self._render()) + + def extend_results(self, results) -> None: + """Append results to already stored ones. + + All added results must be instances of ``Measurement``. + """ + for r in results: + if not isinstance(r, common.Measurement): + raise ValueError( + "Expected an instance of `Measurement`, " f"got {type(r)} instead." + ) + self._results.extend(results) + + def trim_significant_figures(self) -> None: + """Enables trimming of significant figures when building the formatted table.""" + self._trim_significant_figures = True + + def colorize(self, rowwise=False) -> None: + """Colorize formatted table. + + Colorize columnwise by default. + """ + self._colorize = Colorize.ROWWISE if rowwise else Colorize.COLUMNWISE + + def highlight_warnings(self) -> None: + """Enables warning highlighting when building formatted table.""" + self._highlight_warnings = True + + def print(self) -> None: + """Print formatted table""" + print(str(self)) + + def _render(self): + results = common.Measurement.merge(self._results) + grouped_results = self._group_by_label(results) + output = [self._layout(group) for group in grouped_results.values()] + return output + + def _group_by_label(self, results: list[common.Measurement]): + grouped_results: collections.defaultdict[str, list[common.Measurement]] = collections.defaultdict(list) + for r in results: + grouped_results[r.label].append(r) + return grouped_results + + def _layout(self, results: list[common.Measurement]): + table = Table( + results, + self._colorize, + self._trim_significant_figures, + self._highlight_warnings + ) + return table.render() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/compile.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/compile.py new file mode 100644 index 0000000000000000000000000000000000000000..dd15a582a274980bea4aff22f7325ccf562ecb13 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/compile.py @@ -0,0 +1,195 @@ +# mypy: allow-untyped-defs +from typing import Any, cast +from collections.abc import Callable + +import torch +import torch._dynamo +from torch._dynamo.testing import CompileCounterWithBackend +from torch.utils.benchmark import Timer + + +__all__ = ["bench_all", "benchmark_compile"] + + +_warned_tensor_cores = False +_default_float_32_precision = torch.get_float32_matmul_precision() + +try: + + from tabulate import tabulate + + HAS_TABULATE = True +except ModuleNotFoundError: + HAS_TABULATE = False + tabulate = None # type: ignore[assignment] + print("tabulate is not installed, please pip install tabulate to use this utility") + +if HAS_TABULATE: + def _enable_tensor_cores() -> None: + global _warned_tensor_cores + + if torch.cuda.is_available(): + if torch.backends.cuda.matmul.allow_tf32 is False and torch.cuda.get_device_capability() >= (8, 0): + torch.set_float32_matmul_precision("high") + if not _warned_tensor_cores: + print("Your GPU supports tensor cores") + print("we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`") + _warned_tensor_cores = True + + def _disable_tensor_cores() -> None: + torch.set_float32_matmul_precision(_default_float_32_precision) + + def bench_loop( + model: torch.nn.Module | Callable, + sample_input: torch.Tensor | Any, + num_iters: int = 5, + optimizer: torch.optim.Optimizer | None = None, + loss_fn: Callable | None = None, + ): + # Define the statement and setup for the benchmark + if optimizer and loss_fn: + # Training mode + stmt = """ + output = model(sample_input) + loss = loss_fn(output) if loss_fn else output.sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + """ + else: + # Inference mode + stmt = "model(sample_input)" + + # Create the Timer object + timer = Timer( + stmt=stmt, + globals={"model": model, "sample_input": sample_input, "optimizer": optimizer, "loss_fn": loss_fn}, + ) + + + result = timer.timeit(number=num_iters) + + # Get the average time per iteration in milliseconds + avg_time = result.mean * 1000 + return round(avg_time, 2) + + def benchmark_compile( + model: torch.nn.Module | Callable, + sample_input: torch.Tensor | Any, + num_iters: int = 5, + backend: str | None = None, + mode: str | None = "default", + optimizer: torch.optim.Optimizer | None = None, + loss_fn : torch.nn.Module | Callable | None = None, + ): + """ + Use this utility to benchmark torch.compile + """ + if backend: + try: + torch._dynamo.reset() + compile_counter_with_backend = CompileCounterWithBackend(backend) + opt_model = torch.compile(model, backend=compile_counter_with_backend, mode=mode) + + # Compilation only happens after the first inference + compilation_time = bench_loop(opt_model, sample_input, 1, optimizer, loss_fn) + + running_time = bench_loop(opt_model, sample_input, num_iters, optimizer, loss_fn) + + if compile_counter_with_backend.frame_count == 0: + raise RuntimeError("No compilation occurred during benchmarking.") + + if compile_counter_with_backend.frame_count > 1: + raise RuntimeError("Recompilation occurred during benchmarking.") + + except Exception as e: + print(e) + print(f"Failed to compile {backend} with mode {mode}") + return None, None + else: + opt_model = model + compilation_time = None + running_time = bench_loop(opt_model, sample_input, num_iters, optimizer, loss_fn) + + compilation_time = round(compilation_time, 2) if compilation_time else None + running_time = round(running_time, 2) if running_time else None + + + return compilation_time, running_time + + + def bench_all( + model : torch.nn.Module | Callable, + sample_input: torch.Tensor | Any, + num_iters : int = 5, + optimizer: torch.optim.Optimizer | None = None, + loss_fn : torch.nn.Module | Callable | None = None, + ): + """ + This is a simple utility that can be used to benchmark torch.compile + In particular it ensures that your GPU is setup to use tensor cores if it supports its + It also tries out all the main backends and prints a table of results so you can easily compare them all + Many of the backendds have their own optional dependencies so please pip install them separately + + You will get one table for inference and another for training + If you'd like to leverage this utility for training make sure to pass in a torch.optim.Optimizer + + The important warnings are + Your GPU supports tensor cores + we will enable it automatically by setting `torch.set_float32_matmul_precision('high')` + + If a compilation fails for any reason including the dependency not being included + then we will print Failed to compile {backend} with mode {mode} + """ + field_names = ["Train/Inference", "Backend", "Mode", "Compilation Time", "Average Running Time"] + table = [] + + + eager_time = None + torch._dynamo.reset() + _, eager_time = benchmark_compile(model, sample_input, num_iters, None, None, optimizer) + table.append( + [("Training" if optimizer else "Inference"), "Eager", "-", "-", f"{eager_time} ms"] + ) + + for backend in torch._dynamo.list_backends(): + + if backend == "inductor": + mode_options = cast(list[str | None], list(torch._inductor.list_mode_options().keys())) + [None] + for mode in mode_options: + if mode == "default": + continue + torch._dynamo.reset() + try: + if torch.cuda.is_available(): + _enable_tensor_cores() + compilation_time, running_time = benchmark_compile( + model, sample_input, num_iters, backend, mode, optimizer, loss_fn) + finally: + if torch.cuda.is_available(): + _disable_tensor_cores() + table.append([ + ("Training" if optimizer else "Inference"), + # pyrefly: ignore [redundant-condition] + backend if backend else "-", + mode if mode is not None else "-", + f"{compilation_time} ms " if compilation_time else "-", + f"{running_time} ms " if running_time else "-", + ]) + + else: + torch._dynamo.reset() + compilation_time, running_time = benchmark_compile( + model, sample_input, num_iters, backend, None, optimizer, loss_fn) + + if running_time is not None: + table.append([ + ("Training" if optimizer else "Inference"), + backend, "-", + f"{compilation_time} ms " or "-", + f"{running_time} ms ", + ]) + + + # pyrefly: ignore [not-callable] + return tabulate(table, headers=field_names, tablefmt="github") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/cpp_jit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/cpp_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..a298146ce17c7ff6f303b4d76c4c96ba786ae774 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/cpp_jit.py @@ -0,0 +1,175 @@ +"""JIT C++ strings into executables.""" +import atexit +import os +import re +import shutil +import textwrap +import threading +from typing import Any + +import torch +from torch.utils.benchmark.utils._stubs import CallgrindModuleType, TimeitModuleType +from torch.utils.benchmark.utils.common import _make_temp_dir +from torch.utils import cpp_extension + + +LOCK = threading.Lock() +SOURCE_ROOT = os.path.split(os.path.abspath(__file__))[0] + +# We calculate uuid once at import time so that separate processes will have +# separate build roots, but threads will share the same build root. +# `cpp_extension` uses build root as part of the cache key, so per-invocation +# uuid's (e.g. different build root per _compile_template call) would lead to +# a 0% cache hit rate and spurious recompilation. Consider the following: +# ``` +# setup = "auto x = torch::ones({1024, 1024});" +# stmt = "torch::mm(x, x);" +# for num_threads in [1, 2, 4, 8]: +# print(Timer(stmt, setup, num_threads=num_threads, language="c++").blocked_autorange()) +# ```` +# `setup` and `stmt` do not change, so we can reuse the executable from the +# first pass through the loop. +_BUILD_ROOT: str | None = None + +def _get_build_root() -> str: + global _BUILD_ROOT + if _BUILD_ROOT is None: + _BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build") + # pyrefly: ignore [missing-argument] + atexit.register(shutil.rmtree, _BUILD_ROOT) + return _BUILD_ROOT + + +# BACK_TESTING_NOTE: +# There are two workflows where this code could be used. One is the obvious +# case where someone simply builds or installs PyTorch and uses Timer. +# The other is that the entire `torch/utils/benchmark` folder from a CURRENT +# PyTorch checkout is copy-pasted into a much OLDER version of the PyTorch +# source code. This is what we refer to here as "back testing". The rationale +# is that we might want to use current tooling to study some aspect of an +# earlier version of PyTorch. (e.g. a regression.) +# +# The problem is that Timer relies on several aspects of core PyTorch, namely +# some binding functions for Valgrind symbols in `torch._C` and the +# `torch.__config__._cxx_flags()` method. If we were to naively copy code +# around this wouldn't work as the symbols of interest aren't present in +# earlier versions of PyTorch. In order to work around this, we must add back +# testing shims. These shims will never activate during normal use, but will +# allow Timer to function outside of the "correct" version of PyTorch by +# emulating functionality that was added later. +# +# These shims are temporary, and as Timer becomes more integrated with +# PyTorch the cost and complexity of such shims will increase. Once back +# testing is no longer required (which is to say we have done enough historic +# analysis and the shims no longer justify their maintenance and code +# complexity costs) back testing paths will be removed. + +CXX_FLAGS: list[str] | None +if hasattr(torch.__config__, "_cxx_flags"): + try: + CXX_FLAGS = torch.__config__._cxx_flags().strip().split() + if CXX_FLAGS is not None and "-g" not in CXX_FLAGS: + CXX_FLAGS.append("-g") + # remove "-W" flags to allow build benchmarks + # with a relaxed constraint of compiler versions + if CXX_FLAGS is not None: + CXX_FLAGS = list(filter(lambda x: not x.startswith("-W"), CXX_FLAGS)) + + except RuntimeError: + # We are in FBCode. + CXX_FLAGS = None +else: + # FIXME: Remove when back testing is no longer required. + CXX_FLAGS = ["-O2", "-fPIC", "-g"] + +EXTRA_INCLUDE_PATHS: list[str] = [os.path.join(SOURCE_ROOT, "valgrind_wrapper")] +CONDA_PREFIX = os.getenv("CONDA_PREFIX") +if CONDA_PREFIX is not None: + # Load will automatically search /usr/include, but not conda include. + EXTRA_INCLUDE_PATHS.append(os.path.join(CONDA_PREFIX, "include")) + + +COMPAT_CALLGRIND_BINDINGS: CallgrindModuleType | None = None +def get_compat_bindings() -> CallgrindModuleType: + with LOCK: + global COMPAT_CALLGRIND_BINDINGS + if COMPAT_CALLGRIND_BINDINGS is None: + COMPAT_CALLGRIND_BINDINGS = cpp_extension.load( + name="callgrind_bindings", + sources=[os.path.join( + SOURCE_ROOT, + "valgrind_wrapper", + "compat_bindings.cpp" + )], + extra_cflags=CXX_FLAGS, + extra_include_paths=EXTRA_INCLUDE_PATHS, + ) + return COMPAT_CALLGRIND_BINDINGS + + +def _compile_template( + *, + stmt: str, + setup: str, + global_setup: str, + src: str, + is_standalone: bool +) -> Any: + for before, after, indentation in ( + ("// GLOBAL_SETUP_TEMPLATE_LOCATION", global_setup, 0), + ("// SETUP_TEMPLATE_LOCATION", setup, 4), + ("// STMT_TEMPLATE_LOCATION", stmt, 8) + ): + # C++ doesn't care about indentation so this code isn't load + # bearing the way it is with Python, but this makes the source + # look nicer if a human has to look at it. + src = re.sub( + before, + textwrap.indent(after, " " * indentation)[indentation:], + src + ) + + # We want to isolate different Timers. However `cpp_extension` will + # cache builds which will significantly reduce the cost of repeated + # invocations. + with LOCK: + name = f"timer_cpp_{abs(hash(src))}" + build_dir = os.path.join(_get_build_root(), name) + os.makedirs(build_dir, exist_ok=True) + + src_path = os.path.join(build_dir, "timer_src.cpp") + with open(src_path, "w") as f: + f.write(src) + + # `cpp_extension` has its own locking scheme, so we don't need our lock. + return cpp_extension.load( + name=name, + sources=[src_path], + build_directory=build_dir, + extra_cflags=CXX_FLAGS, + extra_include_paths=EXTRA_INCLUDE_PATHS, + is_python_module=not is_standalone, + is_standalone=is_standalone, + ) + + +def compile_timeit_template(*, stmt: str, setup: str, global_setup: str) -> TimeitModuleType: + template_path: str = os.path.join(SOURCE_ROOT, "timeit_template.cpp") + with open(template_path) as f: + src: str = f.read() + + module = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=False) + if not isinstance(module, TimeitModuleType): + raise AssertionError("compiled module is not a TimeitModuleType") + return module + + +def compile_callgrind_template(*, stmt: str, setup: str, global_setup: str) -> str: + template_path: str = os.path.join(SOURCE_ROOT, "valgrind_wrapper", "timer_callgrind_template.cpp") + with open(template_path) as f: + src: str = f.read() + + target = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=True) + if not isinstance(target, str): + raise AssertionError("compiled target path is not a string") + return target diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/fuzzer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/fuzzer.py new file mode 100644 index 0000000000000000000000000000000000000000..38f771d23632efd27239e460591d923be3ee59fc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/fuzzer.py @@ -0,0 +1,469 @@ +# mypy: allow-untyped-defs +import functools +import itertools as it +from typing import Any +from collections.abc import Callable + +import torch + + +__all__ = [ + "Fuzzer", + "FuzzedParameter", "ParameterAlias", + "FuzzedTensor", +] + + +_DISTRIBUTIONS = ( + "loguniform", + "uniform", +) + + +class FuzzedParameter: + """Specification for a parameter to be generated during fuzzing.""" + def __init__( + self, + name: str, + minval: int | float | None = None, + maxval: int | float | None = None, + distribution: str | dict[Any, float] | None = None, + strict: bool = False, + ) -> None: + """ + Args: + name: + A string name with which to identify the parameter. + FuzzedTensors can reference this string in their + specifications. + minval: + The lower bound for the generated value. See the description + of `distribution` for type behavior. + maxval: + The upper bound for the generated value. Type behavior is + identical to `minval`. + distribution: + Specifies the distribution from which this parameter should + be drawn. There are three possibilities: + - "loguniform" + Samples between `minval` and `maxval` (inclusive) such + that the probabilities are uniform in log space. As a + concrete example, if minval=1 and maxval=100, a sample + is as likely to fall in [1, 10) as it is [10, 100]. + - "uniform" + Samples are chosen with uniform probability between + `minval` and `maxval` (inclusive). If either `minval` + or `maxval` is a float then the distribution is the + continuous uniform distribution; otherwise samples + are constrained to the integers. + - dict: + If a dict is passed, the keys are taken to be choices + for the variables and the values are interpreted as + probabilities. (And must sum to one.) + If a dict is passed, `minval` and `maxval` must not be set. + Otherwise, they must be set. + strict: + If a parameter is strict, it will not be included in the + iterative resampling process which Fuzzer uses to find a + valid parameter configuration. This allows an author to + prevent skew from resampling for a given parameter (for + instance, a low size limit could inadvertently bias towards + Tensors with fewer dimensions) at the cost of more iterations + when generating parameters. + """ + self._name = name + self._minval = minval + self._maxval = maxval + self._distribution = self._check_distribution(distribution) + self.strict = strict + + @property + def name(self): + return self._name + + def sample(self, state): + if self._distribution == "loguniform": + return self._loguniform(state) + + if self._distribution == "uniform": + return self._uniform(state) + + if isinstance(self._distribution, dict): + return self._custom_distribution(state) + + def _check_distribution(self, distribution): + if not isinstance(distribution, dict): + if distribution not in _DISTRIBUTIONS: + raise AssertionError(f"Unknown distribution: {distribution}") + else: + if any(i < 0 for i in distribution.values()): + raise AssertionError("Probabilities cannot be negative") + if not abs(sum(distribution.values()) - 1) > 1e-5: + raise AssertionError("Distribution is not normalized") + if self._minval is not None: + raise AssertionError("When passing a custom distribution, 'minval' must be None") + if self._maxval is not None: + raise AssertionError("When passing a custom distribution, 'maxval' must be None") + + return distribution + + def _loguniform(self, state): + import numpy as np + output = int(2 ** state.uniform( + low=np.log2(self._minval) if self._minval is not None else None, + high=np.log2(self._maxval) if self._maxval is not None else None, + )) + if self._minval is not None and output < self._minval: + return self._minval + if self._maxval is not None and output > self._maxval: + return self._maxval + return output + + def _uniform(self, state): + if isinstance(self._minval, int) and isinstance(self._maxval, int): + return int(state.randint(low=self._minval, high=self._maxval + 1)) + return state.uniform(low=self._minval, high=self._maxval) + + def _custom_distribution(self, state): + import numpy as np + # If we directly pass the keys to `choice`, numpy will convert + # them to numpy dtypes. + index = state.choice( + np.arange(len(self._distribution)), + p=tuple(self._distribution.values())) + return list(self._distribution.keys())[index] + + +class ParameterAlias: + """Indicates that a parameter should alias the value of another parameter. + + When used in conjunction with a custom distribution, this allows fuzzed + tensors to represent a broader range of behaviors. For example, the + following sometimes produces Tensors which broadcast: + + Fuzzer( + parameters=[ + FuzzedParameter("x_len", 4, 1024, distribution="uniform"), + + # `y` will either be size one, or match the size of `x`. + FuzzedParameter("y_len", distribution={ + 0.5: 1, + 0.5: ParameterAlias("x_len") + }), + ], + tensors=[ + FuzzedTensor("x", size=("x_len",)), + FuzzedTensor("y", size=("y_len",)), + ], + ) + + Chains of alias' are allowed, but may not contain cycles. + """ + def __init__(self, alias_to) -> None: + self.alias_to = alias_to + + def __repr__(self) -> str: + return f"ParameterAlias[alias_to: {self.alias_to}]" + + +def dtype_size(dtype): + if dtype == torch.bool: + return 1 + if dtype.is_floating_point or dtype.is_complex: + return int(torch.finfo(dtype).bits / 8) + return int(torch.iinfo(dtype).bits / 8) + + +def prod(values, base=1): + """np.prod can overflow, so for sizes the product should be done in Python. + + Even though np.prod type promotes to int64, it can still overflow in which + case the negative value will pass the size check and OOM when attempting to + actually allocate the Tensor. + """ + return functools.reduce(lambda x, y: int(x) * int(y), values, base) + + +class FuzzedTensor: + def __init__( + self, + name: str, + size: tuple[str | int, ...], + steps: tuple[str | int, ...] | None = None, + probability_contiguous: float = 0.5, + min_elements: int | None = None, + max_elements: int | None = None, + max_allocation_bytes: int | None = None, + dim_parameter: str | None = None, + roll_parameter: str | None = None, + dtype=torch.float32, + cuda=False, + tensor_constructor: Callable | None = None + ) -> None: + """ + Args: + name: + A string identifier for the generated Tensor. + size: + A tuple of integers or strings specifying the size of the generated + Tensor. String values will replaced with a concrete int during the + generation process, while ints are simply passed as literals. + steps: + An optional tuple with the same length as `size`. This indicates + that a larger Tensor should be allocated, and then sliced to + produce the generated Tensor. For instance, if size is (4, 8) + and steps is (1, 4), then a tensor `t` of size (4, 32) will be + created and then `t[:, ::4]` will be used. (Allowing one to test + Tensors with strided memory.) + probability_contiguous: + A number between zero and one representing the chance that the + generated Tensor has a contiguous memory layout. This is achieved by + randomly permuting the shape of a Tensor, calling `.contiguous()`, + and then permuting back. This is applied before `steps`, which can + also cause a Tensor to be non-contiguous. + min_elements: + The minimum number of parameters that this Tensor must have for a + set of parameters to be valid. (Otherwise they are resampled.) + max_elements: + Like `min_elements`, but setting an upper bound. + max_allocation_bytes: + Like `max_elements`, but for the size of Tensor that must be + allocated prior to slicing for `steps` (if applicable). For + example, a FloatTensor with size (1024, 1024) and steps (4, 4) + would have 1M elements, but would require a 64 MB allocation. + dim_parameter: + The length of `size` and `steps` will be truncated to this value. + This allows Tensors of varying dimensions to be generated by the + Fuzzer. + dtype: + The PyTorch dtype of the generated Tensor. + cuda: + Whether to place the Tensor on a GPU. + tensor_constructor: + Callable which will be used instead of the default Tensor + construction method. This allows the author to enforce properties + of the Tensor (e.g. it can only have certain values). The dtype and + concrete shape of the Tensor to be created will be passed, and + concrete values of all parameters will be passed as kwargs. Note + that transformations to the result (permuting, slicing) will be + performed by the Fuzzer; the tensor_constructor is only responsible + for creating an appropriately sized Tensor. + """ + self._name = name + self._size = size + self._steps = steps + self._probability_contiguous = probability_contiguous + self._min_elements = min_elements + self._max_elements = max_elements + self._max_allocation_bytes = max_allocation_bytes + self._dim_parameter = dim_parameter + self._dtype = dtype + self._cuda = cuda + self._tensor_constructor = tensor_constructor + + @property + def name(self): + return self._name + + @staticmethod + def default_tensor_constructor(size, dtype, **kwargs): + if dtype.is_floating_point or dtype.is_complex: + return torch.rand(size=size, dtype=dtype, device="cpu") + else: + return torch.randint(1, 127, size=size, dtype=dtype, device="cpu") + + def _make_tensor(self, params, state): + import numpy as np + size, steps, allocation_size = self._get_size_and_steps(params) + constructor = ( + self._tensor_constructor or + self.default_tensor_constructor + ) + + raw_tensor = constructor(size=allocation_size, dtype=self._dtype, **params) + if self._cuda: + raw_tensor = raw_tensor.cuda() + + # Randomly permute the Tensor and call `.contiguous()` to force re-ordering + # of the memory, and then permute it back to the original shape. + dim = len(size) + order = np.arange(dim) + if state.rand() > self._probability_contiguous: + while dim > 1 and np.all(order == np.arange(dim)): + order = state.permutation(raw_tensor.dim()) + + raw_tensor = raw_tensor.permute(tuple(order)).contiguous() + raw_tensor = raw_tensor.permute(tuple(np.argsort(order))) + + slices = [slice(0, size * step, step) for size, step in zip(size, steps, strict=True)] + tensor = raw_tensor[tuple(slices)] + + properties = { + "numel": int(tensor.numel()), + "order": order, + "steps": steps, + "is_contiguous": tensor.is_contiguous(), + "dtype": str(self._dtype), + } + + return tensor, properties + + def _get_size_and_steps(self, params): + dim = ( + params[self._dim_parameter] + if self._dim_parameter is not None + else len(self._size) + ) + + def resolve(values, dim): + """Resolve values into concrete integers.""" + values = tuple(params.get(i, i) for i in values) + if len(values) > dim: + values = values[:dim] + if len(values) < dim: + values = values + tuple(1 for _ in range(dim - len(values))) + return values + + size = resolve(self._size, dim) + steps = resolve(self._steps or (), dim) + allocation_size = tuple(size_i * step_i for size_i, step_i in zip(size, steps, strict=True)) + return size, steps, allocation_size + + def satisfies_constraints(self, params) -> bool: + size, _, allocation_size = self._get_size_and_steps(params) + # Product is computed in Python to avoid integer overflow. + num_elements = prod(size) + if num_elements < 0: + raise AssertionError("Computed number of elements is negative") + + allocation_bytes = prod(allocation_size, base=dtype_size(self._dtype)) + + def nullable_greater(left, right): + if left is None or right is None: + return False + return left > right + + return not any(( + nullable_greater(num_elements, self._max_elements), + nullable_greater(self._min_elements, num_elements), + nullable_greater(allocation_bytes, self._max_allocation_bytes), + )) + + +class Fuzzer: + def __init__( + self, + parameters: list[FuzzedParameter | list[FuzzedParameter]], + tensors: list[FuzzedTensor | list[FuzzedTensor]], + constraints: list[Callable] | None = None, + seed: int | None = None + ) -> None: + """ + Args: + parameters: + List of FuzzedParameters which provide specifications + for generated parameters. Iterable elements will be + unpacked, though arbitrary nested structures will not. + tensors: + List of FuzzedTensors which define the Tensors which + will be created each step based on the parameters for + that step. Iterable elements will be unpacked, though + arbitrary nested structures will not. + constraints: + List of callables. They will be called with params + as kwargs, and if any of them return False the current + set of parameters will be rejected. + seed: + Seed for the RandomState used by the Fuzzer. This will + also be used to set the PyTorch random seed so that random + ops will create reproducible Tensors. + """ + import numpy as np + if seed is None: + seed = int(np.random.RandomState().randint(0, 2 ** 32 - 1, dtype=np.int64)) + self._seed = seed + self._parameters = Fuzzer._unpack(parameters, FuzzedParameter) + self._tensors = Fuzzer._unpack(tensors, FuzzedTensor) + self._constraints = constraints or () + + p_names = {p.name for p in self._parameters} + t_names = {t.name for t in self._tensors} + name_overlap = p_names.intersection(t_names) + if name_overlap: + raise ValueError(f"Duplicate names in parameters and tensors: {name_overlap}") + + self._rejections = 0 + self._total_generated = 0 + + @staticmethod + def _unpack(values, cls): + return tuple(it.chain.from_iterable( + [[i] if isinstance(i, cls) else i for i in values] + )) + + def take(self, n): + import numpy as np + state = np.random.RandomState(self._seed) + torch.manual_seed(state.randint(low=0, high=2 ** 63, dtype=np.int64)) + for _ in range(n): + params = self._generate(state) + tensors = {} + tensor_properties = {} + for t in self._tensors: + tensor, properties = t._make_tensor(params, state) + tensors[t.name] = tensor + tensor_properties[t.name] = properties + yield tensors, tensor_properties, params + + @property + def rejection_rate(self): + if not self._total_generated: + return 0. + return self._rejections / self._total_generated + + def _generate(self, state): + strict_params: dict[str, float | int | ParameterAlias] = {} + for _ in range(1000): + candidate_params: dict[str, float | int | ParameterAlias] = {} + for p in self._parameters: + if p.strict: + if p.name in strict_params: + candidate_params[p.name] = strict_params[p.name] + else: + candidate_params[p.name] = p.sample(state) + strict_params[p.name] = candidate_params[p.name] + else: + candidate_params[p.name] = p.sample(state) + + candidate_params = self._resolve_aliases(candidate_params) + + self._total_generated += 1 + if not all(f(candidate_params) for f in self._constraints): + self._rejections += 1 + continue + + if not all(t.satisfies_constraints(candidate_params) for t in self._tensors): + self._rejections += 1 + continue + + return candidate_params + raise ValueError("Failed to generate a set of valid parameters.") + + @staticmethod + def _resolve_aliases(params): + params = dict(params) + alias_count = sum(isinstance(v, ParameterAlias) for v in params.values()) + + keys = list(params.keys()) + while alias_count: + for k in keys: + v = params[k] + if isinstance(v, ParameterAlias): + params[k] = params[v.alias_to] + alias_count_new = sum(isinstance(v, ParameterAlias) for v in params.values()) + if alias_count == alias_count_new: + raise ValueError(f"ParameterAlias cycle detected\n{params}") + + alias_count = alias_count_new + + return params diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/sparse_fuzzer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/sparse_fuzzer.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a573b9b44fdc2ee3c5141d8badc75a2b484a78 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/sparse_fuzzer.py @@ -0,0 +1,126 @@ +# mypy: allow-untyped-defs +from numbers import Number +import torch +from torch.utils.benchmark import FuzzedTensor +import math + +class FuzzedSparseTensor(FuzzedTensor): + def __init__( + self, + name: str, + size: tuple[str | int, ...], + min_elements: int | None = None, + max_elements: int | None = None, + dim_parameter: str | None = None, + sparse_dim: str | None = None, + nnz: str | None = None, + density: str | None = None, + coalesced: str | None = None, + dtype=torch.float32, + cuda=False + ) -> None: + """ + Args: + name: + A string identifier for the generated Tensor. + size: + A tuple of integers or strings specifying the size of the generated + Tensor. String values will replaced with a concrete int during the + generation process, while ints are simply passed as literals. + min_elements: + The minimum number of parameters that this Tensor must have for a + set of parameters to be valid. (Otherwise they are resampled.) + max_elements: + Like `min_elements`, but setting an upper bound. + dim_parameter: + The length of `size` will be truncated to this value. + This allows Tensors of varying dimensions to be generated by the + Fuzzer. + sparse_dim: + The number of sparse dimensions in a sparse tensor. + density: + This value allows tensors of varying sparsities to be generated by the Fuzzer. + coalesced: + The sparse tensor format permits uncoalesced sparse tensors, + where there may be duplicate coordinates in the indices. + dtype: + The PyTorch dtype of the generated Tensor. + cuda: + Whether to place the Tensor on a GPU. + """ + super().__init__(name=name, size=size, min_elements=min_elements, + max_elements=max_elements, dim_parameter=dim_parameter, dtype=dtype, cuda=cuda) + self._density = density + self._coalesced = coalesced + self._sparse_dim = sparse_dim + + @staticmethod + def sparse_tensor_constructor(size, dtype, sparse_dim, nnz, is_coalesced): + """sparse_tensor_constructor creates a sparse tensor with coo format. + + Note that when `is_coalesced` is False, the number of elements is doubled but the number of indices + represents the same amount of number of non zeros `nnz`, i.e, this is virtually the same tensor + with the same sparsity pattern. Moreover, most of the sparse operation will use coalesce() method + and what we want here is to get a sparse tensor with the same `nnz` even if this is coalesced or not. + + In the other hand when `is_coalesced` is True the number of elements is reduced in the coalescing process + by an unclear amount however the probability to generate duplicates indices are low for most of the cases. + This decision was taken on purpose to maintain the construction cost as low as possible. + """ + if isinstance(size, Number): + size = [size] * sparse_dim + if all(size[d] <= 0 for d in range(sparse_dim)) and nnz != 0: + raise AssertionError('invalid arguments') + v_size = [nnz] + list(size[sparse_dim:]) + if dtype.is_floating_point: + v = torch.rand(size=v_size, dtype=dtype, device="cpu") + else: + v = torch.randint(1, 127, size=v_size, dtype=dtype, device="cpu") + + i = torch.rand(sparse_dim, nnz, device="cpu") + i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i)) + i = i.to(torch.long) + + if not is_coalesced: + v = torch.cat([v, torch.randn_like(v)], 0) + i = torch.cat([i, i], 1) + + x = torch.sparse_coo_tensor(i, v, torch.Size(size)) + if is_coalesced: + x = x.coalesce() + return x + + def _make_tensor(self, params, state): + # pyrefly: ignore [missing-attribute] + size, _, _ = self._get_size_and_steps(params) + density = params['density'] + nnz = math.ceil(sum(size) * density) + if nnz > sum(size): + raise AssertionError('nnz cannot exceed total number of elements') + + is_coalesced = params['coalesced'] + sparse_dim = params['sparse_dim'] if self._sparse_dim else len(size) + sparse_dim = min(sparse_dim, len(size)) + # pyrefly: ignore [missing-attribute] + tensor = self.sparse_tensor_constructor(size, self._dtype, sparse_dim, nnz, is_coalesced) + + # pyrefly: ignore [missing-attribute] + if self._cuda: + tensor = tensor.cuda() + sparse_dim = tensor.sparse_dim() + dense_dim = tensor.dense_dim() + is_hybrid = len(size[sparse_dim:]) > 0 + + properties = { + "numel": int(tensor.numel()), + "shape": tensor.size(), + "is_coalesced": tensor.is_coalesced(), + "density": density, + "sparsity": 1.0 - density, + "sparse_dim": sparse_dim, + "dense_dim": dense_dim, + "is_hybrid": is_hybrid, + # pyrefly: ignore [missing-attribute] + "dtype": str(self._dtype), + } + return tensor, properties diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/timeit_template.cpp b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/timeit_template.cpp new file mode 100644 index 0000000000000000000000000000000000000000..30b6f79c0b5aebca676035ac0b7c08cfce639b23 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/timeit_template.cpp @@ -0,0 +1,43 @@ +/* C++ template for Timer.timeit + +This template will be consumed by `cpp_jit.py`, and will replace: + `GLOBAL_SETUP_TEMPLATE_LOCATION`, + `SETUP_TEMPLATE_LOCATION` + and + `STMT_TEMPLATE_LOCATION` +sections with user provided statements. +*/ +#include + +#include +#include +#include +#include + +// Global setup. (e.g. #includes) +// GLOBAL_SETUP_TEMPLATE_LOCATION + +double timeit(int n) { + pybind11::gil_scoped_release no_gil; + + // Setup + // SETUP_TEMPLATE_LOCATION + + { + // Warmup + // STMT_TEMPLATE_LOCATION + } + + // Main loop + auto start_time = std::chrono::high_resolution_clock::now(); + for (const auto loop_idx : c10::irange(n)) { + (void)loop_idx; + // STMT_TEMPLATE_LOCATION + } + auto end_time = std::chrono::high_resolution_clock::now(); + return std::chrono::duration(end_time - start_time).count(); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("timeit", &timeit); +} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/timer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..f131261b8f36d08e4d9ef87605f379c4215d63ea --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/timer.py @@ -0,0 +1,533 @@ +"""Timer class based on the timeit.Timer class, but torch aware.""" +import enum +import timeit +import textwrap +from typing import overload, Any, NoReturn +from collections.abc import Callable + +import torch +from torch.utils.benchmark.utils import common, cpp_jit +from torch.utils.benchmark.utils._stubs import TimerClass, TimeitModuleType +from torch.utils.benchmark.utils.valgrind_wrapper import timer_interface as valgrind_timer_interface + + +__all__ = ["Timer", "timer", "Language"] + + +if torch.accelerator.is_available(): + def timer() -> float: + torch.accelerator.synchronize() + return timeit.default_timer() +else: + timer = timeit.default_timer + + +class Language(enum.Enum): + PYTHON = 0 + CPP = 1 + + +class CPPTimer: + def __init__( + self, + stmt: str, + setup: str, + global_setup: str, + timer: Callable[[], float], + globals: dict[str, Any], + ) -> None: + if timer is not timeit.default_timer: + raise NotImplementedError( + "PyTorch was built with accelerators and an accelerator is present; however " + "Timer does not yet support accelerator measurements. If your " + "code is CPU only, pass `timer=timeit.default_timer` to the " + "Timer's constructor to indicate this. (Note that this will " + "produce incorrect results if an accelerator is in fact used, as " + "Timer will not synchronize the accelerator.)" + ) + + if globals: + raise ValueError("C++ timing does not support globals.") + + self._stmt: str = textwrap.dedent(stmt) + self._setup: str = textwrap.dedent(setup) + self._global_setup: str = textwrap.dedent(global_setup) + self._timeit_module: TimeitModuleType | None = None + + def timeit(self, number: int) -> float: + if self._timeit_module is None: + self._timeit_module = cpp_jit.compile_timeit_template( + stmt=self._stmt, + setup=self._setup, + global_setup=self._global_setup, + ) + + return self._timeit_module.timeit(number) + + +class Timer: + """Helper class for measuring execution time of PyTorch statements. + + For a full tutorial on how to use this class, see: + https://pytorch.org/tutorials/recipes/recipes/benchmark.html + + The PyTorch Timer is based on `timeit.Timer` (and in fact uses + `timeit.Timer` internally), but with several key differences: + + 1) Runtime aware: + Timer will perform warmups (important as some elements of PyTorch are + lazily initialized), set threadpool size so that comparisons are + apples-to-apples, and synchronize asynchronous accelerator functions when + necessary. + + 2) Focus on replicates: + When measuring code, and particularly complex kernels / models, + run-to-run variation is a significant confounding factor. It is + expected that all measurements should include replicates to quantify + noise and allow median computation, which is more robust than mean. + To that effect, this class deviates from the `timeit` API by + conceptually merging `timeit.Timer.repeat` and `timeit.Timer.autorange`. + (Exact algorithms are discussed in method docstrings.) The `timeit` + method is replicated for cases where an adaptive strategy is not + desired. + + 3) Optional metadata: + When defining a Timer, one can optionally specify `label`, `sub_label`, + `description`, and `env`. (Defined later) These fields are included in + the representation of result object and by the `Compare` class to group + and display results for comparison. + + 4) Instruction counts + In addition to wall times, Timer can run a statement under Callgrind + and report instructions executed. + + Directly analogous to `timeit.Timer` constructor arguments: + + `stmt`, `setup`, `timer`, `globals` + + PyTorch Timer specific constructor arguments: + + `label`, `sub_label`, `description`, `env`, `num_threads` + + Args: + stmt: Code snippet to be run in a loop and timed. + + setup: Optional setup code. Used to define variables used in `stmt` + + global_setup: (C++ only) + Code which is placed at the top level of the file for things like + `#include` statements. + + timer: + Callable which returns the current time. If PyTorch was built + without accelerators or there is no accelerator present, this defaults to + `timeit.default_timer`; otherwise it will synchronize accelerators before + measuring the time. + + globals: + A dict which defines the global variables when `stmt` is being + executed. This is the other method for providing variables which + `stmt` needs. + + label: + String which summarizes `stmt`. For instance, if `stmt` is + "torch.nn.functional.relu(torch.add(x, 1, out=out))" + one might set label to "ReLU(x + 1)" to improve readability. + + sub_label: + Provide supplemental information to disambiguate measurements + with identical stmt or label. For instance, in our example + above sub_label might be "float" or "int", so that it is easy + to differentiate: + "ReLU(x + 1): (float)" + + "ReLU(x + 1): (int)" + when printing Measurements or summarizing using `Compare`. + + description: + String to distinguish measurements with identical label and + sub_label. The principal use of `description` is to signal to + `Compare` the columns of data. For instance one might set it + based on the input size to create a table of the form: :: + + | n=1 | n=4 | ... + ------------- ... + ReLU(x + 1): (float) | ... | ... | ... + ReLU(x + 1): (int) | ... | ... | ... + + + using `Compare`. It is also included when printing a Measurement. + + env: + This tag indicates that otherwise identical tasks were run in + different environments, and are therefore not equivalent, for + instance when A/B testing a change to a kernel. `Compare` will + treat Measurements with different `env` specification as distinct + when merging replicate runs. + + num_threads: + The size of the PyTorch threadpool when executing `stmt`. Single + threaded performance is important as both a key inference workload + and a good indicator of intrinsic algorithmic efficiency, so the + default is set to one. This is in contrast to the default PyTorch + threadpool size which tries to utilize all cores. + """ + + _timer_cls: type[TimerClass] = timeit.Timer + + def __init__( + self, + stmt: str = "pass", + setup: str = "pass", + global_setup: str = "", + timer: Callable[[], float] = timer, + globals: dict[str, Any] | None = None, + label: str | None = None, + sub_label: str | None = None, + description: str | None = None, + env: str | None = None, + num_threads: int = 1, + language: Language | str = Language.PYTHON, + ) -> None: + if not isinstance(stmt, str): + raise ValueError("Currently only a `str` stmt is supported.") + + # We copy `globals` to prevent mutations from leaking. + # (For instance, `eval` adds the `__builtins__` key) + self._globals = dict(globals or {}) + + timer_kwargs = {} + if language in (Language.PYTHON, "py", "python"): + # Include `torch` if not specified as a convenience feature. + self._globals.setdefault("torch", torch) + self._language: Language = Language.PYTHON + if global_setup: + raise ValueError( + f"global_setup is C++ only, got `{global_setup}`. Most " + "likely this code can simply be moved to `setup`." + ) + + elif language in (Language.CPP, "cpp", "c++"): + if self._timer_cls is not timeit.Timer: + raise AssertionError("_timer_cls has already been swapped.") + self._timer_cls = CPPTimer + setup = ("" if setup == "pass" else setup) + self._language = Language.CPP + timer_kwargs["global_setup"] = global_setup + + else: + raise ValueError(f"Invalid language `{language}`.") + + # Convenience adjustment so that multi-line code snippets defined in + # functions do not IndentationError (Python) or look odd (C++). The + # leading newline removal is for the initial newline that appears when + # defining block strings. For instance: + # textwrap.dedent(""" + # print("This is a stmt") + # """) + # produces '\nprint("This is a stmt")\n'. + # + # Stripping this down to 'print("This is a stmt")' doesn't change + # what gets executed, but it makes __repr__'s nicer. + stmt = textwrap.dedent(stmt) + stmt = (stmt[1:] if stmt and stmt[0] == "\n" else stmt).rstrip() + setup = textwrap.dedent(setup) + setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip() + + # pyrefly: ignore [bad-instantiation] + self._timer = self._timer_cls( + stmt=stmt, + setup=setup, + timer=timer, + globals=valgrind_timer_interface.CopyIfCallgrind.unwrap_all(self._globals), + **timer_kwargs, + ) + self._task_spec = common.TaskSpec( + stmt=stmt, + setup=setup, + global_setup=global_setup, + label=label, + sub_label=sub_label, + description=description, + env=env, + num_threads=num_threads, + ) + + def _timeit(self, number: int) -> float: + # Even calling a timer in C++ takes ~50 ns, so no real operation should + # take less than 1 ns. (And this prevents divide by zero errors.) + return max(self._timer.timeit(number), 1e-9) + + def timeit(self, number: int = 1000000) -> common.Measurement: + """Mirrors the semantics of timeit.Timer.timeit(). + + Execute the main statement (`stmt`) `number` times. + https://docs.python.org/3/library/timeit.html#timeit.Timer.timeit + """ + with common.set_torch_threads(self._task_spec.num_threads): + # Warmup + self._timeit(number=max(int(number // 100), 2)) + + return common.Measurement( + number_per_run=number, + raw_times=[self._timeit(number=number)], + task_spec=self._task_spec + ) + + def repeat(self, repeat: int = -1, number: int = -1) -> None: + raise NotImplementedError("See `Timer.blocked_autorange.`") + + def autorange(self, callback: Callable[[int, float], NoReturn] | None = None) -> None: + raise NotImplementedError("See `Timer.blocked_autorange.`") + + def _threaded_measurement_loop( + self, + number: int, + time_hook: Callable[[], float], + stop_hook: Callable[[list[float]], bool], + min_run_time: float, + max_run_time: float | None = None, + callback: Callable[[int, float], NoReturn] | None = None + ) -> list[float]: + total_time = 0.0 + can_stop = False + times: list[float] = [] + with common.set_torch_threads(self._task_spec.num_threads): + while (total_time < min_run_time) or (not can_stop): + time_spent = time_hook() + times.append(time_spent) + total_time += time_spent + if callback: + callback(number, time_spent) + can_stop = stop_hook(times) + if max_run_time and total_time > max_run_time: + break + return times + + def _estimate_block_size(self, min_run_time: float) -> int: + with common.set_torch_threads(self._task_spec.num_threads): + # Estimate the block size needed for measurement to be negligible + # compared to the inner loop. This also serves as a warmup. + overhead = torch.tensor([self._timeit(0) for _ in range(5)]).median().item() + number = 1 + while True: + time_taken = self._timeit(number) + relative_overhead = overhead / time_taken + if relative_overhead <= 1e-4 and time_taken >= min_run_time / 1000: + break + if time_taken > min_run_time: + break + # Avoid overflow in C++ pybind11 interface + if number * 10 > 2147483647: + break + number *= 10 + return number + + def blocked_autorange( + self, + callback: Callable[[int, float], NoReturn] | None = None, + min_run_time: float = 0.2, + ) -> common.Measurement: + """Measure many replicates while keeping timer overhead to a minimum. + + At a high level, blocked_autorange executes the following pseudo-code:: + + `setup` + + total_time = 0 + while total_time < min_run_time + start = timer() + for _ in range(block_size): + `stmt` + total_time += (timer() - start) + + Note the variable `block_size` in the inner loop. The choice of block + size is important to measurement quality, and must balance two + competing objectives: + + 1) A small block size results in more replicates and generally + better statistics. + + 2) A large block size better amortizes the cost of `timer` + invocation, and results in a less biased measurement. This is + important because accelerator synchronization time is non-trivial + (order single to low double digit microseconds) and would + otherwise bias the measurement. + + blocked_autorange sets block_size by running a warmup period, + increasing block size until timer overhead is less than 0.1% of + the overall computation. This value is then used for the main + measurement loop. + + Returns: + A `Measurement` object that contains measured runtimes and + repetition counts, and can be used to compute statistics. + (mean, median, etc.) + """ + number = self._estimate_block_size(min_run_time) + + def time_hook() -> float: + return self._timeit(number) + + def stop_hook(times: list[float]) -> bool: + return True + + times = self._threaded_measurement_loop( + number, time_hook, stop_hook, + min_run_time=min_run_time, + callback=callback) + + return common.Measurement( + number_per_run=number, + raw_times=times, + task_spec=self._task_spec + ) + + def adaptive_autorange( + self, + threshold: float = 0.1, + *, + min_run_time: float = 0.01, + max_run_time: float = 10.0, + callback: Callable[[int, float], NoReturn] | None = None, + ) -> common.Measurement: + """Similar to `blocked_autorange` but also checks for variablility in measurements + and repeats until iqr/median is smaller than `threshold` or `max_run_time` is reached. + + + At a high level, adaptive_autorange executes the following pseudo-code:: + + `setup` + + times = [] + while times.sum < max_run_time + start = timer() + for _ in range(block_size): + `stmt` + times.append(timer() - start) + + enough_data = len(times)>3 and times.sum > min_run_time + small_iqr=times.iqr/times.mean float: + return self._timeit(number) + + def stop_hook(times: list[float]) -> bool: + if len(times) > 3: + return common.Measurement( + number_per_run=number, + raw_times=times, + task_spec=self._task_spec + ).meets_confidence(threshold=threshold) + return False + times = self._threaded_measurement_loop( + number, time_hook, stop_hook, min_run_time, max_run_time, callback=callback) + + return common.Measurement( + number_per_run=number, + raw_times=times, + task_spec=self._task_spec + ) + + @overload + def collect_callgrind( + self, + number: int, + *, + repeats: None, + collect_baseline: bool, + retain_out_file: bool, + ) -> valgrind_timer_interface.CallgrindStats: + ... + + @overload + def collect_callgrind( + self, + number: int, + *, + repeats: int, + collect_baseline: bool, + retain_out_file: bool, + ) -> tuple[valgrind_timer_interface.CallgrindStats, ...]: + ... + + def collect_callgrind( + self, + number: int = 100, + *, + repeats: int | None = None, + collect_baseline: bool = True, + retain_out_file: bool = False, + ) -> Any: + """Collect instruction counts using Callgrind. + + Unlike wall times, instruction counts are deterministic + (modulo non-determinism in the program itself and small amounts of + jitter from the Python interpreter.) This makes them ideal for detailed + performance analysis. This method runs `stmt` in a separate process + so that Valgrind can instrument the program. Performance is severely + degraded due to the instrumentation, however this is ameliorated by + the fact that a small number of iterations is generally sufficient to + obtain good measurements. + + In order to use this method `valgrind`, `callgrind_control`, and + `callgrind_annotate` must be installed. + + Because there is a process boundary between the caller (this process) + and the `stmt` execution, `globals` cannot contain arbitrary in-memory + data structures. (Unlike timing methods) Instead, globals are + restricted to builtins, `nn.Modules`'s, and TorchScripted functions/modules + to reduce the surprise factor from serialization and subsequent + deserialization. The `GlobalsBridge` class provides more detail on this + subject. Take particular care with nn.Modules: they rely on pickle and + you may need to add an import to `setup` for them to transfer properly. + + By default, a profile for an empty statement will be collected and + cached to indicate how many instructions are from the Python loop which + drives `stmt`. + + Returns: + A `CallgrindStats` object which provides instruction counts and + some basic facilities for analyzing and manipulating results. + """ + if not isinstance(self._task_spec.stmt, str): + raise ValueError("`collect_callgrind` currently only supports string `stmt`") + + if repeats is not None and repeats < 1: + raise ValueError("If specified, `repeats` must be >= 1") + + # Check that the statement is valid. It doesn't guarantee success, but it's much + # simpler and quicker to raise an exception for a faulty `stmt` or `setup` in + # the parent process rather than the valgrind subprocess. + self._timeit(1) + is_python = (self._language == Language.PYTHON) + if not is_python and self._globals: + raise AssertionError("_timer globals are only supported for Python timers") + result = valgrind_timer_interface.wrapper_singleton().collect_callgrind( + task_spec=self._task_spec, + globals=self._globals, + number=number, + repeats=repeats or 1, + collect_baseline=collect_baseline and is_python, + is_python=is_python, + retain_out_file=retain_out_file, + ) + + return (result[0] if repeats is None else result) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e04973a065e0f50a966b167f59bd1ec45b7cadb9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/timer_interface.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/timer_interface.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b510099bd1ec22373f1650908aff3b85f99422a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/timer_interface.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h new file mode 100644 index 0000000000000000000000000000000000000000..f078cc82b95daf94d2bea51f1e1b1a8c12daea23 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h @@ -0,0 +1,129 @@ + +/* + ---------------------------------------------------------------- + + Notice that the following BSD-style license applies to this one + file (callgrind.h) only. The rest of Valgrind is licensed under the + terms of the GNU General Public License, version 2, unless + otherwise indicated. See the COPYING file in the source + distribution for details. + + ---------------------------------------------------------------- + + This file is part of callgrind, a valgrind tool for cache simulation + and call tree tracing. + + Copyright (C) 2003-2017 Josef Weidendorfer. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. The origin of this software must not be misrepresented; you must + not claim that you wrote the original software. If you use this + software in a product, an acknowledgment in the product + documentation would be appreciated but is not required. + + 3. Altered source versions must be plainly marked as such, and must + not be misrepresented as being the original software. + + 4. The name of the author may not be used to endorse or promote + products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS + OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE + GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------------------------------------------------------- + + Notice that the above BSD-style license applies to this one file + (callgrind.h) only. The entire rest of Valgrind is licensed under + the terms of the GNU General Public License, version 2. See the + COPYING file in the source distribution for details. + + ---------------------------------------------------------------- +*/ + +#ifndef __CALLGRIND_H +#define __CALLGRIND_H + +#include "valgrind.h" + +/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !! + This enum comprises an ABI exported by Valgrind to programs + which use client requests. DO NOT CHANGE THE ORDER OF THESE + ENTRIES, NOR DELETE ANY -- add new ones at the end. + + The identification ('C','T') for Callgrind has historical + reasons: it was called "Calltree" before. Besides, ('C','G') would + clash with cachegrind. + */ + +typedef + enum { + VG_USERREQ__DUMP_STATS = VG_USERREQ_TOOL_BASE('C','T'), + VG_USERREQ__ZERO_STATS, + VG_USERREQ__TOGGLE_COLLECT, + VG_USERREQ__DUMP_STATS_AT, + VG_USERREQ__START_INSTRUMENTATION, + VG_USERREQ__STOP_INSTRUMENTATION + } Vg_CallgrindClientRequest; + +/* Dump current state of cost centers, and zero them afterwards */ +#define CALLGRIND_DUMP_STATS \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DUMP_STATS, \ + 0, 0, 0, 0, 0) + +/* Dump current state of cost centers, and zero them afterwards. + The argument is appended to a string stating the reason which triggered + the dump. This string is written as a description field into the + profile data dump. */ +#define CALLGRIND_DUMP_STATS_AT(pos_str) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DUMP_STATS_AT, \ + pos_str, 0, 0, 0, 0) + +/* Zero cost centers */ +#define CALLGRIND_ZERO_STATS \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__ZERO_STATS, \ + 0, 0, 0, 0, 0) + +/* Toggles collection state. + The collection state specifies whether the happening of events + should be noted or if they are to be ignored. Events are noted + by increment of counters in a cost center */ +#define CALLGRIND_TOGGLE_COLLECT \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__TOGGLE_COLLECT, \ + 0, 0, 0, 0, 0) + +/* Start full callgrind instrumentation if not already switched on. + When cache simulation is done, it will flush the simulated cache; + this will lead to an artificial cache warmup phase afterwards with + cache misses which would not have happened in reality. */ +#define CALLGRIND_START_INSTRUMENTATION \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__START_INSTRUMENTATION, \ + 0, 0, 0, 0, 0) + +/* Stop full callgrind instrumentation if not already switched off. + This flushes Valgrinds translation cache, and does no additional + instrumentation afterwards, which effectivly will run at the same + speed as the "none" tool (ie. at minimal slowdown). + Use this to bypass Callgrind aggregation for uninteresting code parts. + To start Callgrind in this mode to ignore the setup phase, use + the option "--instr-atstart=no". */ +#define CALLGRIND_STOP_INSTRUMENTATION \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STOP_INSTRUMENTATION, \ + 0, 0, 0, 0, 0) + +#endif /* __CALLGRIND_H */ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cd41f0de092f0b1488c8945edf2af80c6f9b596c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp @@ -0,0 +1,35 @@ +/* Used to collect profiles of old versions of PyTorch. */ +#include +#include + +bool _valgrind_supported_platform() { +#if defined(NVALGRIND) + return false; +#else + return true; +#endif +} + +void _valgrind_toggle() { +#if defined(NVALGRIND) + TORCH_CHECK(false, "Valgrind is not supported."); +#else + CALLGRIND_TOGGLE_COLLECT; +#endif +} + +void _valgrind_toggle_and_dump_stats() { +#if defined(NVALGRIND) + TORCH_CHECK(false, "Valgrind is not supported."); +#else + // NB: See note in Module.cpp + CALLGRIND_TOGGLE_COLLECT; + CALLGRIND_DUMP_STATS; +#endif +} + +PYBIND11_MODULE(callgrind_bindings, m) { + m.def("_valgrind_supported_platform", &_valgrind_supported_platform); + m.def("_valgrind_toggle", &_valgrind_toggle); + m.def("_valgrind_toggle_and_dump_stats", &_valgrind_dump_stats); +} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp new file mode 100644 index 0000000000000000000000000000000000000000..587685c7df7445b299c35462307f47cf6012a00d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp @@ -0,0 +1,68 @@ +/* C++ template for Timer.collect_callgrind + +This template will be consumed by `cpp_jit.py`, and will replace: + `GLOBAL_SETUP_TEMPLATE_LOCATION`, + `SETUP_TEMPLATE_LOCATION` + and + `STMT_TEMPLATE_LOCATION` +sections with user provided statements. +*/ + +#include +#include +#include + +#include + +// Global setup. (e.g. #includes) +// GLOBAL_SETUP_TEMPLATE_LOCATION + +#if defined(NVALGRIND) +static_assert(false); +#endif + +int main(int argc, char* argv[]) { + // This file should only be called inside of `Timer`, so we can adopt a + // very simple and rigid argument parsing scheme. + TORCH_CHECK(argc == 9); + TORCH_CHECK(std::string(argv[1]) == "--number"); + auto number = std::stoi(argv[2]); + + TORCH_CHECK( + std::string(argv[3]) == "--number-warmup" || + std::string(argv[3]) == "--number_warmup"); + auto number_warmup = std::stoi(argv[4]); + + TORCH_CHECK(std::string(argv[5]) == "--repeats"); + auto repeats = std::stoi(argv[6]); + + TORCH_CHECK( + std::string(argv[7]) == "--number-threads" || + std::string(argv[7]) == "--number_threads"); + auto number_threads = std::stoi(argv[8]); + torch::set_num_threads(number_threads); + + // Setup + // SETUP_TEMPLATE_LOCATION + + // Warmup + for (const auto i : c10::irange(number_warmup)) { + (void)i; + // STMT_TEMPLATE_LOCATION + } + + // Main loop + for (const auto repeat : c10::irange(repeats)) { + (void)repeat; + CALLGRIND_TOGGLE_COLLECT; + + for (const auto i : c10::irange(number)) { + (void)i; + // STMT_TEMPLATE_LOCATION + } + + // NB: See note in Module.cpp + CALLGRIND_TOGGLE_COLLECT; + CALLGRIND_DUMP_STATS; + } +} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..17ecea8bbb5598db967e8213b5bbd9c0fd8562f3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py @@ -0,0 +1,919 @@ +"""Intermediate layer between `Timer` and `valgrind`.""" +import collections +import enum +import dataclasses +import itertools as it +import os +import pickle +import re +import shutil +import subprocess +import sys +import textwrap +from typing import ( + cast, Any, NamedTuple, + Union, TYPE_CHECKING) +from collections.abc import Callable +from collections.abc import Iterator + +import torch +from torch.utils.benchmark.utils import common, cpp_jit +from torch.utils.benchmark.utils._stubs import CallgrindModuleType +import operator + + +__all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"] + + +if TYPE_CHECKING: + CompletedProcessType = subprocess.CompletedProcess[str] +else: + CompletedProcessType = subprocess.CompletedProcess + + +class FunctionCount(NamedTuple): + # TODO(#105471): Rename the count field + count: int # type: ignore[assignment] + function: str + + +@dataclasses.dataclass(repr=False, eq=False, frozen=True) +class FunctionCounts: + """Container for manipulating Callgrind results. + + It supports: + 1) Addition and subtraction to combine or diff results. + 2) Tuple-like indexing. + 3) A `denoise` function which strips CPython calls which are known to + be non-deterministic and quite noisy. + 4) Two higher order methods (`filter` and `transform`) for custom + manipulation. + """ + _data: tuple[FunctionCount, ...] + inclusive: bool + truncate_rows: bool = True + + # For normal use, torch._tensor_str.PRINT_OPTS.linewidth determines + # the print settings. This is simply to allow hermetic unit tests. + _linewidth: int | None = None + + def __iter__(self) -> Iterator[FunctionCount]: + yield from self._data + + def __len__(self) -> int: + return len(self._data) + + def __getitem__(self, item: Any) -> Union[FunctionCount, "FunctionCounts"]: + data: FunctionCount | tuple[FunctionCount, ...] = self._data[item] + return ( + FunctionCounts(cast(tuple[FunctionCount, ...], data), self.inclusive, truncate_rows=False) + if isinstance(data, tuple) else data + ) + + def __repr__(self) -> str: + count_len = 0 + for c, _ in self: + # Account for sign in string length. + count_len = max(count_len, len(str(c)) + int(c < 0)) + + lines = [] + linewidth = self._linewidth or torch._tensor_str.PRINT_OPTS.linewidth + fn_str_len = max(linewidth - count_len - 4, 40) + for c, fn in self: + if len(fn) > fn_str_len: + left_len = int((fn_str_len - 5) // 2) + fn = fn[:left_len] + " ... " + fn[-(fn_str_len - left_len - 5):] + lines.append(f" {c:>{count_len}} {fn}") + + if self.truncate_rows and len(lines) > 18: + lines = lines[:9] + ["...".rjust(count_len + 2)] + lines[-9:] + + if not self.inclusive: + lines.extend(["", f"Total: {self.sum()}"]) + + return "\n".join([super().__repr__()] + lines) + + def __add__( + self, + other: "FunctionCounts", + ) -> "FunctionCounts": + return self._merge(other, lambda c: c) + + def __sub__( + self, + other: "FunctionCounts", + ) -> "FunctionCounts": + return self._merge(other, operator.neg) + + def __mul__(self, other: int | float) -> "FunctionCounts": + return self._from_dict({ + fn: int(c * other) for c, fn in self._data + }, self.inclusive) + + def transform(self, map_fn: Callable[[str], str]) -> "FunctionCounts": + """Apply `map_fn` to all of the function names. + + This can be used to regularize function names (e.g. stripping irrelevant + parts of the file path), coalesce entries by mapping multiple functions + to the same name (in which case the counts are added together), etc. + """ + counts: collections.defaultdict[str, int] = collections.defaultdict(int) + for c, fn in self._data: + counts[map_fn(fn)] += c + + return self._from_dict(counts, self.inclusive) + + def filter(self, filter_fn: Callable[[str], bool]) -> "FunctionCounts": + """Keep only the elements where `filter_fn` applied to function name returns True.""" + return FunctionCounts(tuple(i for i in self if filter_fn(i.function)), self.inclusive) + + def sum(self) -> int: + return sum(c for c, _ in self) + + def denoise(self) -> "FunctionCounts": + """Remove known noisy instructions. + + Several instructions in the CPython interpreter are rather noisy. These + instructions involve unicode to dictionary lookups which Python uses to + map variable names. FunctionCounts is generally a content agnostic + container, however this is sufficiently important for obtaining + reliable results to warrant an exception.""" + return self.filter(lambda fn: "dictobject.c:lookdict_unicode" not in fn) + + def _merge( + self, + second: "FunctionCounts", + merge_fn: Callable[[int], int] + ) -> "FunctionCounts": + if self.inclusive != second.inclusive: + raise AssertionError("Cannot merge inclusive and exclusive counts.") + counts: collections.defaultdict[str, int] = collections.defaultdict(int) + for c, fn in self: + counts[fn] += c + + for c, fn in second: + counts[fn] += merge_fn(c) + + return self._from_dict(counts, self.inclusive) + + @staticmethod + def _from_dict(counts: dict[str, int], inclusive: bool) -> "FunctionCounts": + flat_counts = (FunctionCount(c, fn) for fn, c in counts.items() if c) + return FunctionCounts(tuple(sorted(flat_counts, reverse=True)), inclusive) + + +@dataclasses.dataclass(repr=False, eq=False, frozen=True) +class CallgrindStats: + """Top level container for Callgrind results collected by Timer. + + Manipulation is generally done using the FunctionCounts class, which is + obtained by calling `CallgrindStats.stats(...)`. Several convenience + methods are provided as well; the most significant is + `CallgrindStats.as_standardized()`. + """ + task_spec: common.TaskSpec + number_per_run: int + built_with_debug_symbols: bool + baseline_inclusive_stats: FunctionCounts + baseline_exclusive_stats: FunctionCounts + stmt_inclusive_stats: FunctionCounts + stmt_exclusive_stats: FunctionCounts + stmt_callgrind_out: str | None + + def __repr__(self) -> str: + base_stats = self.baseline_exclusive_stats + output = f""" +{super().__repr__()} +{self.task_spec.summarize()} + {'':>25}All{'':>10}Noisy symbols removed + Instructions: {self.counts(denoise=False):>12}{'':>15}{self.counts(denoise=True):>12} + Baseline: {base_stats.sum():>12}{'':>15}{base_stats.denoise().sum():>12} +{self.number_per_run} runs per measurement, {self.task_spec.num_threads} thread{'s' if self.task_spec.num_threads > 1 else ''} +""".strip() + if not self.built_with_debug_symbols: + output += textwrap.dedent(""" + Warning: PyTorch was not built with debug symbols. + Source information may be limited. Rebuild with + REL_WITH_DEB_INFO=1 for more detailed results.""") + return output + + def stats(self, inclusive: bool = False) -> FunctionCounts: + """Returns detailed function counts. + + Conceptually, the FunctionCounts returned can be thought of as a tuple + of (count, path_and_function_name) tuples. + + `inclusive` matches the semantics of callgrind. If True, the counts + include instructions executed by children. `inclusive=True` is useful + for identifying hot spots in code; `inclusive=False` is useful for + reducing noise when diffing counts from two different runs. (See + CallgrindStats.delta(...) for more details) + """ + return self.stmt_inclusive_stats if inclusive else self.stmt_exclusive_stats + + def counts(self, *, denoise: bool = False) -> int: + """Returns the total number of instructions executed. + + See `FunctionCounts.denoise()` for an explanation of the `denoise` arg. + """ + stats = self.stmt_exclusive_stats + return (stats.denoise() if denoise else stats).sum() + + # FIXME: Once 3.7 is the minimum version, type annotate `other` per PEP 563 + def delta( + self, + other: "CallgrindStats", + inclusive: bool = False, + ) -> FunctionCounts: + """Diff two sets of counts. + + One common reason to collect instruction counts is to determine the + the effect that a particular change will have on the number of instructions + needed to perform some unit of work. If a change increases that number, the + next logical question is "why". This generally involves looking at what part + if the code increased in instruction count. This function automates that + process so that one can easily diff counts on both an inclusive and + exclusive basis. + """ + return self.stats(inclusive=inclusive) - other.stats(inclusive=inclusive) + + def as_standardized(self) -> "CallgrindStats": + """Strip library names and some prefixes from function strings. + + When comparing two different sets of instruction counts, on stumbling + block can be path prefixes. Callgrind includes the full filepath + when reporting a function (as it should). However, this can cause + issues when diffing profiles. If a key component such as Python + or PyTorch was built in separate locations in the two profiles, which + can result in something resembling:: + + 23234231 /tmp/first_build_dir/thing.c:foo(...) + 9823794 /tmp/first_build_dir/thing.c:bar(...) + ... + 53453 .../aten/src/Aten/...:function_that_actually_changed(...) + ... + -9823794 /tmp/second_build_dir/thing.c:bar(...) + -23234231 /tmp/second_build_dir/thing.c:foo(...) + + Stripping prefixes can ameliorate this issue by regularizing the + strings and causing better cancellation of equivalent call sites + when diffing. + """ + def strip(stats: FunctionCounts) -> FunctionCounts: + transforms = ( + # PyTorch may have been built in different locations. + (r"^.+build/\.\./", "build/../"), + (r"^.+/" + re.escape("build/aten/"), "build/aten/"), + + # "Python" and "Objects" come from CPython. + (r"^.+/" + re.escape("Python/"), "Python/"), + (r"^.+/" + re.escape("Objects/"), "Objects/"), + + # Strip library name. e.g. `libtorch.so` + (r"\s\[.+\]$", ""), + ) + + for before, after in transforms: + stats = stats.transform(lambda fn: re.sub(before, after, fn)) + + return stats + + return CallgrindStats( + task_spec=self.task_spec, + number_per_run=self.number_per_run, + built_with_debug_symbols=self.built_with_debug_symbols, + baseline_inclusive_stats=strip(self.baseline_inclusive_stats), + baseline_exclusive_stats=strip(self.baseline_exclusive_stats), + stmt_inclusive_stats=strip(self.stmt_inclusive_stats), + stmt_exclusive_stats=strip(self.stmt_exclusive_stats), + + # `as_standardized` will change symbol names, so the contents will + # no longer map directly to `callgrind.out` + stmt_callgrind_out=None, + ) + + +class Serialization(enum.Enum): + PICKLE = 0 + TORCH = 1 + TORCH_JIT = 2 + + +_GLOBALS_ALLOWED_TYPES: dict[Serialization, tuple[Any, ...]] = { + Serialization.PICKLE: (str, bytes, bool, int, float, complex), + Serialization.TORCH_JIT: (torch.jit.ScriptFunction, torch.jit.ScriptModule), + Serialization.TORCH: (torch.nn.Module,), +} + + +class CopyIfCallgrind: + """Signal that a global may be replaced with a deserialized copy. + + See `GlobalsBridge` for why this matters. + """ + def __init__(self, value: Any, *, setup: str | None = None) -> None: + for method, supported_types in _GLOBALS_ALLOWED_TYPES.items(): + if any(isinstance(value, t) for t in supported_types): + self._value: Any = value + self._setup: str | None = setup + self._serialization: Serialization = method + break + else: + supported_str = "\n".join([ + getattr(t, "__name__", repr(t)) + for t in it.chain(_GLOBALS_ALLOWED_TYPES.values())]) + + raise ValueError( + f"Unsupported type: {type(value)}\n" + f"`collect_callgrind` restricts globals to the following types:\n" + f"{textwrap.indent(supported_str, ' ')}" + ) + + @property + def value(self) -> Any: + return self._value + + @property + def setup(self) -> str | None: + return self._setup + + @property + def serialization(self) -> Serialization: + return self._serialization + + @staticmethod + def unwrap_all(globals: dict[str, Any]) -> dict[str, Any]: + return { + k: (v.value if isinstance(v, CopyIfCallgrind) else v) + for k, v in globals.items() + } + + +class GlobalsBridge: + """Handle the transfer of (certain) globals when collecting Callgrind statistics. + + Key takeaway: Any globals passed must be wrapped in `CopyIfCallgrind` to + work with `Timer.collect_callgrind`. + + Consider the following code snippet: + ``` + import pickle + import timeit + + class Counter: + value = 0 + + def __call__(self): + self.value += 1 + + counter = Counter() + timeit.Timer("counter()", globals={"counter": counter}).timeit(10) + print(counter.value) # 10 + + timeit.Timer( + "counter()", + globals={"counter": pickle.loads(pickle.dumps(counter))} + ).timeit(20) + print(counter.value) # Still 10 + ``` + + In the first case, `stmt` is executed using the objects in `globals`; + however, the addition of serialization and deserialization changes the + semantics and may meaningfully change behavior. + + This is a practical consideration when collecting Callgrind statistics. + Unlike `exec` based execution (which `timeit` uses under the hood) which + can share in-memory data structures with the caller, Callgrind collection + requires an entirely new process in order to run under Valgrind. This means + that any data structures used for statement execution will have to be + serialized and deserialized in the subprocess. + + In order to avoid surprising semantics from (user invisible) process + boundaries, what can be passed through `globals` is severely restricted + for `Timer.collect_callgrind`. It is expected that most setup should be + achievable (albeit perhaps less ergonomically) by passing a `setup` + string. + + There are, however, exceptions. One such class are TorchScripted functions. + Because they require a concrete file with source code it is not possible + to define them using a `setup` string. Another group are torch.nn.Modules, + whose construction can be complex and prohibitively cumbersome to coerce + into a `setup` string. Finally, most builtin types are sufficiently well + behaved and sufficiently common to warrant allowing as well. (e.g. + `globals={"n": 1}` is very convenient.) + + Fortunately, all have well defined serialization semantics. This class + is responsible for enabling the Valgrind subprocess to use elements in + `globals` so long as they are an allowed type. + + Caveats: + The user is required to acknowledge this serialization by wrapping + elements in `globals` with `CopyIfCallgrind`. + + While ScriptFunction and ScriptModule are expected to save and load + quite robustly, it is up to the user to ensure that an nn.Module can + un-pickle successfully. + + `torch.Tensor` and `np.ndarray` are deliberately excluded. The + serialization/deserialization process perturbs the representation of a + tensor in ways that could result in incorrect measurements. For example, + if a tensor lives in pinned CPU memory, this fact would not be preserved + by a dump, and that will in turn change the performance of certain CUDA + operations. + """ + + def __init__(self, globals: dict[str, Any], data_dir: str) -> None: + self._globals: dict[str, CopyIfCallgrind] = {} + self._data_dir = data_dir + if not os.path.exists(data_dir): + os.mkdir(data_dir) + + if globals.get("torch", torch) is not torch: + raise ValueError("`collect_callgrind` does not support mocking out `torch`.") + + for name, value in globals.items(): + if name in ("torch", "__builtins__"): + # Torch will be imported by the collection script, and + # __builtins__ is added by Timer. + continue + + if not isinstance(value, CopyIfCallgrind): + raise ValueError( + "`collect_callgrind` requires that globals be wrapped in " + "`CopyIfCallgrind` so that serialization is explicit." + ) + + self._globals[name] = value + + def construct(self) -> str: + load_lines = [] + for name, wrapped_value in self._globals.items(): + if wrapped_value.setup is not None: + # pyrefly: ignore [bad-argument-type] + load_lines.append(textwrap.dedent(wrapped_value.setup)) + + if wrapped_value.serialization == Serialization.PICKLE: + path = os.path.join(self._data_dir, f"{name}.pkl") + load_lines.append( + # pyrefly: ignore [bad-argument-type] + f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)") + with open(path, "wb") as f: + pickle.dump(wrapped_value.value, f) + + elif wrapped_value.serialization == Serialization.TORCH: + path = os.path.join(self._data_dir, f"{name}.pt") + # TODO: Figure out if we can use torch.serialization.add_safe_globals here + # Using weights_only=False after the change in + # https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573 + # pyrefly: ignore [bad-argument-type] + load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)") + torch.save(wrapped_value.value, path) + + elif wrapped_value.serialization == Serialization.TORCH_JIT: + path = os.path.join(self._data_dir, f"{name}.pt") + # pyrefly: ignore [bad-argument-type] + load_lines.append(f"{name} = torch.jit.load({repr(path)})") + with open(path, "wb") as f: + torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call] + + else: + raise NotImplementedError( + f"Unknown serialization method: {wrapped_value.serialization}") + + return "\n".join(load_lines) + + +class _ValgrindWrapper: + def __init__(self) -> None: + self._bindings_module: CallgrindModuleType | None = None + valgrind_symbols = ( + "_valgrind_supported_platform", + "_valgrind_toggle", + "_valgrind_toggle_and_dump_stats", + ) + if all(hasattr(torch._C, symbol) for symbol in valgrind_symbols): + self._supported_platform: bool = torch._C._valgrind_supported_platform() + + else: + print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.") + self._bindings_module = cpp_jit.get_compat_bindings() + if not all(hasattr(self._bindings_module, symbol) for symbol in valgrind_symbols): + raise AssertionError("JIT-compiled callgrind bindings are missing required symbols") + self._supported_platform = self._bindings_module._valgrind_supported_platform() + + self._commands_available: dict[str, bool] = {} + if self._supported_platform: + # Only bother checking on supported platforms. + for cmd in ("valgrind", "callgrind_control", "callgrind_annotate"): + self._commands_available[cmd] = not subprocess.run( + ["which", cmd], + capture_output=True, + check=False, + ).returncode + + self._build_type: str | None = None + build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show()) # type: ignore[no-untyped-call] + if build_search is not None: + self._build_type = build_search.groups()[0].split(",")[0] + + def _validate(self) -> None: + if not self._supported_platform: + raise OSError("Valgrind is not supported on this platform.") + + missing_cmds = [cmd for cmd, available in self._commands_available.items() if not available] + if missing_cmds: + raise OSError("Missing: " + ", ".join(missing_cmds)) + + def collect_callgrind( + self, + task_spec: common.TaskSpec, + globals: dict[str, Any], + *, + number: int, + repeats: int, + collect_baseline: bool, + is_python: bool, + retain_out_file: bool, + ) -> tuple[CallgrindStats, ...]: + """Collect stats, and attach a reference run which can be used to filter interpreter overhead.""" + self._validate() + if not is_python and collect_baseline: + raise AssertionError("collect_baseline is only supported for Python timers") + + *task_stats, baseline_stats = self._invoke( + task_spec=task_spec, + globals=globals, + number=number, + repeats=repeats, + collect_baseline=collect_baseline, + is_python=is_python, + retain_out_file=retain_out_file, + ) + if len(task_stats) != repeats: + raise AssertionError("Unexpected number of task stats returned from _invoke") + + return tuple( + CallgrindStats( + task_spec=task_spec, + number_per_run=number, + built_with_debug_symbols=self._build_type == "RelWithDebInfo", + baseline_inclusive_stats=baseline_stats[0], + baseline_exclusive_stats=baseline_stats[1], + stmt_inclusive_stats=stmt_inclusive_stats, + stmt_exclusive_stats=stmt_exclusive_stats, + stmt_callgrind_out=out_contents, + ) + for stmt_inclusive_stats, stmt_exclusive_stats, out_contents in task_stats + ) + + def _invoke( + self, + *, + task_spec: common.TaskSpec, + globals: dict[str, Any], + number: int, + repeats: int, + collect_baseline: bool, + is_python: bool, + retain_out_file: bool, + ) -> tuple[tuple[FunctionCounts, FunctionCounts, str | None], ...]: + """Core invocation method for Callgrind collection. + + Valgrind operates by effectively replacing the CPU with an emulated + version which allows it to instrument any code at the cost of severe + performance degradation. This has the practical effect that in order + to collect Callgrind statistics, a new process has to be created + running under `valgrind`. The steps for this process are: + + 1) Create a scratch directory. + 2) Codegen a run script. (_ValgrindWrapper._construct_script) + Inside the run script: + * Validate that Python and torch match the parent process + * Validate that it is indeed running under valgrind + * Execute `setup` and warm up `stmt` + * Begin collecting stats + * Run the `stmt` loop + * Stop collecting stats + 3) Parse the run results. + 4) Cleanup the scratch directory. + """ + working_dir = common._make_temp_dir(prefix="callgrind") + data_dir = os.path.join(working_dir, "data") + script_file = os.path.join(working_dir, "timer_callgrind.py") + callgrind_out = os.path.join(working_dir, "callgrind.out") + error_log = os.path.join(working_dir, "error.txt") + stat_log = os.path.join(working_dir, "callgrind_stat.txt") + stdout_stderr_log = os.path.join(working_dir, "stdout_stderr.log") + + def run(args: list[str], **kwargs: Any) -> tuple[CompletedProcessType, str]: + # https://thraxil.org/users/anders/posts/2008/03/13/Subprocess-Hanging-PIPE-is-your-enemy/ + with open(stdout_stderr_log, "wb") as f_stdout_stderr: + invocation = subprocess.run( + args, + stdout=f_stdout_stderr, + stderr=subprocess.STDOUT, + **kwargs, + ) + with open(stdout_stderr_log) as f: + return invocation, f.read() + + try: + if is_python: + if self._bindings_module is not None: + shutil.copy( + self._bindings_module.__file__, + os.path.join(working_dir, os.path.split(self._bindings_module.__file__)[1]) + ) + + script_file = os.path.join(working_dir, "timer_callgrind.py") + with open(script_file, "w") as f: + f.write(self._construct_script( + task_spec, + globals=GlobalsBridge(globals, data_dir), + number=number, + repeats=repeats, + collect_baseline=collect_baseline, + error_log=error_log, + stat_log=stat_log, + bindings=self._bindings_module)) + + run_loop_cmd = ["python", script_file] + else: + if collect_baseline: + raise AssertionError("collect_baseline must be False for non-Python timers") + run_loop_exec = cpp_jit.compile_callgrind_template( + stmt=task_spec.stmt, + setup=task_spec.setup, + global_setup=task_spec.global_setup, + ) + run_loop_cmd = [ + run_loop_exec, + "--number", str(number), + "--number-warmup", str(min(number, 10)), + "--repeats", str(repeats), + "--number-threads", str(task_spec.num_threads), + ] + + valgrind_invocation, valgrind_invocation_output = run([ + "valgrind", + "--tool=callgrind", + f"--callgrind-out-file={callgrind_out}", + "--dump-line=yes", + "--dump-instr=yes", + "--instr-atstart=yes", + "--collect-atstart=no", + ] + run_loop_cmd) + + if valgrind_invocation.returncode: + error_report = "" + if os.path.exists(error_log): + with open(error_log) as f: + error_report = f.read() + if not error_report: + error_report = "Unknown error.\n" + valgrind_invocation_output + + raise OSError(f"Failed to collect callgrind profile:\n{error_report}") + + def parse_output(fpath: str, inclusive: bool) -> FunctionCounts: + _annotate_invocation, annotate_invocation_output = run([ + "callgrind_annotate", + f"--inclusive={'yes' if inclusive else 'no'}", + "--threshold=100", + "--show-percs=no", + fpath + ], check=True) + + total_pattern = re.compile(r"^([0-9,]+)\s+PROGRAM TOTALS") + begin_pattern = re.compile(r"Ir\s+file:function") + function_pattern = re.compile(r"^\s*([0-9,]+)\s+(.+:.+)$") + + class ScanState(enum.Enum): + SCANNING_FOR_TOTAL = 0 + SCANNING_FOR_START = 1 + PARSING = 2 + + scan_state = ScanState.SCANNING_FOR_TOTAL + fn_counts = [] + for l in annotate_invocation_output.splitlines(keepends=False): + if scan_state == ScanState.SCANNING_FOR_TOTAL: + total_match = total_pattern.match(l) + if total_match: + program_totals = int(total_match.groups()[0].replace(",", "")) + scan_state = ScanState.SCANNING_FOR_START + + elif scan_state == ScanState.SCANNING_FOR_START: + if begin_pattern.match(l): + scan_state = ScanState.PARSING + + else: + if scan_state != ScanState.PARSING: + raise AssertionError("Failed to enter PARSING state while parsing callgrind_annotate output") + fn_match = function_pattern.match(l) + if fn_match: + ir_str, file_function = fn_match.groups() + ir = int(ir_str.replace(",", "")) + if ir == program_totals: # type: ignore[possibly-undefined] + # Callgrind includes some top level red herring symbols when + # a program dumps multiple profiles. + continue + fn_counts.append(FunctionCount(ir, file_function)) + + elif re.match(r"-+", l): + # Ignore heading separator lines. + continue + + else: + break + + if scan_state != ScanState.PARSING: + raise AssertionError(f"Failed to parse {fpath}") + return FunctionCounts(tuple(sorted(fn_counts, reverse=True)), inclusive=inclusive) + + def read_results(i: int) -> tuple[FunctionCounts, FunctionCounts, str | None]: + if i == repeats and not collect_baseline: + # Null baseline. + return ( + FunctionCounts((), inclusive=True), + FunctionCounts((), inclusive=False), + None, + ) + + fpath = f"{callgrind_out}.{i + 1}" # Callgrind one-indexes files. + callgrind_out_contents: str | None = None + if retain_out_file: + with open(fpath) as f: + callgrind_out_contents = f.read() + + return ( + parse_output(fpath, inclusive=True), + parse_output(fpath, inclusive=False), + callgrind_out_contents + ) + + return tuple(read_results(i) for i in range(repeats + 1)) + finally: + shutil.rmtree(working_dir) + + @staticmethod + def _construct_script( + task_spec: common.TaskSpec, + globals: GlobalsBridge, + *, + number: int, + repeats: int, + collect_baseline: bool, + error_log: str, + stat_log: str, + bindings: CallgrindModuleType | None, + ) -> str: + def block_stmt(stmt: str, indent: int = 0) -> str: + """Partially unroll benchmark loop. + + The naive template looks something like: + "for _ in range({number}): {stmt}" + + However a loop in Python is surprisingly expensive, and significantly + increases the number of background Python instructions. So instead we + partially unroll the loops, with a block size of 100 chosen to keep + the instruction overhead from `range` low while also not ballooning + the size of the generated file. + """ + block_size = 100 + loop_count = number // block_size + if loop_count == 1: + # There is no point in having `for _ in range(1): ...` rather + # than just `...`, and this lets us save shave a few background + # instructions. + loop_count = 0 + remainder = number - block_size * loop_count + blocked_stmt = "" + + if loop_count: + unrolled_stmts = textwrap.indent("\n".join([stmt] * block_size), " " * 4) + blocked_stmt += f"for _ in range({loop_count}):\n{unrolled_stmts}\n" + + if remainder: + blocked_stmt += "\n".join([stmt] * remainder) + + return textwrap.indent(blocked_stmt, " " * indent) + + pass_baseline = ( + "callgrind_bindings._valgrind_toggle()\n" + f"{block_stmt('pass')}\n" + "callgrind_bindings._valgrind_toggle_and_dump_stats()" + ) + + return textwrap.dedent(r""" + import gc + import os + import pickle + import subprocess + import sys + import time + + # Mitigate https://github.com/pytorch/pytorch/issues/37377 + # which can sometimes cause the subprocess call to fail. + import numpy as np + + import torch + torch.set_num_threads({num_threads}) + + {bindings_import} + + PID = os.getpid() + + def log_failure(msg): + with open({error_log_repr}, "wt") as f: + f.write(msg) + sys.exit(1) + + def check_result(completed_process): + if completed_process.returncode: + log_failure(f"Command failed: {{' '.join(completed_process.args)}}") + return completed_process + + # ============================================================================= + # == Check that subprocess matches parent ===================================== + # ============================================================================= + if os.path.realpath(sys.executable) != "{parent_interpreter}": + log_failure( + "Interpreter mismatch:\n" + f" {{os.path.realpath(sys.executable)}}\n vs.\n {parent_interpreter}" + ) + + if torch.__file__ != "{torch_file}": + log_failure( + "PyTorch does not match expected file:\n" + f" {{torch.__file__}}\n vs.\n {torch_file}" + ) + + # ============================================================================= + # == User specified setup ===================================================== + # ============================================================================= + # Load serialized globals + {load_globals} + + # User setup str + {setup} + + for _ in range({warmup_number}): + {indented_stmt} + + # ============================================================================= + # == Callgrind management ===================================================== + # ============================================================================= + with open("{stat_log}", "wb") as stat_file: + # If many instances of callgrind are running at once, the output of + # `callgrind_control` may exceed 16kb which would cause `subprocess.PIPE` + # to deadlock. So instead we use a file. + callgrind_stat = check_result(subprocess.run( + ["callgrind_control", "--stat"], + stdout=stat_file, + stderr=subprocess.STDOUT, + )) + + with open("{stat_log}", "rt") as stat_file: + stat_lines = stat_file.read().splitlines() + + if f"PID {{PID}}: python {{__file__}}" not in stat_lines: + log_failure("Process does not appear to be running callgrind.") + + gc.collect() + time.sleep(0.01) + + # ============================================================================= + # == User code block ========================================================== + # ============================================================================= + for _ in range({repeats}): + callgrind_bindings._valgrind_toggle() + {blocked_stmt} + callgrind_bindings._valgrind_toggle_and_dump_stats() + gc.collect() + + {baseline} + """).strip().format( + indented_stmt=textwrap.indent(task_spec.stmt, " " * 4), + blocked_stmt=block_stmt(task_spec.stmt, indent=4), + baseline=(pass_baseline if collect_baseline else ""), + number=number, + repeats=repeats, + load_globals=globals.construct(), + setup=task_spec.setup, + warmup_number=min(number, 10), + num_threads=task_spec.num_threads, + error_log_repr=repr(error_log), + stat_log=stat_log, + parent_interpreter=os.path.realpath(sys.executable), + torch_file=torch.__file__, + bindings_import=( + "import torch._C as callgrind_bindings" if bindings is None + else f"import {bindings.__name__} as callgrind_bindings"), + ) + + +CALLGRIND_SINGLETON: _ValgrindWrapper | None = None +def wrapper_singleton() -> _ValgrindWrapper: + global CALLGRIND_SINGLETON + if CALLGRIND_SINGLETON is None: + CALLGRIND_SINGLETON = _ValgrindWrapper() + return CALLGRIND_SINGLETON diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h new file mode 100644 index 0000000000000000000000000000000000000000..d33dd30932aa86b8284cb93d0e29ec646e820197 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h @@ -0,0 +1,7157 @@ +/* -*- c -*- + ---------------------------------------------------------------- + + Notice that the following BSD-style license applies to this one + file (valgrind.h) only. The rest of Valgrind is licensed under the + terms of the GNU General Public License, version 2, unless + otherwise indicated. See the COPYING file in the source + distribution for details. + + ---------------------------------------------------------------- + + This file is part of Valgrind, a dynamic binary instrumentation + framework. + + Copyright (C) 2000-2017 Julian Seward. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. The origin of this software must not be misrepresented; you must + not claim that you wrote the original software. If you use this + software in a product, an acknowledgment in the product + documentation would be appreciated but is not required. + + 3. Altered source versions must be plainly marked as such, and must + not be misrepresented as being the original software. + + 4. The name of the author may not be used to endorse or promote + products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS + OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE + GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------------------------------------------------------- + + Notice that the above BSD-style license applies to this one file + (valgrind.h) only. The entire rest of Valgrind is licensed under + the terms of the GNU General Public License, version 2. See the + COPYING file in the source distribution for details. + + ---------------------------------------------------------------- +*/ + + +/* This file is for inclusion into client (your!) code. + + You can use these macros to manipulate and query Valgrind's + execution inside your own programs. + + The resulting executables will still run without Valgrind, just a + little bit more slowly than they otherwise would, but otherwise + unchanged. When not running on valgrind, each client request + consumes very few (eg. 7) instructions, so the resulting performance + loss is negligible unless you plan to execute client requests + millions of times per second. Nevertheless, if that is still a + problem, you can compile with the NVALGRIND symbol defined (gcc + -DNVALGRIND) so that client requests are not even compiled in. */ + +#ifndef __VALGRIND_H +#define __VALGRIND_H + + +/* ------------------------------------------------------------------ */ +/* VERSION NUMBER OF VALGRIND */ +/* ------------------------------------------------------------------ */ + +/* Specify Valgrind's version number, so that user code can + conditionally compile based on our version number. Note that these + were introduced at version 3.6 and so do not exist in version 3.5 + or earlier. The recommended way to use them to check for "version + X.Y or later" is (eg) + +#if defined(__VALGRIND_MAJOR__) && defined(__VALGRIND_MINOR__) \ + && (__VALGRIND_MAJOR__ > 3 \ + || (__VALGRIND_MAJOR__ == 3 && __VALGRIND_MINOR__ >= 6)) +*/ +#define __VALGRIND_MAJOR__ 3 +#define __VALGRIND_MINOR__ 17 + + +#include + +/* Nb: this file might be included in a file compiled with -ansi. So + we can't use C++ style "//" comments nor the "asm" keyword (instead + use "__asm__"). */ + +/* Derive some tags indicating what the target platform is. Note + that in this file we're using the compiler's CPP symbols for + identifying architectures, which are different to the ones we use + within the rest of Valgrind. Note, __powerpc__ is active for both + 32 and 64-bit PPC, whereas __powerpc64__ is only active for the + latter (on Linux, that is). + + Misc note: how to find out what's predefined in gcc by default: + gcc -Wp,-dM somefile.c +*/ +#undef PLAT_x86_darwin +#undef PLAT_amd64_darwin +#undef PLAT_x86_win32 +#undef PLAT_amd64_win64 +#undef PLAT_x86_linux +#undef PLAT_amd64_linux +#undef PLAT_ppc32_linux +#undef PLAT_ppc64be_linux +#undef PLAT_ppc64le_linux +#undef PLAT_arm_linux +#undef PLAT_arm64_linux +#undef PLAT_s390x_linux +#undef PLAT_mips32_linux +#undef PLAT_mips64_linux +#undef PLAT_nanomips_linux +#undef PLAT_x86_solaris +#undef PLAT_amd64_solaris + + +#if defined(__APPLE__) && defined(__i386__) +# define PLAT_x86_darwin 1 +#elif defined(__APPLE__) && defined(__x86_64__) +# define PLAT_amd64_darwin 1 +#elif (defined(__MINGW32__) && defined(__i386__)) \ + || defined(__CYGWIN32__) \ + || (defined(_WIN32) && defined(_M_IX86)) +# define PLAT_x86_win32 1 +#elif (defined(__MINGW32__) && defined(__x86_64__)) \ + || (defined(_WIN32) && defined(_M_X64)) +/* __MINGW32__ and _WIN32 are defined in 64 bit mode as well. */ +# define PLAT_amd64_win64 1 +#elif defined(__linux__) && defined(__i386__) +# define PLAT_x86_linux 1 +#elif defined(__linux__) && defined(__x86_64__) && !defined(__ILP32__) +# define PLAT_amd64_linux 1 +#elif defined(__linux__) && defined(__powerpc__) && !defined(__powerpc64__) +# define PLAT_ppc32_linux 1 +#elif defined(__linux__) && defined(__powerpc__) && defined(__powerpc64__) && _CALL_ELF != 2 +/* Big Endian uses ELF version 1 */ +# define PLAT_ppc64be_linux 1 +#elif defined(__linux__) && defined(__powerpc__) && defined(__powerpc64__) && _CALL_ELF == 2 +/* Little Endian uses ELF version 2 */ +# define PLAT_ppc64le_linux 1 +#elif defined(__linux__) && defined(__arm__) && !defined(__aarch64__) +# define PLAT_arm_linux 1 +#elif defined(__linux__) && defined(__aarch64__) && !defined(__arm__) +# define PLAT_arm64_linux 1 +#elif defined(__linux__) && defined(__s390__) && defined(__s390x__) +# define PLAT_s390x_linux 1 +#elif defined(__linux__) && defined(__mips__) && (__mips==64) +# define PLAT_mips64_linux 1 +#elif defined(__linux__) && defined(__mips__) && (__mips==32) +# define PLAT_mips32_linux 1 +#elif defined(__linux__) && defined(__nanomips__) +# define PLAT_nanomips_linux 1 +#elif defined(__sun) && defined(__i386__) +# define PLAT_x86_solaris 1 +#elif defined(__sun) && defined(__x86_64__) +# define PLAT_amd64_solaris 1 +#else +/* If we're not compiling for our target platform, don't generate + any inline asms. */ +# if !defined(NVALGRIND) +# define NVALGRIND 1 +# endif +#endif + + +/* ------------------------------------------------------------------ */ +/* ARCHITECTURE SPECIFICS for SPECIAL INSTRUCTIONS. There is nothing */ +/* in here of use to end-users -- skip to the next section. */ +/* ------------------------------------------------------------------ */ + +/* + * VALGRIND_DO_CLIENT_REQUEST(): a statement that invokes a Valgrind client + * request. Accepts both pointers and integers as arguments. + * + * VALGRIND_DO_CLIENT_REQUEST_STMT(): a statement that invokes a Valgrind + * client request that does not return a value. + + * VALGRIND_DO_CLIENT_REQUEST_EXPR(): a C expression that invokes a Valgrind + * client request and whose value equals the client request result. Accepts + * both pointers and integers as arguments. Note that such calls are not + * necessarily pure functions -- they may have side effects. + */ + +#define VALGRIND_DO_CLIENT_REQUEST(_zzq_rlval, _zzq_default, \ + _zzq_request, _zzq_arg1, _zzq_arg2, \ + _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + do { (_zzq_rlval) = VALGRIND_DO_CLIENT_REQUEST_EXPR((_zzq_default), \ + (_zzq_request), (_zzq_arg1), (_zzq_arg2), \ + (_zzq_arg3), (_zzq_arg4), (_zzq_arg5)); } while (0) + +#define VALGRIND_DO_CLIENT_REQUEST_STMT(_zzq_request, _zzq_arg1, \ + _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + do { (void) VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + (_zzq_request), (_zzq_arg1), (_zzq_arg2), \ + (_zzq_arg3), (_zzq_arg4), (_zzq_arg5)); } while (0) + +#if defined(NVALGRIND) + +/* Define NVALGRIND to completely remove the Valgrind magic sequence + from the compiled code (analogous to NDEBUG's effects on + assert()) */ +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + (_zzq_default) + +#else /* ! NVALGRIND */ + +/* The following defines the magic code sequences which the JITter + spots and handles magically. Don't look too closely at them as + they will rot your brain. + + The assembly code sequences for all architectures is in this one + file. This is because this file must be stand-alone, and we don't + want to have multiple files. + + For VALGRIND_DO_CLIENT_REQUEST, we must ensure that the default + value gets put in the return slot, so that everything works when + this is executed not under Valgrind. Args are passed in a memory + block, and so there's no intrinsic limit to the number that could + be passed, but it's currently five. + + The macro args are: + _zzq_rlval result lvalue + _zzq_default default value (result returned when running on real CPU) + _zzq_request request code + _zzq_arg1..5 request params + + The other two macros are used to support function wrapping, and are + a lot simpler. VALGRIND_GET_NR_CONTEXT returns the value of the + guest's NRADDR pseudo-register and whatever other information is + needed to safely run the call original from the wrapper: on + ppc64-linux, the R2 value at the divert point is also needed. This + information is abstracted into a user-visible type, OrigFn. + + VALGRIND_CALL_NOREDIR_* behaves the same as the following on the + guest, but guarantees that the branch instruction will not be + redirected: x86: call *%eax, amd64: call *%rax, ppc32/ppc64: + branch-and-link-to-r11. VALGRIND_CALL_NOREDIR is just text, not a + complete inline asm, since it needs to be combined with more magic + inline asm stuff to be useful. +*/ + +/* ----------------- x86-{linux,darwin,solaris} ---------------- */ + +#if defined(PLAT_x86_linux) || defined(PLAT_x86_darwin) \ + || (defined(PLAT_x86_win32) && defined(__GNUC__)) \ + || defined(PLAT_x86_solaris) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "roll $3, %%edi ; roll $13, %%edi\n\t" \ + "roll $29, %%edi ; roll $19, %%edi\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %EDX = client_request ( %EAX ) */ \ + "xchgl %%ebx,%%ebx" \ + : "=d" (_zzq_result) \ + : "a" (&_zzq_args[0]), "0" (_zzq_default) \ + : "cc", "memory" \ + ); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %EAX = guest_NRADDR */ \ + "xchgl %%ecx,%%ecx" \ + : "=a" (__addr) \ + : \ + : "cc", "memory" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_EAX \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir *%EAX */ \ + "xchgl %%edx,%%edx\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "xchgl %%edi,%%edi\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_x86_linux || PLAT_x86_darwin || (PLAT_x86_win32 && __GNUC__) + || PLAT_x86_solaris */ + +/* ------------------------- x86-Win32 ------------------------- */ + +#if defined(PLAT_x86_win32) && !defined(__GNUC__) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#if defined(_MSC_VER) + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + __asm rol edi, 3 __asm rol edi, 13 \ + __asm rol edi, 29 __asm rol edi, 19 + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + valgrind_do_client_request_expr((uintptr_t)(_zzq_default), \ + (uintptr_t)(_zzq_request), (uintptr_t)(_zzq_arg1), \ + (uintptr_t)(_zzq_arg2), (uintptr_t)(_zzq_arg3), \ + (uintptr_t)(_zzq_arg4), (uintptr_t)(_zzq_arg5)) + +static __inline uintptr_t +valgrind_do_client_request_expr(uintptr_t _zzq_default, uintptr_t _zzq_request, + uintptr_t _zzq_arg1, uintptr_t _zzq_arg2, + uintptr_t _zzq_arg3, uintptr_t _zzq_arg4, + uintptr_t _zzq_arg5) +{ + volatile uintptr_t _zzq_args[6]; + volatile unsigned int _zzq_result; + _zzq_args[0] = (uintptr_t)(_zzq_request); + _zzq_args[1] = (uintptr_t)(_zzq_arg1); + _zzq_args[2] = (uintptr_t)(_zzq_arg2); + _zzq_args[3] = (uintptr_t)(_zzq_arg3); + _zzq_args[4] = (uintptr_t)(_zzq_arg4); + _zzq_args[5] = (uintptr_t)(_zzq_arg5); + __asm { __asm lea eax, _zzq_args __asm mov edx, _zzq_default + __SPECIAL_INSTRUCTION_PREAMBLE + /* %EDX = client_request ( %EAX ) */ + __asm xchg ebx,ebx + __asm mov _zzq_result, edx + } + return _zzq_result; +} + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned int __addr; \ + __asm { __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %EAX = guest_NRADDR */ \ + __asm xchg ecx,ecx \ + __asm mov __addr, eax \ + } \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_EAX ERROR + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm { __SPECIAL_INSTRUCTION_PREAMBLE \ + __asm xchg edi,edi \ + } \ + } while (0) + +#else +#error Unsupported compiler. +#endif + +#endif /* PLAT_x86_win32 */ + +/* ----------------- amd64-{linux,darwin,solaris} --------------- */ + +#if defined(PLAT_amd64_linux) || defined(PLAT_amd64_darwin) \ + || defined(PLAT_amd64_solaris) \ + || (defined(PLAT_amd64_win64) && defined(__GNUC__)) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rolq $3, %%rdi ; rolq $13, %%rdi\n\t" \ + "rolq $61, %%rdi ; rolq $51, %%rdi\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %RDX = client_request ( %RAX ) */ \ + "xchgq %%rbx,%%rbx" \ + : "=d" (_zzq_result) \ + : "a" (&_zzq_args[0]), "0" (_zzq_default) \ + : "cc", "memory" \ + ); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %RAX = guest_NRADDR */ \ + "xchgq %%rcx,%%rcx" \ + : "=a" (__addr) \ + : \ + : "cc", "memory" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_RAX \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir *%RAX */ \ + "xchgq %%rdx,%%rdx\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "xchgq %%rdi,%%rdi\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_amd64_linux || PLAT_amd64_darwin || PLAT_amd64_solaris */ + +/* ------------------------- amd64-Win64 ------------------------- */ + +#if defined(PLAT_amd64_win64) && !defined(__GNUC__) + +#error Unsupported compiler. + +#endif /* PLAT_amd64_win64 */ + +/* ------------------------ ppc32-linux ------------------------ */ + +#if defined(PLAT_ppc32_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rlwinm 0,0,3,0,31 ; rlwinm 0,0,13,0,31\n\t" \ + "rlwinm 0,0,29,0,31 ; rlwinm 0,0,19,0,31\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({ unsigned int _zzq_args[6]; \ + unsigned int _zzq_result; \ + unsigned int* _zzq_ptr; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + _zzq_ptr = _zzq_args; \ + __asm__ volatile("mr 3,%1\n\t" /*default*/ \ + "mr 4,%2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = client_request ( %R4 ) */ \ + "or 1,1,1\n\t" \ + "mr %0,3" /*result*/ \ + : "=b" (_zzq_result) \ + : "b" (_zzq_default), "b" (_zzq_ptr) \ + : "cc", "memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR */ \ + "or 2,2,2\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R11 */ \ + "or 3,3,3\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or 5,5,5\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_ppc32_linux */ + +/* ------------------------ ppc64-linux ------------------------ */ + +#if defined(PLAT_ppc64be_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + unsigned long int r2; /* what tocptr do we need? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rotldi 0,0,3 ; rotldi 0,0,13\n\t" \ + "rotldi 0,0,61 ; rotldi 0,0,51\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({ unsigned long int _zzq_args[6]; \ + unsigned long int _zzq_result; \ + unsigned long int* _zzq_ptr; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + _zzq_ptr = _zzq_args; \ + __asm__ volatile("mr 3,%1\n\t" /*default*/ \ + "mr 4,%2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = client_request ( %R4 ) */ \ + "or 1,1,1\n\t" \ + "mr %0,3" /*result*/ \ + : "=b" (_zzq_result) \ + : "b" (_zzq_default), "b" (_zzq_ptr) \ + : "cc", "memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR */ \ + "or 2,2,2\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR_GPR2 */ \ + "or 4,4,4\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->r2 = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R11 */ \ + "or 3,3,3\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or 5,5,5\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_ppc64be_linux */ + +#if defined(PLAT_ppc64le_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + unsigned long int r2; /* what tocptr do we need? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rotldi 0,0,3 ; rotldi 0,0,13\n\t" \ + "rotldi 0,0,61 ; rotldi 0,0,51\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({ unsigned long int _zzq_args[6]; \ + unsigned long int _zzq_result; \ + unsigned long int* _zzq_ptr; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + _zzq_ptr = _zzq_args; \ + __asm__ volatile("mr 3,%1\n\t" /*default*/ \ + "mr 4,%2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = client_request ( %R4 ) */ \ + "or 1,1,1\n\t" \ + "mr %0,3" /*result*/ \ + : "=b" (_zzq_result) \ + : "b" (_zzq_default), "b" (_zzq_ptr) \ + : "cc", "memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR */ \ + "or 2,2,2\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR_GPR2 */ \ + "or 4,4,4\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->r2 = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R12 */ \ + "or 3,3,3\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or 5,5,5\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_ppc64le_linux */ + +/* ------------------------- arm-linux ------------------------- */ + +#if defined(PLAT_arm_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "mov r12, r12, ror #3 ; mov r12, r12, ror #13 \n\t" \ + "mov r12, r12, ror #29 ; mov r12, r12, ror #19 \n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile("mov r3, %1\n\t" /*default*/ \ + "mov r4, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* R3 = client_request ( R4 ) */ \ + "orr r10, r10, r10\n\t" \ + "mov %0, r3" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "cc","memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* R3 = guest_NRADDR */ \ + "orr r11, r11, r11\n\t" \ + "mov %0, r3" \ + : "=r" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R4 */ \ + "orr r12, r12, r12\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "orr r9, r9, r9\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_arm_linux */ + +/* ------------------------ arm64-linux ------------------------- */ + +#if defined(PLAT_arm64_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "ror x12, x12, #3 ; ror x12, x12, #13 \n\t" \ + "ror x12, x12, #51 ; ror x12, x12, #61 \n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile("mov x3, %1\n\t" /*default*/ \ + "mov x4, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* X3 = client_request ( X4 ) */ \ + "orr x10, x10, x10\n\t" \ + "mov %0, x3" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" ((unsigned long int)(_zzq_default)), \ + "r" (&_zzq_args[0]) \ + : "cc","memory", "x3", "x4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* X3 = guest_NRADDR */ \ + "orr x11, x11, x11\n\t" \ + "mov %0, x3" \ + : "=r" (__addr) \ + : \ + : "cc", "memory", "x3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir X8 */ \ + "orr x12, x12, x12\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "orr x9, x9, x9\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_arm64_linux */ + +/* ------------------------ s390x-linux ------------------------ */ + +#if defined(PLAT_s390x_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + } + OrigFn; + +/* __SPECIAL_INSTRUCTION_PREAMBLE will be used to identify Valgrind specific + * code. This detection is implemented in platform specific toIR.c + * (e.g. VEX/priv/guest_s390_decoder.c). + */ +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "lr 15,15\n\t" \ + "lr 1,1\n\t" \ + "lr 2,2\n\t" \ + "lr 3,3\n\t" + +#define __CLIENT_REQUEST_CODE "lr 2,2\n\t" +#define __GET_NR_CONTEXT_CODE "lr 3,3\n\t" +#define __CALL_NO_REDIR_CODE "lr 4,4\n\t" +#define __VEX_INJECT_IR_CODE "lr 5,5\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile(/* r2 = args */ \ + "lgr 2,%1\n\t" \ + /* r3 = default */ \ + "lgr 3,%2\n\t" \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + __CLIENT_REQUEST_CODE \ + /* results = r3 */ \ + "lgr %0, 3\n\t" \ + : "=d" (_zzq_result) \ + : "a" (&_zzq_args[0]), "0" (_zzq_default) \ + : "cc", "2", "3", "memory" \ + ); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + __GET_NR_CONTEXT_CODE \ + "lgr %0, 3\n\t" \ + : "=a" (__addr) \ + : \ + : "cc", "3", "memory" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_R1 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + __CALL_NO_REDIR_CODE + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + __VEX_INJECT_IR_CODE); \ + } while (0) + +#endif /* PLAT_s390x_linux */ + +/* ------------------------- mips32-linux ---------------- */ + +#if defined(PLAT_mips32_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +/* .word 0x342 + * .word 0x742 + * .word 0xC2 + * .word 0x4C2*/ +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "srl $0, $0, 13\n\t" \ + "srl $0, $0, 29\n\t" \ + "srl $0, $0, 3\n\t" \ + "srl $0, $0, 19\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile("move $11, %1\n\t" /*default*/ \ + "move $12, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* T3 = client_request ( T4 ) */ \ + "or $13, $13, $13\n\t" \ + "move %0, $11\n\t" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "$11", "$12", "memory"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %t9 = guest_NRADDR */ \ + "or $14, $14, $14\n\t" \ + "move %0, $11" /*result*/ \ + : "=r" (__addr) \ + : \ + : "$11" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_T9 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir *%t9 */ \ + "or $15, $15, $15\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or $11, $11, $11\n\t" \ + ); \ + } while (0) + + +#endif /* PLAT_mips32_linux */ + +/* ------------------------- mips64-linux ---------------- */ + +#if defined(PLAT_mips64_linux) + +typedef + struct { + unsigned long nraddr; /* where's the code? */ + } + OrigFn; + +/* dsll $0,$0, 3 + * dsll $0,$0, 13 + * dsll $0,$0, 29 + * dsll $0,$0, 19*/ +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "dsll $0,$0, 3 ; dsll $0,$0,13\n\t" \ + "dsll $0,$0,29 ; dsll $0,$0,19\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile("move $11, %1\n\t" /*default*/ \ + "move $12, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* $11 = client_request ( $12 ) */ \ + "or $13, $13, $13\n\t" \ + "move %0, $11\n\t" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "$11", "$12", "memory"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* $11 = guest_NRADDR */ \ + "or $14, $14, $14\n\t" \ + "move %0, $11" /*result*/ \ + : "=r" (__addr) \ + : \ + : "$11"); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_T9 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir $25 */ \ + "or $15, $15, $15\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or $11, $11, $11\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_mips64_linux */ + +#if defined(PLAT_nanomips_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; +/* + 8000 c04d srl zero, zero, 13 + 8000 c05d srl zero, zero, 29 + 8000 c043 srl zero, zero, 3 + 8000 c053 srl zero, zero, 19 +*/ + +#define __SPECIAL_INSTRUCTION_PREAMBLE "srl[32] $zero, $zero, 13 \n\t" \ + "srl[32] $zero, $zero, 29 \n\t" \ + "srl[32] $zero, $zero, 3 \n\t" \ + "srl[32] $zero, $zero, 19 \n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile("move $a7, %1\n\t" /* default */ \ + "move $t0, %2\n\t" /* ptr */ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* $a7 = client_request( $t0 ) */ \ + "or[32] $t0, $t0, $t0\n\t" \ + "move %0, $a7\n\t" /* result */ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "$a7", "$t0", "memory"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* $a7 = guest_NRADDR */ \ + "or[32] $t1, $t1, $t1\n\t" \ + "move %0, $a7" /*result*/ \ + : "=r" (__addr) \ + : \ + : "$a7"); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_T9 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir $25 */ \ + "or[32] $t2, $t2, $t2\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or[32] $t3, $t3, $t3\n\t" \ + ); \ + } while (0) + +#endif +/* Insert assembly code for other platforms here... */ + +#endif /* NVALGRIND */ + + +/* ------------------------------------------------------------------ */ +/* PLATFORM SPECIFICS for FUNCTION WRAPPING. This is all very */ +/* ugly. It's the least-worst tradeoff I can think of. */ +/* ------------------------------------------------------------------ */ + +/* This section defines magic (a.k.a appalling-hack) macros for doing + guaranteed-no-redirection macros, so as to get from function + wrappers to the functions they are wrapping. The whole point is to + construct standard call sequences, but to do the call itself with a + special no-redirect call pseudo-instruction that the JIT + understands and handles specially. This section is long and + repetitious, and I can't see a way to make it shorter. + + The naming scheme is as follows: + + CALL_FN_{W,v}_{v,W,WW,WWW,WWWW,5W,6W,7W,etc} + + 'W' stands for "word" and 'v' for "void". Hence there are + different macros for calling arity 0, 1, 2, 3, 4, etc, functions, + and for each, the possibility of returning a word-typed result, or + no result. +*/ + +/* Use these to write the name of your wrapper. NOTE: duplicates + VG_WRAP_FUNCTION_Z{U,Z} in pub_tool_redir.h. NOTE also: inserts + the default behaviour equivalance class tag "0000" into the name. + See pub_tool_redir.h for details -- normally you don't need to + think about this, though. */ + +/* Use an extra level of macroisation so as to ensure the soname/fnname + args are fully macro-expanded before pasting them together. */ +#define VG_CONCAT4(_aa,_bb,_cc,_dd) _aa##_bb##_cc##_dd + +#define I_WRAP_SONAME_FNNAME_ZU(soname,fnname) \ + VG_CONCAT4(_vgw00000ZU_,soname,_,fnname) + +#define I_WRAP_SONAME_FNNAME_ZZ(soname,fnname) \ + VG_CONCAT4(_vgw00000ZZ_,soname,_,fnname) + +/* Use this macro from within a wrapper function to collect the + context (address and possibly other info) of the original function. + Once you have that you can then use it in one of the CALL_FN_ + macros. The type of the argument _lval is OrigFn. */ +#define VALGRIND_GET_ORIG_FN(_lval) VALGRIND_GET_NR_CONTEXT(_lval) + +/* Also provide end-user facilities for function replacement, rather + than wrapping. A replacement function differs from a wrapper in + that it has no way to get hold of the original function being + called, and hence no way to call onwards to it. In a replacement + function, VALGRIND_GET_ORIG_FN always returns zero. */ + +#define I_REPLACE_SONAME_FNNAME_ZU(soname,fnname) \ + VG_CONCAT4(_vgr00000ZU_,soname,_,fnname) + +#define I_REPLACE_SONAME_FNNAME_ZZ(soname,fnname) \ + VG_CONCAT4(_vgr00000ZZ_,soname,_,fnname) + +/* Derivatives of the main macros below, for calling functions + returning void. */ + +#define CALL_FN_v_v(fnptr) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_v(_junk,fnptr); } while (0) + +#define CALL_FN_v_W(fnptr, arg1) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_W(_junk,fnptr,arg1); } while (0) + +#define CALL_FN_v_WW(fnptr, arg1,arg2) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_WW(_junk,fnptr,arg1,arg2); } while (0) + +#define CALL_FN_v_WWW(fnptr, arg1,arg2,arg3) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_WWW(_junk,fnptr,arg1,arg2,arg3); } while (0) + +#define CALL_FN_v_WWWW(fnptr, arg1,arg2,arg3,arg4) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_WWWW(_junk,fnptr,arg1,arg2,arg3,arg4); } while (0) + +#define CALL_FN_v_5W(fnptr, arg1,arg2,arg3,arg4,arg5) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_5W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5); } while (0) + +#define CALL_FN_v_6W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_6W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6); } while (0) + +#define CALL_FN_v_7W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6,arg7) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_7W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6,arg7); } while (0) + +/* ----------------- x86-{linux,darwin,solaris} ---------------- */ + +#if defined(PLAT_x86_linux) || defined(PLAT_x86_darwin) \ + || defined(PLAT_x86_solaris) + +/* These regs are trashed by the hidden call. No need to mention eax + as gcc can already see that, plus causes gcc to bomb. */ +#define __CALLER_SAVED_REGS /*"eax"*/ "ecx", "edx" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "movl %%esp,%%edi\n\t" \ + "andl $0xfffffff0,%%esp\n\t" +#define VALGRIND_RESTORE_STACK \ + "movl %%edi,%%esp\n\t" + +/* These CALL_FN_ macros assume that on x86-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $12, %%esp\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $8, %%esp\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $4, %%esp\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $12, %%esp\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $8, %%esp\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $4, %%esp\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $12, %%esp\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $8, %%esp\n\t" \ + "pushl 40(%%eax)\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $4, %%esp\n\t" \ + "pushl 44(%%eax)\n\t" \ + "pushl 40(%%eax)\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "pushl 48(%%eax)\n\t" \ + "pushl 44(%%eax)\n\t" \ + "pushl 40(%%eax)\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_x86_linux || PLAT_x86_darwin || PLAT_x86_solaris */ + +/* ---------------- amd64-{linux,darwin,solaris} --------------- */ + +#if defined(PLAT_amd64_linux) || defined(PLAT_amd64_darwin) \ + || defined(PLAT_amd64_solaris) + +/* ARGREGS: rdi rsi rdx rcx r8 r9 (the rest on stack in R-to-L order) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS /*"rax",*/ "rcx", "rdx", "rsi", \ + "rdi", "r8", "r9", "r10", "r11" + +/* This is all pretty complex. It's so as to make stack unwinding + work reliably. See bug 243270. The basic problem is the sub and + add of 128 of %rsp in all of the following macros. If gcc believes + the CFA is in %rsp, then unwinding may fail, because what's at the + CFA is not what gcc "expected" when it constructs the CFIs for the + places where the macros are instantiated. + + But we can't just add a CFI annotation to increase the CFA offset + by 128, to match the sub of 128 from %rsp, because we don't know + whether gcc has chosen %rsp as the CFA at that point, or whether it + has chosen some other register (eg, %rbp). In the latter case, + adding a CFI annotation to change the CFA offset is simply wrong. + + So the solution is to get hold of the CFA using + __builtin_dwarf_cfa(), put it in a known register, and add a + CFI annotation to say what the register is. We choose %rbp for + this (perhaps perversely), because: + + (1) %rbp is already subject to unwinding. If a new register was + chosen then the unwinder would have to unwind it in all stack + traces, which is expensive, and + + (2) %rbp is already subject to precise exception updates in the + JIT. If a new register was chosen, we'd have to have precise + exceptions for it too, which reduces performance of the + generated code. + + However .. one extra complication. We can't just whack the result + of __builtin_dwarf_cfa() into %rbp and then add %rbp to the + list of trashed registers at the end of the inline assembly + fragments; gcc won't allow %rbp to appear in that list. Hence + instead we need to stash %rbp in %r15 for the duration of the asm, + and say that %r15 is trashed instead. gcc seems happy to go with + that. + + Oh .. and this all needs to be conditionalised so that it is + unchanged from before this commit, when compiled with older gccs + that don't support __builtin_dwarf_cfa. Furthermore, since + this header file is freestanding, it has to be independent of + config.h, and so the following conditionalisation cannot depend on + configure time checks. + + Although it's not clear from + 'defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM)', + this expression excludes Darwin. + .cfi directives in Darwin assembly appear to be completely + different and I haven't investigated how they work. + + For even more entertainment value, note we have to use the + completely undocumented __builtin_dwarf_cfa(), which appears to + really compute the CFA, whereas __builtin_frame_address(0) claims + to but actually doesn't. See + https://bugs.kde.org/show_bug.cgi?id=243270#c47 +*/ +#if defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM) +# define __FRAME_POINTER \ + ,"r"(__builtin_dwarf_cfa()) +# define VALGRIND_CFI_PROLOGUE \ + "movq %%rbp, %%r15\n\t" \ + "movq %2, %%rbp\n\t" \ + ".cfi_remember_state\n\t" \ + ".cfi_def_cfa rbp, 0\n\t" +# define VALGRIND_CFI_EPILOGUE \ + "movq %%r15, %%rbp\n\t" \ + ".cfi_restore_state\n\t" +#else +# define __FRAME_POINTER +# define VALGRIND_CFI_PROLOGUE +# define VALGRIND_CFI_EPILOGUE +#endif + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "movq %%rsp,%%r14\n\t" \ + "andq $0xfffffffffffffff0,%%rsp\n\t" +#define VALGRIND_RESTORE_STACK \ + "movq %%r14,%%rsp\n\t" + +/* These CALL_FN_ macros assume that on amd64-linux, sizeof(unsigned + long) == 8. */ + +/* NB 9 Sept 07. There is a nasty kludge here in all these CALL_FN_ + macros. In order not to trash the stack redzone, we need to drop + %rsp by 128 before the hidden call, and restore afterwards. The + nastyness is that it is only by luck that the stack still appears + to be unwindable during the hidden call - since then the behaviour + of any routine using this macro does not match what the CFI data + says. Sigh. + + Why is this important? Imagine that a wrapper has a stack + allocated local, and passes to the hidden call, a pointer to it. + Because gcc does not know about the hidden call, it may allocate + that local in the redzone. Unfortunately the hidden call may then + trash it before it comes to use it. So we must step clear of the + redzone, for the duration of the hidden call, to make it safe. + + Probably the same problem afflicts the other redzone-style ABIs too + (ppc64-linux); but for those, the stack is + self describing (none of this CFI nonsense) so at least messing + with the stack pointer doesn't give a danger of non-unwindable + stack. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $136,%%rsp\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $136,%%rsp\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "pushq 80(%%rax)\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $136,%%rsp\n\t" \ + "pushq 88(%%rax)\n\t" \ + "pushq 80(%%rax)\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "pushq 96(%%rax)\n\t" \ + "pushq 88(%%rax)\n\t" \ + "pushq 80(%%rax)\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_amd64_linux || PLAT_amd64_darwin || PLAT_amd64_solaris */ + +/* ------------------------ ppc32-linux ------------------------ */ + +#if defined(PLAT_ppc32_linux) + +/* This is useful for finding out about the on-stack stuff: + + extern int f9 ( int,int,int,int,int,int,int,int,int ); + extern int f10 ( int,int,int,int,int,int,int,int,int,int ); + extern int f11 ( int,int,int,int,int,int,int,int,int,int,int ); + extern int f12 ( int,int,int,int,int,int,int,int,int,int,int,int ); + + int g9 ( void ) { + return f9(11,22,33,44,55,66,77,88,99); + } + int g10 ( void ) { + return f10(11,22,33,44,55,66,77,88,99,110); + } + int g11 ( void ) { + return f11(11,22,33,44,55,66,77,88,99,110,121); + } + int g12 ( void ) { + return f12(11,22,33,44,55,66,77,88,99,110,121,132); + } +*/ + +/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "lr", "ctr", "xer", \ + "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ + "r0", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ + "r11", "r12", "r13" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "mr 28,1\n\t" \ + "rlwinm 1,1,0,0,27\n\t" +#define VALGRIND_RESTORE_STACK \ + "mr 1,28\n\t" + +/* These CALL_FN_ macros assume that on ppc32-linux, + sizeof(unsigned long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-16\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-16\n\t" \ + /* arg10 */ \ + "lwz 3,40(11)\n\t" \ + "stw 3,12(1)\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-32\n\t" \ + /* arg11 */ \ + "lwz 3,44(11)\n\t" \ + "stw 3,16(1)\n\t" \ + /* arg10 */ \ + "lwz 3,40(11)\n\t" \ + "stw 3,12(1)\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + _argvec[12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-32\n\t" \ + /* arg12 */ \ + "lwz 3,48(11)\n\t" \ + "stw 3,20(1)\n\t" \ + /* arg11 */ \ + "lwz 3,44(11)\n\t" \ + "stw 3,16(1)\n\t" \ + /* arg10 */ \ + "lwz 3,40(11)\n\t" \ + "stw 3,12(1)\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_ppc32_linux */ + +/* ------------------------ ppc64-linux ------------------------ */ + +#if defined(PLAT_ppc64be_linux) + +/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "lr", "ctr", "xer", \ + "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ + "r0", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ + "r11", "r12", "r13" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "mr 28,1\n\t" \ + "rldicr 1,1,0,59\n\t" +#define VALGRIND_RESTORE_STACK \ + "mr 1,28\n\t" + +/* These CALL_FN_ macros assume that on ppc64-linux, sizeof(unsigned + long) == 8. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+0]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+1]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+2]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+3]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+4]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+5]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+6]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+7]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+8]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+9]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+10]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg10 */ \ + "ld 3,80(11)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+11]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg11 */ \ + "ld 3,88(11)\n\t" \ + "std 3,128(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(11)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+12]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + _argvec[2+12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg12 */ \ + "ld 3,96(11)\n\t" \ + "std 3,136(1)\n\t" \ + /* arg11 */ \ + "ld 3,88(11)\n\t" \ + "std 3,128(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(11)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_ppc64be_linux */ + +/* ------------------------- ppc64le-linux ----------------------- */ +#if defined(PLAT_ppc64le_linux) + +/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "lr", "ctr", "xer", \ + "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ + "r0", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ + "r11", "r12", "r13" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "mr 28,1\n\t" \ + "rldicr 1,1,0,59\n\t" +#define VALGRIND_RESTORE_STACK \ + "mr 1,28\n\t" + +/* These CALL_FN_ macros assume that on ppc64-linux, sizeof(unsigned + long) == 8. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+0]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+1]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+2]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+3]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+4]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+5]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+6]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+7]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+8]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+9]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+10]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg10 */ \ + "ld 3,80(12)\n\t" \ + "std 3,104(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+11]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg11 */ \ + "ld 3,88(12)\n\t" \ + "std 3,112(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(12)\n\t" \ + "std 3,104(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+12]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + _argvec[2+12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg12 */ \ + "ld 3,96(12)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg11 */ \ + "ld 3,88(12)\n\t" \ + "std 3,112(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(12)\n\t" \ + "std 3,104(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_ppc64le_linux */ + +/* ------------------------- arm-linux ------------------------- */ + +#if defined(PLAT_arm_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "r0", "r1", "r2", "r3","r4", "r12", "r14" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +/* This is a bit tricky. We store the original stack pointer in r10 + as it is callee-saves. gcc doesn't allow the use of r11 for some + reason. Also, we can't directly "bic" the stack pointer in thumb + mode since r13 isn't an allowed register number in that context. + So use r4 as a temporary, since that is about to get trashed + anyway, just after each use of this macro. Side effect is we need + to be very careful about any future changes, since + VALGRIND_ALIGN_STACK simply assumes r4 is usable. */ +#define VALGRIND_ALIGN_STACK \ + "mov r10, sp\n\t" \ + "mov r4, sp\n\t" \ + "bic r4, r4, #7\n\t" \ + "mov sp, r4\n\t" +#define VALGRIND_RESTORE_STACK \ + "mov sp, r10\n\t" + +/* These CALL_FN_ macros assume that on arm-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "push {r0} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "push {r0, r1} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "push {r0, r1, r2} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "push {r0, r1, r2, r3} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #40] \n\t" \ + "push {r0} \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #40] \n\t" \ + "ldr r1, [%1, #44] \n\t" \ + "push {r0, r1} \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #40] \n\t" \ + "ldr r1, [%1, #44] \n\t" \ + "ldr r2, [%1, #48] \n\t" \ + "push {r0, r1, r2} \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_arm_linux */ + +/* ------------------------ arm64-linux ------------------------ */ + +#if defined(PLAT_arm64_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "x0", "x1", "x2", "x3","x4", "x5", "x6", "x7", "x8", "x9", \ + "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", \ + "x18", "x19", "x20", "x30", \ + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", \ + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", \ + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", \ + "v26", "v27", "v28", "v29", "v30", "v31" + +/* x21 is callee-saved, so we can use it to save and restore SP around + the hidden call. */ +#define VALGRIND_ALIGN_STACK \ + "mov x21, sp\n\t" \ + "bic sp, x21, #15\n\t" +#define VALGRIND_RESTORE_STACK \ + "mov sp, x21\n\t" + +/* These CALL_FN_ macros assume that on arm64-linux, + sizeof(unsigned long) == 8. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x20 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x20 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1, #80] \n\t" \ + "str x8, [sp, #8] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x30 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1, #80] \n\t" \ + "str x8, [sp, #8] \n\t" \ + "ldr x8, [%1, #88] \n\t" \ + "str x8, [sp, #16] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11, \ + arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x30 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1, #80] \n\t" \ + "str x8, [sp, #8] \n\t" \ + "ldr x8, [%1, #88] \n\t" \ + "str x8, [sp, #16] \n\t" \ + "ldr x8, [%1, #96] \n\t" \ + "str x8, [sp, #24] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_arm64_linux */ + +/* ------------------------- s390x-linux ------------------------- */ + +#if defined(PLAT_s390x_linux) + +/* Similar workaround as amd64 (see above), but we use r11 as frame + pointer and save the old r11 in r7. r11 might be used for + argvec, therefore we copy argvec in r1 since r1 is clobbered + after the call anyway. */ +#if defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM) +# define __FRAME_POINTER \ + ,"d"(__builtin_dwarf_cfa()) +# define VALGRIND_CFI_PROLOGUE \ + ".cfi_remember_state\n\t" \ + "lgr 1,%1\n\t" /* copy the argvec pointer in r1 */ \ + "lgr 7,11\n\t" \ + "lgr 11,%2\n\t" \ + ".cfi_def_cfa r11, 0\n\t" +# define VALGRIND_CFI_EPILOGUE \ + "lgr 11, 7\n\t" \ + ".cfi_restore_state\n\t" +#else +# define __FRAME_POINTER +# define VALGRIND_CFI_PROLOGUE \ + "lgr 1,%1\n\t" +# define VALGRIND_CFI_EPILOGUE +#endif + +/* Nb: On s390 the stack pointer is properly aligned *at all times* + according to the s390 GCC maintainer. (The ABI specification is not + precise in this regard.) Therefore, VALGRIND_ALIGN_STACK and + VALGRIND_RESTORE_STACK are not defined here. */ + +/* These regs are trashed by the hidden call. Note that we overwrite + r14 in s390_irgen_noredir (VEX/priv/guest_s390_irgen.c) to give the + function a proper return address. All others are ABI defined call + clobbers. */ +#if defined(__VX__) || defined(__S390_VX__) +#define __CALLER_SAVED_REGS "0", "1", "2", "3", "4", "5", "14", \ + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", \ + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", \ + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", \ + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" +#else +#define __CALLER_SAVED_REGS "0", "1", "2", "3", "4", "5", "14", \ + "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7" +#endif + +/* Nb: Although r11 is modified in the asm snippets below (inside + VALGRIND_CFI_PROLOGUE) it is not listed in the clobber section, for + two reasons: + (1) r11 is restored in VALGRIND_CFI_EPILOGUE, so effectively it is not + modified + (2) GCC will complain that r11 cannot appear inside a clobber section, + when compiled with -O -fno-omit-frame-pointer + */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 1, 0(1)\n\t" /* target->r1 */ \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "d" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +/* The call abi has the arguments in r2-r6 and stack */ +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1, arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1, arg2, arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1, arg2, arg3, arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1, arg2, arg3, arg4, arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-168\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,168\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-176\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,176\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-184\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,184\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-192\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,192\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9, arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-200\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "mvc 192(8,15), 80(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,200\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9, arg10, arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-208\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "mvc 192(8,15), 80(1)\n\t" \ + "mvc 200(8,15), 88(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,208\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9, arg10, arg11, arg12)\ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + _argvec[12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-216\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "mvc 192(8,15), 80(1)\n\t" \ + "mvc 200(8,15), 88(1)\n\t" \ + "mvc 208(8,15), 96(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,216\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + + +#endif /* PLAT_s390x_linux */ + +/* ------------------------- mips32-linux ----------------------- */ + +#if defined(PLAT_mips32_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "$2", "$3", "$4", "$5", "$6", \ +"$7", "$8", "$9", "$10", "$11", "$12", "$13", "$14", "$15", "$24", \ +"$25", "$31" + +/* These CALL_FN_ macros assume that on mips-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16\n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" /* arg1*/ \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 24\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 24 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 32\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "nop\n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 32 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 32\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 32 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 40\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 40 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 40\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 40 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 48\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 40(%1) \n\t" \ + "sw $4, 36($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 48 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 48\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 40(%1) \n\t" \ + "sw $4, 36($29) \n\t" \ + "lw $4, 44(%1) \n\t" \ + "sw $4, 40($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 48 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 56\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 40(%1) \n\t" \ + "sw $4, 36($29) \n\t" \ + "lw $4, 44(%1) \n\t" \ + "sw $4, 40($29) \n\t" \ + "lw $4, 48(%1) \n\t" \ + "sw $4, 44($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 56 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_mips32_linux */ + +/* ------------------------- nanomips-linux -------------------- */ + +#if defined(PLAT_nanomips_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "$t4", "$t5", "$a0", "$a1", "$a2", \ +"$a3", "$a4", "$a5", "$a6", "$a7", "$t0", "$t1", "$t2", "$t3", \ +"$t8","$t9", "$at" + +/* These CALL_FN_ macros assume that on mips-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + "lw $a5,24(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + "lw $a5,24(%1)\n\t" \ + "lw $a6,28(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + "lw $a5,24(%1)\n\t" \ + "lw $a6,28(%1)\n\t" \ + "lw $a7,32(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9,40(%1) \n\t" \ + "sw $t9, 4($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9,40(%1) \n\t" \ + "sw $t9, 4($sp) \n\t" \ + "lw $t9,44(%1) \n\t" \ + "sw $t9, 8($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9,40(%1) \n\t" \ + "sw $t9, 4($sp) \n\t" \ + "lw $t9,44(%1) \n\t" \ + "sw $t9, 8($sp) \n\t" \ + "lw $t9,48(%1) \n\t" \ + "sw $t9,12($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_nanomips_linux */ + +/* ------------------------- mips64-linux ------------------------- */ + +#if defined(PLAT_mips64_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "$2", "$3", "$4", "$5", "$6", \ +"$7", "$8", "$9", "$10", "$11", "$12", "$13", "$14", "$15", "$24", \ +"$25", "$31" + +/* These CALL_FN_ macros assume that on mips64-linux, + sizeof(long long) == 8. */ + +#define MIPS64_LONG2REG_CAST(x) ((long long)(long)x) + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[1]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + __asm__ volatile( \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[2]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" /* arg1*/ \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[3]; \ + volatile unsigned long long _res; \ + _argvec[0] = _orig.nraddr; \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[4]; \ + volatile unsigned long long _res; \ + _argvec[0] = _orig.nraddr; \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[5]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[6]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[7]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[8]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[9]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[10]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + __asm__ volatile( \ + "dsubu $29, $29, 8\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 8\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[11]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + _argvec[10] = MIPS64_LONG2REG_CAST(arg10); \ + __asm__ volatile( \ + "dsubu $29, $29, 16\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 80(%1)\n\t" \ + "sd $4, 8($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 16\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[12]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + _argvec[10] = MIPS64_LONG2REG_CAST(arg10); \ + _argvec[11] = MIPS64_LONG2REG_CAST(arg11); \ + __asm__ volatile( \ + "dsubu $29, $29, 24\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 80(%1)\n\t" \ + "sd $4, 8($29)\n\t" \ + "ld $4, 88(%1)\n\t" \ + "sd $4, 16($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 24\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[13]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + _argvec[10] = MIPS64_LONG2REG_CAST(arg10); \ + _argvec[11] = MIPS64_LONG2REG_CAST(arg11); \ + _argvec[12] = MIPS64_LONG2REG_CAST(arg12); \ + __asm__ volatile( \ + "dsubu $29, $29, 32\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 80(%1)\n\t" \ + "sd $4, 8($29)\n\t" \ + "ld $4, 88(%1)\n\t" \ + "sd $4, 16($29)\n\t" \ + "ld $4, 96(%1)\n\t" \ + "sd $4, 24($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 32\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#endif /* PLAT_mips64_linux */ + +/* ------------------------------------------------------------------ */ +/* ARCHITECTURE INDEPENDENT MACROS for CLIENT REQUESTS. */ +/* */ +/* ------------------------------------------------------------------ */ + +/* Some request codes. There are many more of these, but most are not + exposed to end-user view. These are the public ones, all of the + form 0x1000 + small_number. + + Core ones are in the range 0x00000000--0x0000ffff. The non-public + ones start at 0x2000. +*/ + +/* These macros are used by tools -- they must be public, but don't + embed them into other programs. */ +#define VG_USERREQ_TOOL_BASE(a,b) \ + ((unsigned int)(((a)&0xff) << 24 | ((b)&0xff) << 16)) +#define VG_IS_TOOL_USERREQ(a, b, v) \ + (VG_USERREQ_TOOL_BASE(a,b) == ((v) & 0xffff0000)) + +/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !! + This enum comprises an ABI exported by Valgrind to programs + which use client requests. DO NOT CHANGE THE NUMERIC VALUES OF THESE + ENTRIES, NOR DELETE ANY -- add new ones at the end of the most + relevant group. */ +typedef + enum { VG_USERREQ__RUNNING_ON_VALGRIND = 0x1001, + VG_USERREQ__DISCARD_TRANSLATIONS = 0x1002, + + /* These allow any function to be called from the simulated + CPU but run on the real CPU. Nb: the first arg passed to + the function is always the ThreadId of the running + thread! So CLIENT_CALL0 actually requires a 1 arg + function, etc. */ + VG_USERREQ__CLIENT_CALL0 = 0x1101, + VG_USERREQ__CLIENT_CALL1 = 0x1102, + VG_USERREQ__CLIENT_CALL2 = 0x1103, + VG_USERREQ__CLIENT_CALL3 = 0x1104, + + /* Can be useful in regression testing suites -- eg. can + send Valgrind's output to /dev/null and still count + errors. */ + VG_USERREQ__COUNT_ERRORS = 0x1201, + + /* Allows the client program and/or gdbserver to execute a monitor + command. */ + VG_USERREQ__GDB_MONITOR_COMMAND = 0x1202, + + /* Allows the client program to change a dynamic command line + option. */ + VG_USERREQ__CLO_CHANGE = 0x1203, + + /* These are useful and can be interpreted by any tool that + tracks malloc() et al, by using vg_replace_malloc.c. */ + VG_USERREQ__MALLOCLIKE_BLOCK = 0x1301, + VG_USERREQ__RESIZEINPLACE_BLOCK = 0x130b, + VG_USERREQ__FREELIKE_BLOCK = 0x1302, + /* Memory pool support. */ + VG_USERREQ__CREATE_MEMPOOL = 0x1303, + VG_USERREQ__DESTROY_MEMPOOL = 0x1304, + VG_USERREQ__MEMPOOL_ALLOC = 0x1305, + VG_USERREQ__MEMPOOL_FREE = 0x1306, + VG_USERREQ__MEMPOOL_TRIM = 0x1307, + VG_USERREQ__MOVE_MEMPOOL = 0x1308, + VG_USERREQ__MEMPOOL_CHANGE = 0x1309, + VG_USERREQ__MEMPOOL_EXISTS = 0x130a, + + /* Allow printfs to valgrind log. */ + /* The first two pass the va_list argument by value, which + assumes it is the same size as or smaller than a UWord, + which generally isn't the case. Hence are deprecated. + The second two pass the vargs by reference and so are + immune to this problem. */ + /* both :: char* fmt, va_list vargs (DEPRECATED) */ + VG_USERREQ__PRINTF = 0x1401, + VG_USERREQ__PRINTF_BACKTRACE = 0x1402, + /* both :: char* fmt, va_list* vargs */ + VG_USERREQ__PRINTF_VALIST_BY_REF = 0x1403, + VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF = 0x1404, + + /* Stack support. */ + VG_USERREQ__STACK_REGISTER = 0x1501, + VG_USERREQ__STACK_DEREGISTER = 0x1502, + VG_USERREQ__STACK_CHANGE = 0x1503, + + /* Wine support */ + VG_USERREQ__LOAD_PDB_DEBUGINFO = 0x1601, + + /* Querying of debug info. */ + VG_USERREQ__MAP_IP_TO_SRCLOC = 0x1701, + + /* Disable/enable error reporting level. Takes a single + Word arg which is the delta to this thread's error + disablement indicator. Hence 1 disables or further + disables errors, and -1 moves back towards enablement. + Other values are not allowed. */ + VG_USERREQ__CHANGE_ERR_DISABLEMENT = 0x1801, + + /* Some requests used for Valgrind internal, such as + self-test or self-hosting. */ + /* Initialise IR injection */ + VG_USERREQ__VEX_INIT_FOR_IRI = 0x1901, + /* Used by Inner Valgrind to inform Outer Valgrind where to + find the list of inner guest threads */ + VG_USERREQ__INNER_THREADS = 0x1902 + } Vg_ClientRequest; + +#if !defined(__GNUC__) +# define __extension__ /* */ +#endif + + +/* Returns the number of Valgrinds this code is running under. That + is, 0 if running natively, 1 if running under Valgrind, 2 if + running under Valgrind which is running under another Valgrind, + etc. */ +#define RUNNING_ON_VALGRIND \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* if not */, \ + VG_USERREQ__RUNNING_ON_VALGRIND, \ + 0, 0, 0, 0, 0) \ + + +/* Discard translation of code in the range [_qzz_addr .. _qzz_addr + + _qzz_len - 1]. Useful if you are debugging a JITter or some such, + since it provides a way to make sure valgrind will retranslate the + invalidated area. Returns no value. */ +#define VALGRIND_DISCARD_TRANSLATIONS(_qzz_addr,_qzz_len) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DISCARD_TRANSLATIONS, \ + _qzz_addr, _qzz_len, 0, 0, 0) + +#define VALGRIND_INNER_THREADS(_qzz_addr) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__INNER_THREADS, \ + _qzz_addr, 0, 0, 0, 0) + + +/* These requests are for getting Valgrind itself to print something. + Possibly with a backtrace. This is a really ugly hack. The return value + is the number of characters printed, excluding the "**** " part at the + start and the backtrace (if present). */ + +#if defined(__GNUC__) || defined(__INTEL_COMPILER) && !defined(_MSC_VER) +/* Modern GCC will optimize the static routine out if unused, + and unused attribute will shut down warnings about it. */ +static int VALGRIND_PRINTF(const char *format, ...) + __attribute__((format(__printf__, 1, 2), __unused__)); +#endif +static int +#if defined(_MSC_VER) +__inline +#endif +VALGRIND_PRINTF(const char *format, ...) +{ +#if defined(NVALGRIND) + (void)format; + return 0; +#else /* NVALGRIND */ +#if defined(_MSC_VER) || defined(__MINGW64__) + uintptr_t _qzz_res; +#else + unsigned long _qzz_res; +#endif + va_list vargs; + va_start(vargs, format); +#if defined(_MSC_VER) || defined(__MINGW64__) + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_VALIST_BY_REF, + (uintptr_t)format, + (uintptr_t)&vargs, + 0, 0, 0); +#else + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_VALIST_BY_REF, + (unsigned long)format, + (unsigned long)&vargs, + 0, 0, 0); +#endif + va_end(vargs); + return (int)_qzz_res; +#endif /* NVALGRIND */ +} + +#if defined(__GNUC__) || defined(__INTEL_COMPILER) && !defined(_MSC_VER) +static int VALGRIND_PRINTF_BACKTRACE(const char *format, ...) + __attribute__((format(__printf__, 1, 2), __unused__)); +#endif +static int +#if defined(_MSC_VER) +__inline +#endif +VALGRIND_PRINTF_BACKTRACE(const char *format, ...) +{ +#if defined(NVALGRIND) + (void)format; + return 0; +#else /* NVALGRIND */ +#if defined(_MSC_VER) || defined(__MINGW64__) + uintptr_t _qzz_res; +#else + unsigned long _qzz_res; +#endif + va_list vargs; + va_start(vargs, format); +#if defined(_MSC_VER) || defined(__MINGW64__) + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF, + (uintptr_t)format, + (uintptr_t)&vargs, + 0, 0, 0); +#else + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF, + (unsigned long)format, + (unsigned long)&vargs, + 0, 0, 0); +#endif + va_end(vargs); + return (int)_qzz_res; +#endif /* NVALGRIND */ +} + + +/* These requests allow control to move from the simulated CPU to the + real CPU, calling an arbitrary function. + + Note that the current ThreadId is inserted as the first argument. + So this call: + + VALGRIND_NON_SIMD_CALL2(f, arg1, arg2) + + requires f to have this signature: + + Word f(Word tid, Word arg1, Word arg2) + + where "Word" is a word-sized type. + + Note that these client requests are not entirely reliable. For example, + if you call a function with them that subsequently calls printf(), + there's a high chance Valgrind will crash. Generally, your prospects of + these working are made higher if the called function does not refer to + any global variables, and does not refer to any libc or other functions + (printf et al). Any kind of entanglement with libc or dynamic linking is + likely to have a bad outcome, for tricky reasons which we've grappled + with a lot in the past. +*/ +#define VALGRIND_NON_SIMD_CALL0(_qyy_fn) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL0, \ + _qyy_fn, \ + 0, 0, 0, 0) + +#define VALGRIND_NON_SIMD_CALL1(_qyy_fn, _qyy_arg1) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL1, \ + _qyy_fn, \ + _qyy_arg1, 0, 0, 0) + +#define VALGRIND_NON_SIMD_CALL2(_qyy_fn, _qyy_arg1, _qyy_arg2) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL2, \ + _qyy_fn, \ + _qyy_arg1, _qyy_arg2, 0, 0) + +#define VALGRIND_NON_SIMD_CALL3(_qyy_fn, _qyy_arg1, _qyy_arg2, _qyy_arg3) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL3, \ + _qyy_fn, \ + _qyy_arg1, _qyy_arg2, \ + _qyy_arg3, 0) + + +/* Counts the number of errors that have been recorded by a tool. Nb: + the tool must record the errors with VG_(maybe_record_error)() or + VG_(unique_error)() for them to be counted. */ +#define VALGRIND_COUNT_ERRORS \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + 0 /* default return */, \ + VG_USERREQ__COUNT_ERRORS, \ + 0, 0, 0, 0, 0) + +/* Several Valgrind tools (Memcheck, Massif, Helgrind, DRD) rely on knowing + when heap blocks are allocated in order to give accurate results. This + happens automatically for the standard allocator functions such as + malloc(), calloc(), realloc(), memalign(), new, new[], free(), delete, + delete[], etc. + + But if your program uses a custom allocator, this doesn't automatically + happen, and Valgrind will not do as well. For example, if you allocate + superblocks with mmap() and then allocates chunks of the superblocks, all + Valgrind's observations will be at the mmap() level and it won't know that + the chunks should be considered separate entities. In Memcheck's case, + that means you probably won't get heap block overrun detection (because + there won't be redzones marked as unaddressable) and you definitely won't + get any leak detection. + + The following client requests allow a custom allocator to be annotated so + that it can be handled accurately by Valgrind. + + VALGRIND_MALLOCLIKE_BLOCK marks a region of memory as having been allocated + by a malloc()-like function. For Memcheck (an illustrative case), this + does two things: + + - It records that the block has been allocated. This means any addresses + within the block mentioned in error messages will be + identified as belonging to the block. It also means that if the block + isn't freed it will be detected by the leak checker. + + - It marks the block as being addressable and undefined (if 'is_zeroed' is + not set), or addressable and defined (if 'is_zeroed' is set). This + controls how accesses to the block by the program are handled. + + 'addr' is the start of the usable block (ie. after any + redzone), 'sizeB' is its size. 'rzB' is the redzone size if the allocator + can apply redzones -- these are blocks of padding at the start and end of + each block. Adding redzones is recommended as it makes it much more likely + Valgrind will spot block overruns. `is_zeroed' indicates if the memory is + zeroed (or filled with another predictable value), as is the case for + calloc(). + + VALGRIND_MALLOCLIKE_BLOCK should be put immediately after the point where a + heap block -- that will be used by the client program -- is allocated. + It's best to put it at the outermost level of the allocator if possible; + for example, if you have a function my_alloc() which calls + internal_alloc(), and the client request is put inside internal_alloc(), + stack traces relating to the heap block will contain entries for both + my_alloc() and internal_alloc(), which is probably not what you want. + + For Memcheck users: if you use VALGRIND_MALLOCLIKE_BLOCK to carve out + custom blocks from within a heap block, B, that has been allocated with + malloc/calloc/new/etc, then block B will be *ignored* during leak-checking + -- the custom blocks will take precedence. + + VALGRIND_FREELIKE_BLOCK is the partner to VALGRIND_MALLOCLIKE_BLOCK. For + Memcheck, it does two things: + + - It records that the block has been deallocated. This assumes that the + block was annotated as having been allocated via + VALGRIND_MALLOCLIKE_BLOCK. Otherwise, an error will be issued. + + - It marks the block as being unaddressable. + + VALGRIND_FREELIKE_BLOCK should be put immediately after the point where a + heap block is deallocated. + + VALGRIND_RESIZEINPLACE_BLOCK informs a tool about reallocation. For + Memcheck, it does four things: + + - It records that the size of a block has been changed. This assumes that + the block was annotated as having been allocated via + VALGRIND_MALLOCLIKE_BLOCK. Otherwise, an error will be issued. + + - If the block shrunk, it marks the freed memory as being unaddressable. + + - If the block grew, it marks the new area as undefined and defines a red + zone past the end of the new block. + + - The V-bits of the overlap between the old and the new block are preserved. + + VALGRIND_RESIZEINPLACE_BLOCK should be put after allocation of the new block + and before deallocation of the old block. + + In many cases, these three client requests will not be enough to get your + allocator working well with Memcheck. More specifically, if your allocator + writes to freed blocks in any way then a VALGRIND_MAKE_MEM_UNDEFINED call + will be necessary to mark the memory as addressable just before the zeroing + occurs, otherwise you'll get a lot of invalid write errors. For example, + you'll need to do this if your allocator recycles freed blocks, but it + zeroes them before handing them back out (via VALGRIND_MALLOCLIKE_BLOCK). + Alternatively, if your allocator reuses freed blocks for allocator-internal + data structures, VALGRIND_MAKE_MEM_UNDEFINED calls will also be necessary. + + Really, what's happening is a blurring of the lines between the client + program and the allocator... after VALGRIND_FREELIKE_BLOCK is called, the + memory should be considered unaddressable to the client program, but the + allocator knows more than the rest of the client program and so may be able + to safely access it. Extra client requests are necessary for Valgrind to + understand the distinction between the allocator and the rest of the + program. + + Ignored if addr == 0. +*/ +#define VALGRIND_MALLOCLIKE_BLOCK(addr, sizeB, rzB, is_zeroed) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MALLOCLIKE_BLOCK, \ + addr, sizeB, rzB, is_zeroed, 0) + +/* See the comment for VALGRIND_MALLOCLIKE_BLOCK for details. + Ignored if addr == 0. +*/ +#define VALGRIND_RESIZEINPLACE_BLOCK(addr, oldSizeB, newSizeB, rzB) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__RESIZEINPLACE_BLOCK, \ + addr, oldSizeB, newSizeB, rzB, 0) + +/* See the comment for VALGRIND_MALLOCLIKE_BLOCK for details. + Ignored if addr == 0. +*/ +#define VALGRIND_FREELIKE_BLOCK(addr, rzB) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__FREELIKE_BLOCK, \ + addr, rzB, 0, 0, 0) + +/* Create a memory pool. */ +#define VALGRIND_CREATE_MEMPOOL(pool, rzB, is_zeroed) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CREATE_MEMPOOL, \ + pool, rzB, is_zeroed, 0, 0) + +/* Create a memory pool with some flags specifying extended behaviour. + When flags is zero, the behaviour is identical to VALGRIND_CREATE_MEMPOOL. + + The flag VALGRIND_MEMPOOL_METAPOOL specifies that the pieces of memory + associated with the pool using VALGRIND_MEMPOOL_ALLOC will be used + by the application as superblocks to dole out MALLOC_LIKE blocks using + VALGRIND_MALLOCLIKE_BLOCK. In other words, a meta pool is a "2 levels" + pool : first level is the blocks described by VALGRIND_MEMPOOL_ALLOC. + The second level blocks are described using VALGRIND_MALLOCLIKE_BLOCK. + Note that the association between the pool and the second level blocks + is implicit : second level blocks will be located inside first level + blocks. It is necessary to use the VALGRIND_MEMPOOL_METAPOOL flag + for such 2 levels pools, as otherwise valgrind will detect overlapping + memory blocks, and will abort execution (e.g. during leak search). + + Such a meta pool can also be marked as an 'auto free' pool using the flag + VALGRIND_MEMPOOL_AUTO_FREE, which must be OR-ed together with the + VALGRIND_MEMPOOL_METAPOOL. For an 'auto free' pool, VALGRIND_MEMPOOL_FREE + will automatically free the second level blocks that are contained + inside the first level block freed with VALGRIND_MEMPOOL_FREE. + In other words, calling VALGRIND_MEMPOOL_FREE will cause implicit calls + to VALGRIND_FREELIKE_BLOCK for all the second level blocks included + in the first level block. + Note: it is an error to use the VALGRIND_MEMPOOL_AUTO_FREE flag + without the VALGRIND_MEMPOOL_METAPOOL flag. +*/ +#define VALGRIND_MEMPOOL_AUTO_FREE 1 +#define VALGRIND_MEMPOOL_METAPOOL 2 +#define VALGRIND_CREATE_MEMPOOL_EXT(pool, rzB, is_zeroed, flags) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CREATE_MEMPOOL, \ + pool, rzB, is_zeroed, flags, 0) + +/* Destroy a memory pool. */ +#define VALGRIND_DESTROY_MEMPOOL(pool) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DESTROY_MEMPOOL, \ + pool, 0, 0, 0, 0) + +/* Associate a piece of memory with a memory pool. */ +#define VALGRIND_MEMPOOL_ALLOC(pool, addr, size) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_ALLOC, \ + pool, addr, size, 0, 0) + +/* Disassociate a piece of memory from a memory pool. */ +#define VALGRIND_MEMPOOL_FREE(pool, addr) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_FREE, \ + pool, addr, 0, 0, 0) + +/* Disassociate any pieces outside a particular range. */ +#define VALGRIND_MEMPOOL_TRIM(pool, addr, size) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_TRIM, \ + pool, addr, size, 0, 0) + +/* Resize and/or move a piece associated with a memory pool. */ +#define VALGRIND_MOVE_MEMPOOL(poolA, poolB) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MOVE_MEMPOOL, \ + poolA, poolB, 0, 0, 0) + +/* Resize and/or move a piece associated with a memory pool. */ +#define VALGRIND_MEMPOOL_CHANGE(pool, addrA, addrB, size) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_CHANGE, \ + pool, addrA, addrB, size, 0) + +/* Return 1 if a mempool exists, else 0. */ +#define VALGRIND_MEMPOOL_EXISTS(pool) \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + VG_USERREQ__MEMPOOL_EXISTS, \ + pool, 0, 0, 0, 0) + +/* Mark a piece of memory as being a stack. Returns a stack id. + start is the lowest addressable stack byte, end is the highest + addressable stack byte. */ +#define VALGRIND_STACK_REGISTER(start, end) \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + VG_USERREQ__STACK_REGISTER, \ + start, end, 0, 0, 0) + +/* Unmark the piece of memory associated with a stack id as being a + stack. */ +#define VALGRIND_STACK_DEREGISTER(id) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STACK_DEREGISTER, \ + id, 0, 0, 0, 0) + +/* Change the start and end address of the stack id. + start is the new lowest addressable stack byte, end is the new highest + addressable stack byte. */ +#define VALGRIND_STACK_CHANGE(id, start, end) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STACK_CHANGE, \ + id, start, end, 0, 0) + +/* Load PDB debug info for Wine PE image_map. */ +#define VALGRIND_LOAD_PDB_DEBUGINFO(fd, ptr, total_size, delta) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__LOAD_PDB_DEBUGINFO, \ + fd, ptr, total_size, delta, 0) + +/* Map a code address to a source file name and line number. buf64 + must point to a 64-byte buffer in the caller's address space. The + result will be dumped in there and is guaranteed to be zero + terminated. If no info is found, the first byte is set to zero. */ +#define VALGRIND_MAP_IP_TO_SRCLOC(addr, buf64) \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + VG_USERREQ__MAP_IP_TO_SRCLOC, \ + addr, buf64, 0, 0, 0) + +/* Disable error reporting for this thread. Behaves in a stack like + way, so you can safely call this multiple times provided that + VALGRIND_ENABLE_ERROR_REPORTING is called the same number of times + to re-enable reporting. The first call of this macro disables + reporting. Subsequent calls have no effect except to increase the + number of VALGRIND_ENABLE_ERROR_REPORTING calls needed to re-enable + reporting. Child threads do not inherit this setting from their + parents -- they are always created with reporting enabled. */ +#define VALGRIND_DISABLE_ERROR_REPORTING \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CHANGE_ERR_DISABLEMENT, \ + 1, 0, 0, 0, 0) + +/* Re-enable error reporting, as per comments on + VALGRIND_DISABLE_ERROR_REPORTING. */ +#define VALGRIND_ENABLE_ERROR_REPORTING \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CHANGE_ERR_DISABLEMENT, \ + -1, 0, 0, 0, 0) + +/* Execute a monitor command from the client program. + If a connection is opened with GDB, the output will be sent + according to the output mode set for vgdb. + If no connection is opened, output will go to the log output. + Returns 1 if command not recognised, 0 otherwise. */ +#define VALGRIND_MONITOR_COMMAND(command) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0, VG_USERREQ__GDB_MONITOR_COMMAND, \ + command, 0, 0, 0, 0) + + +/* Change the value of a dynamic command line option. + Note that unknown or not dynamically changeable options + will cause a warning message to be output. */ +#define VALGRIND_CLO_CHANGE(option) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CLO_CHANGE, \ + option, 0, 0, 0, 0) + + +#undef PLAT_x86_darwin +#undef PLAT_amd64_darwin +#undef PLAT_x86_win32 +#undef PLAT_amd64_win64 +#undef PLAT_x86_linux +#undef PLAT_amd64_linux +#undef PLAT_ppc32_linux +#undef PLAT_ppc64be_linux +#undef PLAT_ppc64le_linux +#undef PLAT_arm_linux +#undef PLAT_s390x_linux +#undef PLAT_mips32_linux +#undef PLAT_mips64_linux +#undef PLAT_nanomips_linux +#undef PLAT_x86_solaris +#undef PLAT_amd64_solaris + +#endif /* __VALGRIND_H */ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ab5e7ce7f1c55a7b8bfff0bda646a4635231871 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__init__.py @@ -0,0 +1,78 @@ +from torch.utils.data.dataloader import ( + _DatasetKind, + DataLoader, + default_collate, + default_convert, + get_worker_info, +) +from torch.utils.data.datapipes._decorator import ( + argument_validation, + functional_datapipe, + guaranteed_datapipes_determinism, + non_deterministic, + runtime_validation, + runtime_validation_disabled, +) +from torch.utils.data.datapipes.datapipe import ( + DataChunk, + DFIterDataPipe, + IterDataPipe, + MapDataPipe, +) +from torch.utils.data.dataset import ( + ChainDataset, + ConcatDataset, + Dataset, + IterableDataset, + random_split, + StackDataset, + Subset, + TensorDataset, +) +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import ( + BatchSampler, + RandomSampler, + Sampler, + SequentialSampler, + SubsetRandomSampler, + WeightedRandomSampler, +) + + +__all__ = [ + "BatchSampler", + "ChainDataset", + "ConcatDataset", + "DFIterDataPipe", + "DataChunk", + "DataLoader", + "Dataset", + "DistributedSampler", + "IterDataPipe", + "IterableDataset", + "MapDataPipe", + "RandomSampler", + "Sampler", + "SequentialSampler", + "StackDataset", + "Subset", + "SubsetRandomSampler", + "TensorDataset", + "WeightedRandomSampler", + "_DatasetKind", + "argument_validation", + "default_collate", + "default_convert", + "functional_datapipe", + "get_worker_info", + "guaranteed_datapipes_determinism", + "non_deterministic", + "random_split", + "runtime_validation", + "runtime_validation_disabled", +] + +# Please keep this list sorted +if __all__ != sorted(__all__): + raise AssertionError("__all__ is not sorted") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e18fc8d7e045f642a5b87c128bb6081c0e576f80 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/backward_compatibility.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/backward_compatibility.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11d925d2088bc296713be7f0763a607e15456c44 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/backward_compatibility.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataloader.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataloader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9488c33a4b050724c3c9cf87c6dbf892e2307d4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataloader.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataset.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataset.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5aa0e4ecb35897d83c178225641e316404ae3956 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataset.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/distributed.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/distributed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8cbb9ded9590d2fd40738c8fc35816447d8e0fe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/distributed.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..153da05303f9dbcd721525a5b98b6b7a11a423a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph_settings.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph_settings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b3bcecc0427a246e1640815f8d9684dc5c6d10f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph_settings.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/sampler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/sampler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..323e9f05eb157b92aa869a70875da5dea57d1d92 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/__pycache__/sampler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44111ef697b7188df38711db1add2b8e0de4a293 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__init__.py @@ -0,0 +1,53 @@ +r"""Utility classes & functions for data loading. Code in this folder is mostly used by ../dataloder.py. + +A lot of multiprocessing is used in data loading, which only supports running +functions defined in global environment (py2 can't serialize static methods). +Therefore, for code tidiness we put these functions into different files in this +folder. +""" + +import atexit +import sys + +# old private location of the ExceptionWrapper that some users rely on: +from torch._utils import ExceptionWrapper + + +IS_WINDOWS = sys.platform == "win32" + + +MP_STATUS_CHECK_INTERVAL = 5.0 +r"""Interval (in seconds) to check status of processes to avoid hanging in + multiprocessing data loading. This is mainly used in getting data from + another process, in which case we need to periodically check whether the + sender is alive to prevent hanging.""" + + +python_exit_status = False +r"""Whether Python is shutting down. This flag is guaranteed to be set before +the Python core library resources are freed, but Python may already be exiting +for some time when this is set. + +Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar +hook in Python 3.7 multiprocessing library: +https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327 +""" + + +try: + import numpy + + HAS_NUMPY = True +except ModuleNotFoundError: + HAS_NUMPY = False + + +def _set_python_exit_flag() -> None: + global python_exit_status + python_exit_status = True + + +atexit.register(_set_python_exit_flag) + + +from . import collate, fetch, pin_memory, signal_handling, worker diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bba82f54918b5560edc702b6a3619a9dae742502 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb5eae39b122825a05b291f36bdff96cc7ea59b2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f62fc4dba9d582acd8548c0233b29997a1c687a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5393d9d954f99694820972e500ad4e8de6baff61 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4097d5a09d6e50db162726cf2f4fc515585c253 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b22dec9db3af91b6926868839a302256667db289 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..733e84a9afae622a3d2f3bc7637184e31436d46c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py @@ -0,0 +1,401 @@ +# mypy: allow-untyped-defs +r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers. + +These methods are used to collate samples fetched from dataset into Tensor(s). +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. + +`default_collate` and `default_convert` are exposed to users via 'dataloader.py'. +""" + +import collections +import contextlib +import copy +import re +from collections.abc import Callable + +import torch + + +np_str_obj_array_pattern = re.compile(r"[SaUO]") + + +def default_convert(data): + r""" + Convert each NumPy array element into a :class:`torch.Tensor`. + + If the input is a `Sequence`, `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`. + If the input is not an NumPy array, it is left unchanged. + This is used as the default function for collation when both `batch_sampler` and `batch_size` + are NOT defined in :class:`~torch.utils.data.DataLoader`. + + The general input type to output type mapping is similar to that + of :func:`~torch.utils.data.default_collate`. See the description there for more details. + + Args: + data: a single data point to be converted + + Examples: + >>> # xdoctest: +SKIP + >>> # Example with `int` + >>> default_convert(0) + 0 + >>> # Example with NumPy array + >>> default_convert(np.array([0, 1])) + tensor([0, 1]) + >>> # Example with NamedTuple + >>> Point = namedtuple("Point", ["x", "y"]) + >>> default_convert(Point(0, 0)) + Point(x=0, y=0) + >>> default_convert(Point(np.array(0), np.array(0))) + Point(x=tensor(0), y=tensor(0)) + >>> # Example with List + >>> default_convert([np.array([0, 1]), np.array([2, 3])]) + [tensor([0, 1]), tensor([2, 3])] + """ + elem_type = type(data) + if isinstance(data, torch.Tensor): + return data + elif ( + elem_type.__module__ == "numpy" + and elem_type.__name__ != "str_" + and elem_type.__name__ != "string_" + ): + # array of string classes and object + if ( + elem_type.__name__ == "ndarray" + and np_str_obj_array_pattern.search(data.dtype.str) is not None + ): + return data + return torch.as_tensor(data) + elif isinstance(data, collections.abc.Mapping): + try: + if isinstance(data, collections.abc.MutableMapping): + # The mapping type may have extra properties, so we can't just + # use `type(data)(...)` to create the new mapping. + # Create a clone and update it if the mapping type is mutable. + clone = copy.copy(data) + clone.update({key: default_convert(data[key]) for key in data}) + return clone + else: + return elem_type({key: default_convert(data[key]) for key in data}) + except TypeError: + # The mapping type may not support `copy()` / `update(mapping)` + # or `__init__(iterable)`. + return {key: default_convert(data[key]) for key in data} + elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple + return elem_type(*(default_convert(d) for d in data)) + elif isinstance(data, tuple): + return [default_convert(d) for d in data] # Backwards compatibility. + elif isinstance(data, collections.abc.Sequence) and not isinstance( + data, (str, bytes) + ): + try: + if isinstance(data, collections.abc.MutableSequence): + # The sequence type may have extra properties, so we can't just + # use `type(data)(...)` to create the new sequence. + # Create a clone and update it if the sequence type is mutable. + clone = copy.copy(data) # type: ignore[arg-type] + for i, d in enumerate(data): + clone[i] = default_convert(d) + return clone + else: + return elem_type([default_convert(d) for d in data]) + except TypeError: + # The sequence type may not support `copy()` / `__setitem__(index, item)` + # or `__init__(iterable)` (e.g., `range`). + return [default_convert(d) for d in data] + else: + return data + + +default_collate_err_msg_format = ( + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}" +) + + +def collate( + batch, + *, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, +): + r""" + General collate function that handles collection type of element within each batch. + + The function also opens function registry to deal with specific element types. `default_collate_fn_map` + provides default collate functions for tensors, numpy arrays, numbers and strings. + + Args: + batch: a single batch to be collated + collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function. + If the element type isn't present in this dictionary, + this function will go through each key of the dictionary in the insertion order to + invoke the corresponding collate function if the element type is a subclass of the key. + + Examples: + >>> def collate_tensor_fn(batch, *, collate_fn_map): + ... # Extend this function to handle batch of tensors + ... return torch.stack(batch, 0) + >>> def custom_collate(batch): + ... collate_map = {torch.Tensor: collate_tensor_fn} + ... return collate(batch, collate_fn_map=collate_map) + >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map` + >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn}) + + Note: + Each collate function requires a positional argument for batch and a keyword argument + for the dictionary of collate functions as `collate_fn_map`. + """ + elem = batch[0] + elem_type = type(elem) + + if collate_fn_map is not None: + if elem_type in collate_fn_map: + return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) + + for collate_type in collate_fn_map: + if isinstance(elem, collate_type): + return collate_fn_map[collate_type]( + batch, collate_fn_map=collate_fn_map + ) + + if isinstance(elem, collections.abc.Mapping): + try: + if isinstance(elem, collections.abc.MutableMapping): + # The mapping type may have extra properties, so we can't just + # use `type(data)(...)` to create the new mapping. + # Create a clone and update it if the mapping type is mutable. + clone = copy.copy(elem) + clone.update( + { + key: collate( + [d[key] for d in batch], collate_fn_map=collate_fn_map + ) + for key in elem + } + ) + return clone + else: + return elem_type( + { + key: collate( + [d[key] for d in batch], collate_fn_map=collate_fn_map + ) + for key in elem + } + ) + except TypeError: + # The mapping type may not support `copy()` / `update(mapping)` + # or `__init__(iterable)`. + return { + key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) + for key in elem + } + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple + return elem_type( + *( + collate(samples, collate_fn_map=collate_fn_map) + for samples in zip(*batch, strict=False) + ) + ) + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + # pyrefly: ignore [not-iterable] + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError("each element in list of batch should be of equal size") + transposed = list( + zip(*batch, strict=False) + ) # It may be accessed twice, so we use a list. + + if isinstance(elem, tuple): + return [ + collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ] # Backwards compatibility. + else: + try: + if isinstance(elem, collections.abc.MutableSequence): + # The sequence type may have extra properties, so we can't just + # use `type(data)(...)` to create the new sequence. + # Create a clone and update it if the sequence type is mutable. + clone = copy.copy(elem) # type: ignore[arg-type] + for i, samples in enumerate(transposed): + clone[i] = collate(samples, collate_fn_map=collate_fn_map) + return clone + else: + return elem_type( + [ + collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ] + ) + except TypeError: + # The sequence type may not support `copy()` / `__setitem__(index, item)` + # or `__init__(iterable)` (e.g., `range`). + return [ + collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) + + +def collate_tensor_fn( + batch, + *, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, +): + elem = batch[0] + out = None + if elem.is_nested: + raise RuntimeError( + "Batches of nested tensors are not currently supported by the default collate_fn; " + "please provide a custom collate_fn to handle them appropriately." + ) + if elem.layout in { + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_bsr, + torch.sparse_csc, + torch.sparse_bsc, + }: + raise RuntimeError( + "Batches of sparse tensors are not currently supported by the default collate_fn; " + "please provide a custom collate_fn to handle them appropriately." + ) + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum(x.numel() for x in batch) + storage = elem._typed_storage()._new_shared(numel, device=elem.device) + out = elem.new(storage).resize_(len(batch), *list(elem.size())) + return torch.stack(batch, 0, out=out) + + +def collate_numpy_array_fn( + batch, + *, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, +): + elem = batch[0] + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map) + + +def collate_numpy_scalar_fn( + batch, + *, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, +): + return torch.as_tensor(batch) + + +def collate_float_fn( + batch, + *, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, +): + return torch.tensor(batch, dtype=torch.float64) + + +def collate_int_fn( + batch, + *, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, +): + return torch.tensor(batch) + + +def collate_str_fn( + batch, + *, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, +): + return batch + + +default_collate_fn_map: dict[type | tuple[type, ...], Callable] = { + torch.Tensor: collate_tensor_fn +} +with contextlib.suppress(ImportError): + import numpy as np + + # For both ndarray and memmap (subclass of ndarray) + default_collate_fn_map[np.ndarray] = collate_numpy_array_fn + # See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html + # Skip string scalars + default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn +default_collate_fn_map[float] = collate_float_fn +default_collate_fn_map[int] = collate_int_fn +default_collate_fn_map[str] = collate_str_fn +default_collate_fn_map[bytes] = collate_str_fn + + +def default_collate(batch): + r""" + Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size. + + The exact output type can be a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a + Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type. + This is used as the default function for collation when + `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`. + + Here is the general input type (based on the type of the element within the batch) to output type mapping: + + * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size) + * NumPy Arrays -> :class:`torch.Tensor` + * `float` -> :class:`torch.Tensor` + * `int` -> :class:`torch.Tensor` + * `str` -> `str` (unchanged) + * `bytes` -> `bytes` (unchanged) + * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]` + * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]), + default_collate([V2_1, V2_2, ...]), ...]` + * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]), + default_collate([V2_1, V2_2, ...]), ...]` + + Args: + batch: a single batch to be collated + + Examples: + >>> # xdoctest: +SKIP + >>> # Example with a batch of `int`s: + >>> default_collate([0, 1, 2, 3]) + tensor([0, 1, 2, 3]) + >>> # Example with a batch of `str`s: + >>> default_collate(["a", "b", "c"]) + ['a', 'b', 'c'] + >>> # Example with `Map` inside the batch: + >>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}]) + {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} + >>> # Example with `NamedTuple` inside the batch: + >>> Point = namedtuple("Point", ["x", "y"]) + >>> default_collate([Point(0, 0), Point(1, 1)]) + Point(x=tensor([0, 1]), y=tensor([0, 1])) + >>> # Example with `Tuple` inside the batch: + >>> default_collate([(0, 1), (2, 3)]) + [tensor([0, 2]), tensor([1, 3])] + >>> # Example with `List` inside the batch: + >>> default_collate([[0, 1], [2, 3]]) + [tensor([0, 2]), tensor([1, 3])] + >>> # Two options to extend `default_collate` to handle specific type + >>> # Option 1: Write custom collate function and invoke `default_collate` + >>> def custom_collate(batch): + ... elem = batch[0] + ... if isinstance(elem, CustomType): # Some custom condition + ... return ... + ... else: # Fall back to `default_collate` + ... return default_collate(batch) + >>> # Option 2: In-place modify `default_collate_fn_map` + >>> def collate_customtype_fn(batch, *, collate_fn_map=None): + ... return ... + >>> default_collate_fn_map.update(CustomType, collate_customtype_fn) + >>> default_collate(batch) # Handle `CustomType` automatically + """ + return collate(batch, collate_fn_map=default_collate_fn_map) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcd0ec5b30731269fc304b5ef2e087d94dc3211 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py @@ -0,0 +1,57 @@ +# mypy: allow-untyped-defs +r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset. + +This logic is shared in both single- and multi-processing data loading. +""" + +from typing import NoReturn + + +class _BaseDatasetFetcher: + def __init__(self, dataset, auto_collation, collate_fn, drop_last) -> None: + self.dataset = dataset + self.auto_collation = auto_collation + self.collate_fn = collate_fn + self.drop_last = drop_last + + def fetch(self, possibly_batched_index) -> NoReturn: + raise NotImplementedError + + +class _IterableDatasetFetcher(_BaseDatasetFetcher): + def __init__(self, dataset, auto_collation, collate_fn, drop_last) -> None: + super().__init__(dataset, auto_collation, collate_fn, drop_last) + self.dataset_iter = iter(dataset) + self.ended = False + + def fetch(self, possibly_batched_index): + if self.ended: + raise StopIteration + + if self.auto_collation: + data = [] + for _ in possibly_batched_index: + try: + data.append(next(self.dataset_iter)) + except StopIteration: + self.ended = True + break + if len(data) == 0 or ( + self.drop_last and len(data) < len(possibly_batched_index) + ): + raise StopIteration + else: + data = next(self.dataset_iter) + return self.collate_fn(data) + + +class _MapDatasetFetcher(_BaseDatasetFetcher): + def fetch(self, possibly_batched_index): + if self.auto_collation: + if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: + data = self.dataset.__getitems__(possibly_batched_index) + else: + data = [self.dataset[idx] for idx in possibly_batched_index] + else: + data = self.dataset[possibly_batched_index] + return self.collate_fn(data) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..a7646ea7677c1f770b413ae18ed055e79e41b189 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py @@ -0,0 +1,102 @@ +# mypy: allow-untyped-defs +r"""Contains definitions of the methods used by the _BaseDataLoaderIter to put fetched tensors into pinned memory. + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import collections +import copy +import queue + +import torch +from torch._utils import ExceptionWrapper + +from . import MP_STATUS_CHECK_INTERVAL + + +def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device) -> None: + # This setting is thread local, and prevents the copy in pin_memory from + # consuming all CPU cores. + torch.set_num_threads(1) + + torch.multiprocessing._set_thread_name("pt_data_pin") + torch.accelerator.set_device_index(device_id) + + def do_one_step() -> None: + try: + r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + return + idx, data = r + if not done_event.is_set() and not isinstance(data, ExceptionWrapper): + try: + data = pin_memory(data, device) + except Exception: + data = ExceptionWrapper( + where=f"in pin memory thread for device {device_id}" + ) + r = (idx, data) + while not done_event.is_set(): + try: + out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) + break + except queue.Full: + continue + + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + while not done_event.is_set(): + # Make sure that we don't preserve any object from one iteration + # to the next + do_one_step() + + +def pin_memory(data, device=None): + if isinstance(data, torch.Tensor): + return data.pin_memory(device) + elif isinstance(data, (str, bytes)): + return data + elif isinstance(data, collections.abc.Mapping): + try: + if isinstance(data, collections.abc.MutableMapping): + # The sequence type may have extra properties, so we can't just + # use `type(data)(...)` to create the new sequence. + # Create a clone and update it if the sequence type is mutable. + clone = copy.copy(data) + clone.update( + {k: pin_memory(sample, device) for k, sample in data.items()} + ) + return clone + else: + return type(data)( + # pyrefly: ignore [bad-argument-count] + {k: pin_memory(sample, device) for k, sample in data.items()} + ) # type: ignore[call-arg] + except TypeError: + # The mapping type may not support `copy()` / `update(mapping)` + # or `__init__(iterable)`. + return {k: pin_memory(sample, device) for k, sample in data.items()} + elif isinstance(data, tuple): + if hasattr(data, "_fields"): # namedtuple + return type(data)(*(pin_memory(sample, device) for sample in data)) + return type(data)(pin_memory(sample, device) for sample in data) + elif isinstance(data, collections.abc.Sequence): + try: + if isinstance(data, collections.abc.MutableSequence): + # The sequence type may have extra properties, so we can't just + # use `type(data)(...)` to create the new sequence. + # Create a clone and update it if the sequence type is mutable. + clone = copy.copy(data) # type: ignore[arg-type] + for i, item in enumerate(data): + clone[i] = pin_memory(item, device) + return clone + return type(data)([pin_memory(sample, device) for sample in data]) # type: ignore[call-arg] + except TypeError: + # The sequence type may not support `copy()` / `__setitem__(index, item)` + # or `__init__(iterable)` (e.g., `range`). + return [pin_memory(sample, device) for sample in data] + elif hasattr(data, "pin_memory"): + return data.pin_memory() + else: + return data diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/signal_handling.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/signal_handling.py new file mode 100644 index 0000000000000000000000000000000000000000..abff09bc40819d83420a08e0b90d7ba816f4764f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/signal_handling.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +r"""Signal handling for multiprocessing data loading. + +NOTE [ Signal handling in multiprocessing data loading ] + +In cases like DataLoader, if a worker process dies due to bus error/segfault +or just hang, the main process will hang waiting for data. This is difficult +to avoid on PyTorch side as it can be caused by limited shm, or other +libraries users call in the workers. In this file and `DataLoader.cpp`, we make +our best effort to provide some error message to users when such unfortunate +events happen. + +When a _BaseDataLoaderIter starts worker processes, their pids are registered in a +defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ] +via `_set_worker_pids`. + +When an error happens in a worker process, the main process received a SIGCHLD, +and Python will eventually call the handler registered below +(in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails` +call checks all registered worker pids and raise proper error message to +prevent main process from hanging waiting for data from worker. + +Additionally, at the beginning of each worker's `_utils.worker._worker_loop`, +`_set_worker_signal_handlers` is called to register critical signal handlers +(e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error +message to stderr before triggering the default handler. So a message will also +be printed from the worker process when it is killed by such signals. + +See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of +this signal handling design and other mechanism we implement to make our +multiprocessing data loading robust to errors. +""" + +import signal +import threading + +# Some of the following imported functions are not used in this file, but are to +# be used `_utils.signal_handling.XXXXX`. +from torch._C import ( # noqa: F401 + _error_if_any_worker_fails, + _remove_worker_pids, + _set_worker_pids, + _set_worker_signal_handlers, +) + +from . import IS_WINDOWS + + +_SIGCHLD_handler_set = False +r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one +handler needs to be set for all DataLoaders in a process.""" + + +def _set_SIGCHLD_handler() -> None: + # Windows doesn't support SIGCHLD handler + if IS_WINDOWS: + return + # can't set signal in child threads + if not isinstance(threading.current_thread(), threading._MainThread): # type: ignore[attr-defined] + return + global _SIGCHLD_handler_set + if _SIGCHLD_handler_set: + return + previous_handler = signal.getsignal(signal.SIGCHLD) + if not callable(previous_handler): + # This doesn't catch default handler, but SIGCHLD default handler is a + # no-op. + previous_handler = None + + def handler(signum, frame) -> None: + # This following call uses `waitid` with WNOHANG from C side. Therefore, + # Python can still get and update the process status successfully. + _error_if_any_worker_fails() + if previous_handler is not None: + if not callable(previous_handler): + raise AssertionError("previous_handler is not callable") + previous_handler(signum, frame) + + signal.signal(signal.SIGCHLD, handler) + _SIGCHLD_handler_set = True diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..611aee4766bf451193152d9c6f20055889a2caae --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py @@ -0,0 +1,383 @@ +# mypy: allow-untyped-defs +r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers. + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import os +import queue +import random +from dataclasses import dataclass +from typing import Optional, TYPE_CHECKING + +import torch +from torch._utils import ExceptionWrapper + +from . import HAS_NUMPY, IS_WINDOWS, MP_STATUS_CHECK_INTERVAL, signal_handling + + +if TYPE_CHECKING: + from torch.utils.data import Dataset + +if IS_WINDOWS: + import ctypes + from ctypes.wintypes import BOOL, DWORD, HANDLE + + # On Windows, the parent ID of the worker process remains unchanged when the manager process + # is gone, and the only way to check it through OS is to let the worker have a process handle + # of the manager and ask if the process status has changed. + class ManagerWatchdog: + def __init__(self) -> None: + self.manager_pid = os.getppid() + + # mypy cannot detect this code is windows only + self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined] + self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) + self.kernel32.OpenProcess.restype = HANDLE + self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) + self.kernel32.WaitForSingleObject.restype = DWORD + + # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx + SYNCHRONIZE = 0x00100000 + self.manager_handle = self.kernel32.OpenProcess( + SYNCHRONIZE, 0, self.manager_pid + ) + + if not self.manager_handle: + raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined] + + self.manager_dead = False + + def is_alive(self) -> bool: + if not self.manager_dead: + # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx + self.manager_dead = ( + self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 + ) + return not self.manager_dead + +else: + + class ManagerWatchdog: # type: ignore[no-redef] + def __init__(self) -> None: + self.manager_pid = os.getppid() + self.manager_dead = False + + def is_alive(self) -> bool: + if not self.manager_dead: + self.manager_dead = os.getppid() != self.manager_pid + return not self.manager_dead + + +_worker_info: Optional["WorkerInfo"] = None + + +class WorkerInfo: + id: int + num_workers: int + seed: int + dataset: "Dataset" + __initialized = False + + def __init__(self, **kwargs) -> None: + for k, v in kwargs.items(): + setattr(self, k, v) + self.__keys = tuple(kwargs.keys()) + self.__initialized = True + + def __setattr__(self, key, val) -> None: + if self.__initialized: + raise RuntimeError( + f"Cannot assign attributes to {self.__class__.__name__} objects" + ) + return super().__setattr__(key, val) + + def __repr__(self) -> str: + items = [f"{k}={getattr(self, k)}" for k in self.__keys] + return f"{self.__class__.__name__}({', '.join(items)})" + + +def get_worker_info() -> WorkerInfo | None: + r"""Returns the information about the current + :class:`~torch.utils.data.DataLoader` iterator worker process. + + When called in a worker, this returns an object guaranteed to have the + following attributes: + + * :attr:`id`: the current worker id. + * :attr:`num_workers`: the total number of workers. + * :attr:`seed`: the random seed set for the current worker. This value is + determined by main process RNG and the worker id. See + :class:`~torch.utils.data.DataLoader`'s documentation for more details. + * :attr:`dataset`: the copy of the dataset object in **this** process. Note + that this will be a different object in a different process than the one + in the main process. + + When called in the main process, this returns ``None``. + + .. note:: + When used in a :attr:`worker_init_fn` passed over to + :class:`~torch.utils.data.DataLoader`, this method can be useful to + set up each worker process differently, for instance, using ``worker_id`` + to configure the ``dataset`` object to only read a specific fraction of a + sharded dataset, or use ``seed`` to seed other libraries used in dataset + code. + """ + return _worker_info + + +r"""Dummy class used to signal the end of an IterableDataset""" + + +@dataclass(frozen=True) +class _IterableDatasetStopIteration: + worker_id: int + + +r"""Dummy class used to resume the fetching when worker reuse is enabled""" + + +@dataclass(frozen=True) +class _ResumeIteration: + seed: int | None = None + + +# The function `_generate_state` is adapted from `numpy.random.SeedSequence` +# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx +# It's MIT licensed, here is the copyright: + +# Copyright (c) 2015 Melissa E. O'Neill +# Copyright (c) 2019 NumPy Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +# This function generates an array of int32 as the seed for +# `numpy.random`, in order to prevent state collision due to same +# seed and algorithm for `numpy.random` and `random` modules. +# TODO: Implement `SeedSequence` like object for `torch.random` +def _generate_state(base_seed, worker_id): + INIT_A = 0x43B0D7E5 + MULT_A = 0x931E8875 + INIT_B = 0x8B51F9DD + MULT_B = 0x58F38DED + MIX_MULT_L = 0xCA01F9DD + MIX_MULT_R = 0x4973F715 + XSHIFT = 4 * 8 // 2 + MASK32 = 0xFFFFFFFF + + entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0] + pool = [0] * 4 + + hash_const_A = INIT_A + + def hash(value): + nonlocal hash_const_A + value = (value ^ hash_const_A) & MASK32 + hash_const_A = (hash_const_A * MULT_A) & MASK32 + value = (value * hash_const_A) & MASK32 + value = (value ^ (value >> XSHIFT)) & MASK32 + return value + + def mix(x, y): + result_x = (MIX_MULT_L * x) & MASK32 + result_y = (MIX_MULT_R * y) & MASK32 + result = (result_x - result_y) & MASK32 + result = (result ^ (result >> XSHIFT)) & MASK32 + return result + + # Add in the entropy to the pool. + for i in range(len(pool)): + pool[i] = hash(entropy[i]) + + # Mix all bits together so late bits can affect earlier bits. + for i_src in range(len(pool)): + for i_dst in range(len(pool)): + if i_src != i_dst: + pool[i_dst] = mix(pool[i_dst], hash(pool[i_src])) + + hash_const_B = INIT_B + state = [] + for i_dst in range(4): + data_val = pool[i_dst] + data_val = (data_val ^ hash_const_B) & MASK32 + hash_const_B = (hash_const_B * MULT_B) & MASK32 + data_val = (data_val * hash_const_B) & MASK32 + data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32 + state.append(data_val) + return state + + +def _worker_loop( + dataset_kind, + dataset, + index_queue, + data_queue, + done_event, + auto_collation, + collate_fn, + drop_last, + base_seed, + init_fn, + worker_id, + num_workers, + persistent_workers, + shared_seed, +) -> None: + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + + try: + # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal had already happened + # again. + # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers + signal_handling._set_worker_signal_handlers() + + torch.multiprocessing._set_thread_name("pt_data_worker") + + torch.set_num_threads(1) + seed = base_seed + worker_id + random.seed(seed) + torch.manual_seed(seed) + if HAS_NUMPY: + np_seed = _generate_state(base_seed, worker_id) + import numpy as np + + np.random.seed(np_seed) + + from torch.utils.data import IterDataPipe + from torch.utils.data.graph_settings import apply_random_seed + + shared_rng = torch.Generator() + if isinstance(dataset, IterDataPipe): + if shared_seed is None: + raise AssertionError( + "shared_seed must be provided for IterDataPipe workers" + ) + shared_rng.manual_seed(shared_seed) + dataset = apply_random_seed(dataset, shared_rng) + + global _worker_info + _worker_info = WorkerInfo( + id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset + ) + + from torch.utils.data import _DatasetKind + + init_exception = None + + try: + if init_fn is not None: + init_fn(worker_id) + + fetcher = _DatasetKind.create_fetcher( + dataset_kind, dataset, auto_collation, collate_fn, drop_last + ) + except Exception: + init_exception = ExceptionWrapper( + where=f"in DataLoader worker process {worker_id}" + ) + + # When using Iterable mode, some worker can exit earlier than others due + # to the IterableDataset behaving differently for different workers. + # When such things happen, an `_IterableDatasetStopIteration` object is + # sent over to the main process with the ID of this worker, so that the + # main process won't send more tasks to this worker, and will send + # `None` to this worker to properly exit it. + # + # Note that we cannot set `done_event` from a worker as it is shared + # among all processes. Instead, we set the `iteration_end` flag to + # signify that the iterator is exhausted. When either `done_event` or + # `iteration_end` is set, we skip all processing step and just wait for + # `None`. + iteration_end = False + + watchdog = ManagerWatchdog() + + while watchdog.is_alive(): + try: + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + if isinstance(r, _ResumeIteration): + # Acknowledge the main process + data_queue.put((r, None)) + iteration_end = False + + if isinstance(dataset, IterDataPipe): + if r.seed is None: + raise AssertionError( + "resume iteration seed is None for IterDataPipe" + ) + shared_rng.manual_seed(r.seed) + dataset = apply_random_seed(dataset, shared_rng) + + # Recreate the fetcher for worker-reuse policy + fetcher = _DatasetKind.create_fetcher( + dataset_kind, dataset, auto_collation, collate_fn, drop_last + ) + continue + elif r is None: + # Received the final signal + if not done_event.is_set() and not iteration_end: + raise AssertionError( + "Received final signal but neither done_event nor iteration_end is set" + ) + break + elif done_event.is_set() or iteration_end: + # `done_event` is set. But I haven't received the final signal + # (None) yet. I will keep continuing until get it, and skip the + # processing steps. + continue + idx, index = r + data: _IterableDatasetStopIteration | ExceptionWrapper + if init_exception is not None: + data = init_exception + init_exception = None + else: + try: + data = fetcher.fetch(index) # type: ignore[possibly-undefined] + except Exception as e: + if ( + isinstance(e, StopIteration) + and dataset_kind == _DatasetKind.Iterable + ): + data = _IterableDatasetStopIteration(worker_id) + # Set `iteration_end` + # (1) to save future `next(...)` calls, and + # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. + iteration_end = True + else: + # It is important that we don't store exc_info in a variable. + # `ExceptionWrapper` does the correct thing. + # See NOTE [ Python Traceback Reference Cycle Problem ] + data = ExceptionWrapper( + where=f"in DataLoader worker process {worker_id}" + ) + data_queue.put((idx, data)) + del data, idx, index, r # save memory + except KeyboardInterrupt: + # Main process will raise KeyboardInterrupt anyways. + pass + if done_event.is_set(): + data_queue.cancel_join_thread() + data_queue.close() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/backward_compatibility.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/backward_compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..5b928aea69fa7a7033a82021c5f41e053ff962fa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/backward_compatibility.py @@ -0,0 +1,11 @@ +# mypy: allow-untyped-defs +from typing_extensions import deprecated as _deprecated + + +@_deprecated( + "Usage of `backward_compatibility.worker_init_fn` is deprecated " + "as `DataLoader` automatically applies sharding in every worker", + category=FutureWarning, +) +def worker_init_fn(worker_id) -> None: + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/dataloader.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..9f2cd710faf6e7bc6df41e86169253f85357c83f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/dataloader.py @@ -0,0 +1,1707 @@ +# mypy: allow-untyped-defs +r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter. + +To support these two classes, in `./_utils` we define many utility methods and +functions to be run in multiprocessing. E.g., the data loading worker loop is +in `./_utils/worker.py`. +""" + +from __future__ import annotations + +import contextlib +import functools +import itertools +import logging +import multiprocessing as python_multiprocessing +import os +import queue +import threading +import warnings +from collections.abc import Callable +from typing import Any, Generic, NoReturn, TYPE_CHECKING, TypeVar +from typing_extensions import Self + +import torch +import torch.distributed as dist +import torch.utils.data.graph_settings +from torch._utils import ExceptionWrapper +from torch.utils.data import _utils +from torch.utils.data.datapipes.datapipe import ( + _IterDataPipeSerializationWrapper, + _MapDataPipeSerializationWrapper, + IterDataPipe, + MapDataPipe, +) +from torch.utils.data.dataset import Dataset, IterableDataset +from torch.utils.data.sampler import ( + BatchSampler, + RandomSampler, + Sampler, + SequentialSampler, +) + + +if TYPE_CHECKING: + from collections.abc import Iterable + +__all__ = [ + "DataLoader", + "get_worker_info", + "default_collate", + "default_convert", +] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_worker_init_fn_t = Callable[[int], None] + +# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that +# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. +# See https://github.com/python/mypy/issues/3737. +_collate_fn_t = Callable[[list[_T]], Any] + + +# These functions used to be defined in this file. However, it was moved to +# _utils/collate.py. Although it is rather hard to access this from user land +# (one has to explicitly directly `import torch.utils.data.dataloader`), there +# probably is user code out there using it. This aliasing maintains BC in this +# aspect. +default_collate: _collate_fn_t = _utils.collate.default_collate +default_convert = _utils.collate.default_convert + +get_worker_info = _utils.worker.get_worker_info + +logger = logging.getLogger(__name__) + + +class _DatasetKind: + Map = 0 + Iterable = 1 + + @staticmethod + def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): + if kind == _DatasetKind.Map: + return _utils.fetch._MapDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + else: + return _utils.fetch._IterableDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + + +class _InfiniteConstantSampler(Sampler): + r"""Analogous to ``itertools.repeat(None, None)``. + + Used as sampler for :class:`~torch.utils.data.IterableDataset`. + """ + + def __iter__(self): + while True: + yield None + + +def _get_distributed_settings(): + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size(), dist.get_rank() + else: + return 1, 0 + + +def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id) -> None: + global_worker_id = worker_id + info = torch.utils.data.get_worker_info() + if info is None: + raise AssertionError("Worker info is None in sharding worker init function") + total_workers = info.num_workers + datapipe = info.dataset + if not isinstance(datapipe, (IterDataPipe, MapDataPipe)): + raise AssertionError( + "datapipe must be an instance of IterDataPipe or MapDataPipe" + ) + # To distribute elements across distributed process evenly, we should shard data on distributed + # processes first then shard on worker processes + total_workers *= world_size + global_worker_id = global_worker_id * world_size + rank_id + # For BC, use default SHARDING_PRIORITIES + torch.utils.data.graph_settings.apply_sharding( + datapipe, total_workers, global_worker_id + ) + if worker_init_fn is not None: + worker_init_fn(worker_id) + + +def _share_dist_seed(generator, pg): + _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator) + if isinstance(pg, dist.ProcessGroup): + dist.broadcast(_shared_seed, src=0, group=pg) + return _shared_seed.item() + + +class DataLoader(Generic[_T_co]): + r""" + Data loader combines a dataset and a sampler, and provides an iterable over the given dataset. + + The :class:`~torch.utils.data.DataLoader` supports both map-style and + iterable-style datasets with single- or multi-process loading, customizing + loading order and optional automatic batching (collation) and memory pinning. + + See :py:mod:`torch.utils.data` documentation page for more details. + + Args: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: ``1``). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: ``False``). + sampler (Sampler or Iterable, optional): defines the strategy to draw + samples from the dataset. Can be any ``Iterable`` with ``__len__`` + implemented. If specified, :attr:`shuffle` must not be specified. + batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but + returns a batch of indices at a time. Mutually exclusive with + :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, + and :attr:`drop_last`. + num_workers (int, optional): how many subprocesses to use for data + loading. ``0`` means that the data will be loaded in the main process. + (default: ``0``) + collate_fn (Callable, optional): merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + pin_memory (bool, optional): If ``True``, the data loader will copy Tensors + into device/CUDA pinned memory before returning them. If your data elements + are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, + see the example below. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``False``) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn (Callable, optional): If not ``None``, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: ``None``) + multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If + ``None``, the default + `multiprocessing context `_ # noqa: D401 + of your operating system will + be used. (default: ``None``) + generator (torch.Generator, optional): If not ``None``, this RNG will be used + by RandomSampler to generate random indexes and multiprocessing to generate + ``base_seed`` for workers. (default: ``None``) + prefetch_factor (int, optional, keyword-only arg): Number of batches loaded + in advance by each worker. ``2`` means there will be a total of + 2 * num_workers batches prefetched across all workers. (default value depends + on the set value for num_workers. If value of num_workers=0 default is ``None``. + Otherwise, if value of ``num_workers > 0`` default is ``2``). + persistent_workers (bool, optional): If ``True``, the data loader will not shut down + the worker processes after a dataset has been consumed once. This allows to + maintain the workers `Dataset` instances alive. (default: ``False``) + pin_memory_device (str, optional): Deprecated, the current :ref:`accelerator` + will be used as the device if ``pin_memory=True``. + in_order (bool, optional): If ``False``, the data loader will not enforce that batches + are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``) + + + .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` + cannot be an unpicklable object, e.g., a lambda function. See + :ref:`multiprocessing-best-practices` on more details related + to multiprocessing in PyTorch. + + .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. + When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`, + it instead returns an estimate based on ``len(dataset) / batch_size``, with proper + rounding depending on :attr:`drop_last`, regardless of multi-process loading + configurations. This represents the best guess PyTorch can make because PyTorch + trusts user :attr:`dataset` code in correctly handling multi-process + loading to avoid duplicate data. + + However, if sharding results in multiple workers having incomplete last batches, + this estimate can still be inaccurate, because (1) an otherwise complete batch can + be broken into multiple ones and (2) more than one batch worth of samples can be + dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such + cases in general. + + See `Dataset Types`_ for more details on these two types of datasets and how + :class:`~torch.utils.data.IterableDataset` interacts with + `Multi-process data loading`_. + + .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and + :ref:`data-loading-randomness` notes for random seed related questions. + + .. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data + distribution being fed to the trainer in cases with imbalanced data. + """ + + dataset: Dataset[_T_co] + batch_size: int | None + num_workers: int + pin_memory: bool + drop_last: bool + timeout: float + sampler: Sampler | Iterable + pin_memory_device: str + prefetch_factor: int | None + _iterator: _BaseDataLoaderIter | None + __initialized = False + + def __init__( + self, + dataset: Dataset[_T_co], + batch_size: int | None = 1, + shuffle: bool | None = None, + sampler: Sampler | Iterable | None = None, + batch_sampler: Sampler[list] | Iterable[list] | None = None, + num_workers: int = 0, + collate_fn: _collate_fn_t | None = None, + pin_memory: bool = False, + drop_last: bool = False, + timeout: float = 0, + worker_init_fn: _worker_init_fn_t | None = None, + multiprocessing_context=None, + generator=None, + *, + prefetch_factor: int | None = None, + persistent_workers: bool = False, + pin_memory_device: str = "", + in_order: bool = True, + ) -> None: + torch._C._log_api_usage_once("python.data_loader") + + if num_workers < 0: + raise ValueError( + "num_workers option should be non-negative; " + "use num_workers=0 to disable multiprocessing." + ) + + if timeout < 0: + raise ValueError("timeout option should be non-negative") + + if num_workers == 0 and prefetch_factor is not None: + raise ValueError( + "prefetch_factor option could only be specified in multiprocessing." + "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None." + ) + elif num_workers > 0 and prefetch_factor is None: + prefetch_factor = 2 + elif prefetch_factor is not None and prefetch_factor < 0: + raise ValueError("prefetch_factor option should be non-negative") + + if persistent_workers and num_workers == 0: + raise ValueError("persistent_workers option needs num_workers > 0") + + self.dataset = dataset + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.pin_memory = pin_memory + self.pin_memory_device = pin_memory_device + self.timeout = timeout + self.worker_init_fn = worker_init_fn + self.multiprocessing_context = multiprocessing_context + self.in_order = in_order + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler + if isinstance(self.dataset, IterDataPipe): + self.dataset = _IterDataPipeSerializationWrapper(self.dataset) + elif isinstance(self.dataset, MapDataPipe): + self.dataset = _MapDataPipeSerializationWrapper(self.dataset) + + # Arg-check dataset related before checking samplers because we want to + # tell users that iterable-style datasets are incompatible with custom + # samplers first, so that they don't learn that this combo doesn't work + # after spending time fixing the custom sampler errors. + if isinstance(dataset, IterableDataset): + self._dataset_kind = _DatasetKind.Iterable + # NOTE [ Custom Samplers and IterableDataset ] + # + # `IterableDataset` does not support custom `batch_sampler` or + # `sampler` since the key is irrelevant (unless we support + # generator-style dataset one day...). + # + # For `sampler`, we always create a dummy sampler. This is an + # infinite sampler even when the dataset may have an implemented + # finite `__len__` because in multi-process data loading, naive + # settings will return duplicated data (which may be desired), and + # thus using a sampler with length matching that of dataset will + # cause data lost (you may have duplicates of the first couple + # batches, but never see anything afterwards). Therefore, + # `Iterabledataset` always uses an infinite sampler, an instance of + # `_InfiniteConstantSampler` defined above. + # + # A custom `batch_sampler` essentially only controls the batch size. + # However, it is unclear how useful it would be since an iterable-style + # dataset can handle that within itself. Moreover, it is pointless + # in multi-process data loading as the assignment order of batches + # to workers is an implementation detail so users can not control + # how to batchify each worker's iterable. Thus, we disable this + # option. If this turns out to be useful in future, we can re-enable + # this, and support custom samplers that specify the assignments to + # specific workers. + if isinstance(dataset, IterDataPipe): + if shuffle is not None: + dataset = torch.utils.data.graph_settings.apply_shuffle_settings( + dataset, shuffle=shuffle + ) + # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default. + elif shuffle not in {False, None}: + raise ValueError( + f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}" + ) + + if sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}" + ) + elif batch_sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + f"batch_sampler option, but got batch_sampler={batch_sampler}" + ) + else: + shuffle = bool(shuffle) + self._dataset_kind = _DatasetKind.Map + + if sampler is not None and shuffle: + raise ValueError("sampler option is mutually exclusive with shuffle") + + if batch_sampler is not None: + # auto_collation with custom batch_sampler + if batch_size != 1 or shuffle or sampler is not None or drop_last: + raise ValueError( + "batch_sampler option is mutually exclusive " + "with batch_size, shuffle, sampler, and " + "drop_last" + ) + batch_size = None + drop_last = False + elif batch_size is None: + # no auto_collation + if drop_last: + raise ValueError( + "batch_size=None option disables auto-batching " + "and is mutually exclusive with drop_last" + ) + + if sampler is None: # give default samplers + if self._dataset_kind == _DatasetKind.Iterable: + # See NOTE [ Custom Samplers and IterableDataset ] + sampler = _InfiniteConstantSampler() + else: # map-style + if shuffle: + sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type] + else: + sampler = SequentialSampler(dataset) # type: ignore[arg-type] + + if batch_size is not None and batch_sampler is None: + # auto_collation without custom batch_sampler + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.batch_size = batch_size + self.drop_last = drop_last + self.sampler = sampler + self.batch_sampler = batch_sampler + self.generator = generator + + if collate_fn is None: + if self._auto_collation: + collate_fn = _utils.collate.default_collate + else: + collate_fn = _utils.collate.default_convert + + self.collate_fn = collate_fn + self.persistent_workers = persistent_workers + + self.__initialized = True + self._IterableDataset_len_called = ( + None # See NOTE [ IterableDataset and __len__ ] + ) + + self._iterator = None + + self.check_worker_number_rationality() + + torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined] + + def _get_iterator(self) -> _BaseDataLoaderIter: + if self.num_workers == 0: + return _SingleProcessDataLoaderIter(self) + else: + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIter(self) + + @property + def multiprocessing_context(self): + return self.__multiprocessing_context + + @multiprocessing_context.setter + def multiprocessing_context(self, multiprocessing_context) -> None: + if multiprocessing_context is not None: + if self.num_workers > 0: + if isinstance(multiprocessing_context, str): + valid_start_methods = torch.multiprocessing.get_all_start_methods() + if multiprocessing_context not in valid_start_methods: + raise ValueError( + "multiprocessing_context option " + f"should specify a valid start method in {valid_start_methods!r}, but got " + f"multiprocessing_context={multiprocessing_context!r}" + ) + multiprocessing_context = torch.multiprocessing.get_context( + multiprocessing_context + ) + + if not isinstance( + multiprocessing_context, python_multiprocessing.context.BaseContext + ): + raise TypeError( + "multiprocessing_context option should be a valid context " + "object or a string specifying the start method, but got " + f"multiprocessing_context={multiprocessing_context}" + ) + else: + raise ValueError( + "multiprocessing_context can only be used with " + "multi-process loading (num_workers > 0), but got " + f"num_workers={self.num_workers}" + ) + + self.__multiprocessing_context = multiprocessing_context + + def __setattr__(self, attr, val) -> None: + if self.__initialized and attr in ( + "batch_size", + "batch_sampler", + "sampler", + "drop_last", + "dataset", + "persistent_workers", + ): + raise ValueError( + f"{attr} attribute should not be set after {self.__class__.__name__} is initialized" + ) + + super().__setattr__(attr, val) + + def __iter__(self) -> _BaseDataLoaderIter: + # When using a single worker the returned iterator should be + # created every time to avoid resetting its state + # However, in the case of a multiple workers iterator + # the iterator is only created once in the lifetime of the + # DataLoader object so that workers can be reused + if self.persistent_workers and self.num_workers > 0: + if self._iterator is None: + self._iterator = self._get_iterator() + else: + self._iterator._reset(self) + return self._iterator + else: + return self._get_iterator() + + @property + def _auto_collation(self): + return self.batch_sampler is not None + + @property + def _index_sampler(self): + # The actual sampler used for generating indices for `_DatasetFetcher` + # (see _utils/fetch.py) to read data at each time. This would be + # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. + # We can't change `.sampler` and `.batch_sampler` attributes for BC + # reasons. + if self._auto_collation: + return self.batch_sampler + else: + return self.sampler + + def __len__(self) -> int: + if self._dataset_kind == _DatasetKind.Iterable: + # NOTE [ IterableDataset and __len__ ] + # + # For `IterableDataset`, `__len__` could be inaccurate when one naively + # does multi-processing data loading, since the samples will be duplicated. + # However, no real use case should be actually using that behavior, so + # it should count as a user error. We should generally trust user + # code to do the proper thing (e.g., configure each replica differently + # in `__iter__`), and give us the correct `__len__` if they choose to + # implement it (this will still throw if the dataset does not implement + # a `__len__`). + # + # To provide a further warning, we track if `__len__` was called on the + # `DataLoader`, save the returned value in `self._len_called`, and warn + # if the iterator ends up yielding more than this number of samples. + + # Cannot statically verify that dataset is Sized + length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type] + if ( + self.batch_size is not None + ): # IterableDataset doesn't allow custom sampler or batch_sampler + from math import ceil + + if self.drop_last: + length = length // self.batch_size + else: + length = ceil(length / self.batch_size) + return length + else: + return len(self._index_sampler) + + def check_worker_number_rationality(self) -> None: + # This function check whether the dataloader's worker number is rational based on + # current system's resource. Current rule is that if the number of workers this + # Dataloader will create is bigger than the number of logical cpus that is allowed to + # use, than we will pop up a warning to let user pay attention. + # + # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2 + # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current + # DataLoader process can use half of them which is 32, then the rational max number of + # worker that initiated from this process is 32. + # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32. + # So the warning message is triggered to notify the user to lower the worker number if + # necessary. + # + # + # [Note] Please note that this function respects `cpuset` only when os.sched_getaffinity is + # available (available in most of Linux system, but not OSX and Windows). + # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but + # it doesn't respect cpuset. + # We don't take threading into account since each worker process is single threaded + # at this time. + # + # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc) + # other than `torch.set_num_threads` to 1 in the worker process, if the passing + # in functions use 3rd party modules that rely on those threading flags to determine + # how many thread to create (eg. numpy, etc), then it is caller's responsibility to + # set those flags correctly. + def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): + suggested_max_worker_msg = ( + ( + ( + "Our suggested max number of worker in current system is {}{}, which is smaller " + "than what this DataLoader is going to create." + ).format( + num_worker_suggest, + ( + "" + if cpuset_checked + else " (`cpuset` is not taken into account)" + ), + ) + ) + if num_worker_suggest is not None + else ( + "DataLoader is not able to compute a suggested max number of worker in current system." + ) + ) + + warn_msg = ( + f"This DataLoader will create {num_worker_created} worker processes in total. {suggested_max_worker_msg} " + "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, " + "lower the worker number to avoid potential slowness/freeze if necessary." + ) + return warn_msg + + if not self.num_workers or self.num_workers == 0: + return + + # try to compute a suggested max number of worker based on system's resource + max_num_worker_suggest = None + cpuset_checked = False + if hasattr(os, "sched_getaffinity"): + try: + max_num_worker_suggest = len(os.sched_getaffinity(0)) + cpuset_checked = True + except Exception: + pass + if max_num_worker_suggest is None: + # os.cpu_count() could return Optional[int] + # get cpu count first and check None in order to satisfy mypy check + cpu_count = os.cpu_count() + if cpu_count is not None: + max_num_worker_suggest = cpu_count + + if max_num_worker_suggest is None: + warnings.warn( + _create_warning_msg( + max_num_worker_suggest, self.num_workers, cpuset_checked + ), + stacklevel=2, + ) + return + + if self.num_workers > max_num_worker_suggest: + warnings.warn( + _create_warning_msg( + max_num_worker_suggest, self.num_workers, cpuset_checked + ), + stacklevel=2, + ) + + +class _BaseDataLoaderIter: + def __init__(self, loader: DataLoader) -> None: + self._dataset = loader.dataset + self._shared_seed = None + self._pg = None + if isinstance(self._dataset, IterDataPipe): + if dist.is_available() and dist.is_initialized(): + self._pg = dist.new_group(backend="gloo") + self._shared_seed = _share_dist_seed(loader.generator, self._pg) + shared_rng = torch.Generator() + shared_rng.manual_seed(self._shared_seed) + self._dataset = torch.utils.data.graph_settings.apply_random_seed( + self._dataset, shared_rng + ) + self._dataset_kind = loader._dataset_kind + self._IterableDataset_len_called = loader._IterableDataset_len_called + self._auto_collation = loader._auto_collation + self._drop_last = loader.drop_last + self._index_sampler = loader._index_sampler + self._num_workers = loader.num_workers + ws, rank = _get_distributed_settings() + self._world_size = ws + self._rank = rank + + if loader.pin_memory and loader.pin_memory_device: + warnings.warn( + "pin_memory_device is deprecated, the current accelerator will be used as the device," + f"ignore pin_memory_device='{loader.pin_memory_device}'.", + stacklevel=2, + ) + if loader.pin_memory and not torch.accelerator.is_available(): + warn_msg = ( + "'pin_memory' argument is set as true but no accelerator is found, " + "then device pinned memory won't be used." + ) + warnings.warn(warn_msg, stacklevel=2) + + # Enabling pin_memory in _BaseDataLoaderIter to support identical + # behavior in forked implementations using _BaseDataLoaderIter. + self._pin_memory = loader.pin_memory and torch.accelerator.is_available() + + # Set pin memory device based on the current accelerator. + self._pin_memory_device = ( + acc.type + if self._pin_memory + and (acc := torch.accelerator.current_accelerator()) is not None + else None + ) + + # Currently, pin_memory would raise error on the MPS backend (see + # https://github.com/pytorch/pytorch/issues/86060), so forcibly + # disable pin_memory on MPS. Remove this restriction once pinned + # memory allocation for MPS is fixed. + if self._pin_memory_device == "mps": + self._pin_memory = False + warn_msg = ( + "'pin_memory' argument is set as true but not supported on MPS now, " + "device pinned memory won't be used." + ) + warnings.warn(warn_msg, stacklevel=2) + + self._timeout = loader.timeout + self._collate_fn = loader.collate_fn + self._sampler_iter = iter(self._index_sampler) + self._base_seed = ( + torch.empty((), dtype=torch.int64) + .random_(generator=loader.generator) + .item() + ) + self._persistent_workers = loader.persistent_workers + self._num_yielded = 0 + self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__" + + def __iter__(self) -> Self: + return self + + def _reset(self, loader, first_iter=False) -> None: + self._sampler_iter = iter(self._index_sampler) + self._num_yielded = 0 + self._IterableDataset_len_called = loader._IterableDataset_len_called + if isinstance(self._dataset, IterDataPipe): + self._shared_seed = _share_dist_seed(loader.generator, self._pg) + shared_rng = torch.Generator() + shared_rng.manual_seed(self._shared_seed) + self._dataset = torch.utils.data.graph_settings.apply_random_seed( + self._dataset, shared_rng + ) + + def _next_index(self): + return next(self._sampler_iter) # may raise StopIteration + + def _next_data(self) -> NoReturn: + raise NotImplementedError + + def __next__(self) -> Any: + with torch.autograd.profiler.record_function(self._profile_name): + if self._sampler_iter is None: + # TODO(https://github.com/pytorch/pytorch/issues/76750) + self._reset() # type: ignore[call-arg] + data = self._next_data() + self._num_yielded += 1 + if ( + self._dataset_kind == _DatasetKind.Iterable + and self._IterableDataset_len_called is not None + and self._num_yielded > self._IterableDataset_len_called + ): + warn_msg = ( + f"Length of IterableDataset {self._dataset} was reported to be {self._IterableDataset_len_called}" + f"(when accessing len(dataloader)), but {self._num_yielded} samples have been fetched. " + ) + if self._num_workers > 0: + warn_msg += ( + "For multiprocessing data-loading, this could be caused by not properly configuring the " + "IterableDataset replica at each worker. Please see " + "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples." + ) + warnings.warn(warn_msg, stacklevel=2) + return data + + def __len__(self) -> int: + return len(self._index_sampler) + + def __getstate__(self): + # TODO: add limited pickling support for sharing an iterator + # across multiple threads for HOGWILD. + # Probably the best way to do this is by moving the sample pushing + # to a separate thread and then just sharing the data queue + # but signalling the end is tricky without a non-blocking API + raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) + + +class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): + def __init__(self, loader) -> None: + super().__init__(loader) + if self._timeout != 0: + raise AssertionError("_SingleProcessDataLoaderIter requires timeout == 0") + if self._num_workers != 0: + raise AssertionError( + "_SingleProcessDataLoaderIter requires num_workers == 0" + ) + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Taking care of distributed sharding + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + # For BC, use default SHARDING_PRIORITIES + torch.utils.data.graph_settings.apply_sharding( + self._dataset, self._world_size, self._rank + ) + + self._dataset_fetcher = _DatasetKind.create_fetcher( + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, + ) + + def _next_data(self): + index = self._next_index() # may raise StopIteration + data = self._dataset_fetcher.fetch(index) # may raise StopIteration + if self._pin_memory: + data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) + return data + + +class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): + r"""Iterates once over the DataLoader's dataset, as specified by the sampler.""" + + # NOTE [ Data Loader Multiprocessing Shutdown Logic ] + # + # Preliminary: + # + # Our data model looks like this (queues are indicated with curly brackets): + # + # main process || + # | || + # {index_queue} || + # | || + # worker processes || DATA + # | || + # {worker_result_queue} || FLOW + # | || + # pin_memory_thread of main process || DIRECTION + # | || + # {data_queue} || + # | || + # data output \/ + # + # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if + # `pin_memory=False`. + # + # + # Terminating multiprocessing logic requires very careful design. In + # particular, we need to make sure that + # + # 1. The iterator gracefully exits the workers when its last reference is + # gone or it is depleted. + # + # In this case, the workers should be gracefully exited because the + # main process may still need to continue to run, and we want cleaning + # up code in the workers to be executed (e.g., releasing GPU memory). + # Naturally, we implement the shutdown logic in `__del__` of + # DataLoaderIterator. + # + # We delay the discussion on the logic in this case until later. + # + # 2. The iterator exits the workers when the loader process and/or worker + # processes exits normally or with error. + # + # We set all workers and `pin_memory_thread` to have `daemon=True`. + # + # You may ask, why can't we make the workers non-daemonic, and + # gracefully exit using the same logic as we have in `__del__` when the + # iterator gets deleted (see 1 above)? + # + # First of all, `__del__` is **not** guaranteed to be called when + # interpreter exits. Even if it is called, by the time it executes, + # many Python core library resources may already be freed, and even + # simple things like acquiring an internal lock of a queue may hang. + # Therefore, in this case, we actually need to prevent `__del__` from + # being executed, and rely on the automatic termination of daemonic + # children. + # + # Thus, we register an `atexit` hook that sets a global flag + # `_utils.python_exit_status`. Since `atexit` hooks are executed in the + # reverse order of registration, we are guaranteed that this flag is + # set before library resources we use are freed (which, at least in + # CPython, is done via an `atexit` handler defined in + # `multiprocessing/util.py` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362 + # registered when an object requiring this mechanism is first + # created, e.g., `mp.Queue` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103 + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29 + # ) + # + # So in `__del__`, we check if `_utils.python_exit_status` is set or + # `None` (freed), and perform no-op if so. + # + # However, simply letting library clean-up codes run can also be bad, + # because such codes (i.e., `multiprocessing.util._exit_function()`) + # include join putting threads for `mp.Queue`, which can be blocking. + # Hence, the main process putting threads are called with + # `cancel_join_thread` at creation. See later section + # [ 3b. A process won't hang when putting into a queue; ] + # for more details. + # + # Here are two example cases where library clean-up codes can run + # before `__del__` is called: + # + # 1. If we hold onto a reference to the iterator, it more often + # than not tries to do `multiprocessing` library cleaning before + # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666) + # and thus prevents our cleaning-up code to run first. + # + # 2. A similar issue araises when a `DataLoader` is used in a subprocess. + # When a process ends, it shuts the all its daemonic children + # down with a SIGTERM (instead of joining them without a timeout). + # Similarly for threads, but by a different mechanism. This fact, + # together with a few implementation details of multiprocessing, forces + # us to make workers daemonic. All of our problems arise when a + # DataLoader is used in a subprocess, and are caused by multiprocessing + # code which looks more or less like this: + # + # try: + # your_function_using_a_dataloader() + # finally: + # multiprocessing.util._exit_function() + # + # The joining/termination mentioned above happens inside + # `_exit_function()`. Now, if `your_function_using_a_dataloader()` + # throws, the stack trace stored in the exception will prevent the + # frame which uses `DataLoaderIter` to be freed. If the frame has any + # reference to the `DataLoaderIter` (e.g., in a method of the iter), + # its `__del__`, which starts the shutdown procedure, will not be + # called. That, in turn, means that workers aren't notified. Attempting + # to join in `_exit_function` will then result in a hang. + # + # For context, `_exit_function` is also registered as an `atexit` call. + # So it is unclear to me (@ssnl) why this is needed in a finally block. + # The code dates back to 2008 and there is no comment on the original + # PEP 371 or patch https://bugs.python.org/issue3050 (containing both + # the finally block and the `atexit` registration) that explains this. + # + # + # Finally, another choice is to just shutdown workers with logic in 1 + # above whenever we see an error in `next`. This isn't ideal because + # a. It prevents users from using try-catch to resume data loading. + # b. It doesn't prevent hanging if users have references to the + # iterator. + # + # 3. All processes exit if any of them die unexpectedly by fatal signals. + # + # As shown above, the workers are set as daemonic children of the main + # process. However, automatic cleaning-up of such child processes only + # happens if the parent process exits gracefully (e.g., not via fatal + # signals like SIGKILL). So we must ensure that each process will exit + # even the process that should send/receive data to/from it were + # killed, i.e., + # + # a. A process won't hang when getting from a queue. + # + # Even with carefully designed data dependencies (i.e., a `put()` + # always corresponding to a `get()`), hanging on `get()` can still + # happen when data in queue is corrupted (e.g., due to + # `cancel_join_thread` or unexpected exit). + # + # For child exit, we set a timeout whenever we try to get data + # from `data_queue`, and check the workers' status on each timeout + # and error. + # See `_DataLoaderiter._get_batch()` and + # `_DataLoaderiter._try_get_data()` for details. + # + # Additionally, for child exit on non-Windows platforms, we also + # register a SIGCHLD handler (which is supported on Windows) on + # the main process, which checks if any of the workers fail in the + # (Python) handler. This is more efficient and faster in detecting + # worker failures, compared to only using the above mechanism. + # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. + # + # For `.get()` calls where the sender(s) is not the workers, we + # guard them with timeouts, and check the status of the sender + # when timeout happens: + # + in the workers, the `_utils.worker.ManagerWatchdog` class + # checks the status of the main process. + # + if `pin_memory=True`, when getting from `pin_memory_thread`, + # check `pin_memory_thread` status periodically until `.get()` + # returns or see that `pin_memory_thread` died. + # + # b. A process won't hang when putting into a queue; + # + # We use `mp.Queue` which has a separate background thread to put + # objects from an unbounded buffer array. The background thread is + # daemonic and usually automatically joined when the process + # *exits*. + # + # In case that the receiver has ended abruptly while + # reading from the pipe, the join will hang forever. The usual + # solution for this in Python is calling `q.cancel_join_thread`, + # which prevents automatically joining it when finalizing + # (exiting). + # + # Nonetheless, `cancel_join_thread` must only be called when the + # queue is **not** going to be read from or write into by another + # process, because it may hold onto a lock or leave corrupted data + # in the queue, leading other readers/writers to hang. + # + # Hence, + # + For worker processes, we only do so (for their output + # queues, i.e., `worker_result_queue`) before exiting. + # + For `pin_memory_thread`, its output queue `data_queue` is a + # `queue.Queue` that does blocking `put` if the queue is full. + # So there is no above problem, but as a result, in + # `_pin_memory_loop`, we do need to wrap the `put` in a loop + # that breaks not only upon success, but also when the main + # process stops reading, i.e., is shutting down. + # + For loader process, we `cancel_join_thread()` for all + # `_index_queues` because the whole purpose of workers and + # `pin_memory_thread` is to serve the loader process. If + # loader process is already exiting, we don't really care if + # the queues are corrupted. + # + # + # Now let's get back to 1: + # how we gracefully exit the workers when the last reference to the + # iterator is gone. + # + # To achieve this, we implement the following logic along with the design + # choices mentioned above: + # + # `workers_done_event`: + # A `multiprocessing.Event` shared among the main process and all worker + # processes. This is used to signal the workers that the iterator is + # shutting down. After it is set, they will not send processed data to + # queues anymore, and only wait for the final `None` before exiting. + # `done_event` isn't strictly needed. I.e., we can just check for `None` + # from the input queue, but it allows us to skip wasting resources + # processing data if we are already shutting down. + # + # `pin_memory_thread_done_event`: + # A `threading.Event` for a similar purpose to that of + # `workers_done_event`, but is for the `pin_memory_thread`. The reason + # that separate events are needed is that `pin_memory_thread` reads from + # the output queue of the workers. But the workers, upon seeing that + # `workers_done_event` is set, only wants to see the final `None`, and is + # not required to flush all data in the output queue (e.g., it may call + # `cancel_join_thread` on that queue if its `IterableDataset` iterator + # happens to exhaust coincidentally, which is out of the control of the + # main process). Thus, since we will exit `pin_memory_thread` before the + # workers (see below), two separate events are used. + # + # NOTE: In short, the protocol is that the main process will set these + # `done_event`s and then the corresponding processes/threads a `None`, + # and that they may exit at any time after receiving the `None`. + # + # NOTE: Using `None` as the final signal is valid, since normal data will + # always be a 2-tuple with the 1st element being the index of the data + # transferred (different from dataset index/key), and the 2nd being + # either the dataset key or the data sample (depending on which part + # of the data model the queue is at). + # + # [ worker processes ] + # While loader process is alive: + # Get from `index_queue`. + # If get anything else, + # Check `workers_done_event`. + # If set, continue to next iteration + # i.e., keep getting until see the `None`, then exit. + # Otherwise, process data: + # If is fetching from an `IterableDataset` and the iterator + # is exhausted, send an `_IterableDatasetStopIteration` + # object to signal iteration end. The main process, upon + # receiving such an object, will send `None` to this + # worker and not use the corresponding `index_queue` + # anymore. + # If timed out, + # No matter `workers_done_event` is set (still need to see `None`) + # or not, must continue to next iteration. + # (outside loop) + # If `workers_done_event` is set, (this can be False with `IterableDataset`) + # `data_queue.cancel_join_thread()`. (Everything is ending here: + # main process won't read from it; + # other workers will also call + # `cancel_join_thread`.) + # + # [ pin_memory_thread ] + # # No need to check main thread. If this thread is alive, the main loader + # # thread must be alive, because this thread is set as daemonic. + # While `pin_memory_thread_done_event` is not set: + # Get from `worker_result_queue`. + # If timed out, continue to get in the next iteration. + # Otherwise, process data. + # While `pin_memory_thread_done_event` is not set: + # Put processed data to `data_queue` (a `queue.Queue` with blocking put) + # If timed out, continue to put in the next iteration. + # Otherwise, break, i.e., continuing to the out loop. + # + # NOTE: we don't check the status of the main thread because + # 1. if the process is killed by fatal signal, `pin_memory_thread` + # ends. + # 2. in other cases, either the cleaning-up in __del__ or the + # automatic exit of daemonic thread will take care of it. + # This won't busy-wait either because `.get(timeout)` does not + # busy-wait. + # + # [ main process ] + # In the DataLoader Iter's `__del__` + # b. Exit `pin_memory_thread` + # i. Set `pin_memory_thread_done_event`. + # ii Put `None` in `worker_result_queue`. + # iii. Join the `pin_memory_thread`. + # iv. `worker_result_queue.cancel_join_thread()`. + # + # c. Exit the workers. + # i. Set `workers_done_event`. + # ii. Put `None` in each worker's `index_queue`. + # iii. Join the workers. + # iv. Call `.cancel_join_thread()` on each worker's `index_queue`. + # + # NOTE: (c) is better placed after (b) because it may leave corrupted + # data in `worker_result_queue`, which `pin_memory_thread` + # reads from, in which case the `pin_memory_thread` can only + # happen at timing out, which is slow. Nonetheless, same thing + # happens if a worker is killed by signal at unfortunate times, + # but in other cases, we are better off having a non-corrupted + # `worker_result_queue` for `pin_memory_thread`. + # + # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) + # can be omitted + # + # NB: `done_event`s isn't strictly needed. E.g., we can just check for + # `None` from `index_queue`, but it allows us to skip wasting resources + # processing indices already in `index_queue` if we are already shutting + # down. + + def __init__(self, loader) -> None: + super().__init__(loader) + + self._prefetch_factor = loader.prefetch_factor + self._in_order = loader.in_order + + if self._num_workers <= 0: + raise AssertionError( + "num_workers must be greater than 0 for MultiProcessingDataLoaderIter" + ) + if self._prefetch_factor <= 0: + raise AssertionError( + "prefetch_factor must be greater than 0 for MultiProcessingDataLoaderIter" + ) + + if loader.multiprocessing_context is None: + multiprocessing_context = torch.multiprocessing + else: + multiprocessing_context = loader.multiprocessing_context + + self._worker_init_fn = loader.worker_init_fn + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Additional worker init function will take care of sharding in MP and Distributed + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + self._worker_init_fn = functools.partial( + _sharding_worker_init_fn, + self._worker_init_fn, + self._world_size, + self._rank, + ) + + # No certainty which module multiprocessing_context is + self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + self._worker_pids_set = False + self._shutdown = False + self._workers_done_event = multiprocessing_context.Event() + + self._index_queues = [] + self._workers = [] + for i in range(self._num_workers): + # No certainty which module multiprocessing_context is + index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + # Need to `cancel_join_thread` here! + # See sections (2) and (3b) above. + index_queue.cancel_join_thread() + w = multiprocessing_context.Process( + target=_utils.worker._worker_loop, + args=( + self._dataset_kind, + self._dataset, + index_queue, + self._worker_result_queue, + self._workers_done_event, + self._auto_collation, + self._collate_fn, + self._drop_last, + self._base_seed, + self._worker_init_fn, + i, + self._num_workers, + self._persistent_workers, + self._shared_seed, + ), + ) + w.daemon = True + # NB: Process.start() actually take some time as it needs to + # start a process and pass the arguments over via a pipe. + # Therefore, we only add a worker to self._workers list after + # it started, so that we do not call .join() if program dies + # before it starts, and __del__ tries to join but will get: + # AssertionError: can only join a started process. + from pickle import PicklingError + + try: + w.start() + except (TypeError, AttributeError, PicklingError): + warnings.warn( + "Got pickle error when attempting to start a worker Process. " + "This might be because the worker Process arguments are not picklable. " + "Python 3.14+ changed the multiprocessing start method in non-Mac POSIX platforms " + "to 'forkserver', which requires the worker Process arguments to be picklable. " + "You can also try multiprocessing.set_start_method('fork').", + stacklevel=2, + ) + raise + self._index_queues.append(index_queue) + self._workers.append(w) + + if self._pin_memory: + self._pin_memory_thread_done_event = threading.Event() + + # Queue is not type-annotated + self._data_queue = queue.Queue() # type: ignore[var-annotated] + current_device_id = torch.accelerator.current_device_index() + pin_memory_thread = threading.Thread( + target=_utils.pin_memory._pin_memory_loop, + args=( + self._worker_result_queue, + self._data_queue, + current_device_id, + self._pin_memory_thread_done_event, + self._pin_memory_device, + ), + ) + pin_memory_thread.daemon = True + pin_memory_thread.start() + # Similar to workers (see comment above), we only register + # pin_memory_thread once it is started. + self._pin_memory_thread = pin_memory_thread + else: + self._data_queue = self._worker_result_queue # type: ignore[assignment] + + # In some rare cases, persistent workers (daemonic processes) + # would be terminated before `__del__` of iterator is invoked + # when main process exits + # It would cause failure when pin_memory_thread tries to read + # corrupted data from worker_result_queue + # atexit is used to shutdown thread and child processes in the + # right sequence before main process exits + if self._persistent_workers and self._pin_memory: + import atexit + + for w in self._workers: + atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) + + # .pid can be None only before process is spawned (not the case, so ignore) + _utils.signal_handling._set_worker_pids( + id(self), + tuple(w.pid for w in self._workers), # type: ignore[misc] + ) + _utils.signal_handling._set_SIGCHLD_handler() + self._worker_pids_set = True + self._reset(loader, first_iter=True) + + def _reset(self, loader, first_iter=False) -> None: + super()._reset(loader, first_iter) + self._send_idx = 0 # idx of the next task to be sent to workers + self._rcvd_idx = 0 # idx of the next task to be returned in __next__ + # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). + # map: task idx => - (worker_id,) if data isn't fetched (outstanding) + # \ (worker_id, data) if data is already fetched (out-of-order) + self._task_info = {} + self._tasks_outstanding = ( + 0 # always equal to count(v for v in task_info.values() if len(v) == 1) + ) + # A list of booleans representing whether each worker still has work to + # do, i.e., not having exhausted its iterable dataset object. It always + # contains all `True`s if not using an iterable-style dataset + # (i.e., if kind != Iterable). + # Not that this indicates that a worker still has work to do *for this epoch*. + # It does not mean that a worker is dead. In case of `_persistent_workers`, + # the worker will be reset to available in the next epoch. + self._workers_status = [True for i in range(self._num_workers)] + # A list of integers representing how many tasks are outstanding for each worker + # Incremented when a task is dispatched to the worker + # Decremented when that data has been given to the main thread + # Each worker should have at most self._prefetch_factor tasks outstanding + self._workers_num_tasks = [0 for i in range(self._num_workers)] + # Reset the worker queue cycle so it resumes next epoch at worker 0 + self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) + # We resume the prefetching in case it was enabled + if not first_iter: + for idx in range(self._num_workers): + self._index_queues[idx].put( + _utils.worker._ResumeIteration(self._shared_seed) + ) + resume_iteration_cnt = self._num_workers + while resume_iteration_cnt > 0: + return_idx, return_data = self._get_data() + if isinstance(return_idx, _utils.worker._ResumeIteration): + if return_data is not None: + raise AssertionError( + "Expected return_data to be None when resuming iteration" + ) + resume_iteration_cnt -= 1 + # prime the prefetch loop + for _ in range(self._prefetch_factor * self._num_workers): + self._try_put_index() + + def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): + # Tries to fetch data from `self._data_queue` once for a given timeout. + # This can also be used as inner loop of fetching without timeout, with + # the sender status as the loop condition. + # + # This raises a `RuntimeError` if any worker died expectedly. This error + # can come from either the SIGCHLD handler in `_utils/signal_handling.py` + # (only for non-Windows platforms), or the manual check below on errors + # and timeouts. + # + # Returns a 2-tuple: + # (bool: whether successfully get data, any: data if successful else None) + try: + data = self._data_queue.get(timeout=timeout) + return (True, data) + except Exception as e: + # At timeout and error, we manually check whether any worker has + # failed. Note that this is the only mechanism for Windows to detect + # worker failures. + failed_workers = [] + for worker_id, w in enumerate(self._workers): + if self._workers_status[worker_id] and not w.is_alive(): + failed_workers.append(w) + self._mark_worker_as_unavailable(worker_id) + if len(failed_workers) > 0: + pids_str = ", ".join(str(w.pid) for w in failed_workers) + raise RuntimeError( + f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly" + ) from e + if isinstance(e, queue.Empty): + return (False, None) + + import errno + import tempfile + + try: + # Raise an exception if we are this close to the FDs limit. + # Apparently, trying to open only one file is not a sufficient + # test. + # See NOTE [ DataLoader on Linux and open files limit ] + fds_limit_margin = 10 + with contextlib.ExitStack() as stack: + for _ in range(fds_limit_margin): + stack.enter_context( + tempfile.NamedTemporaryFile() # pyrefly: ignore [bad-argument-type] + ) + except OSError as e: + if e.errno == errno.EMFILE: + raise RuntimeError( + "Too many open files. Communication with the" + " workers is no longer possible. Please increase the" + " limit using `ulimit -n` in the shell or change the" + " sharing strategy by calling" + " `torch.multiprocessing.set_sharing_strategy('file_system')`" + " at the beginning of your code" + ) from None + raise + + # NOTE [ DataLoader on Linux and open files limit ] + # + # On Linux when DataLoader is used with multiprocessing we pass the data between + # the root process and the workers through SHM files. We remove those files from + # the filesystem as soon as they are created and keep them alive by + # passing around their file descriptors through AF_UNIX sockets. (See + # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in + # the wiki (https://github.com/pytorch/pytorch/wiki).) + # + # This sometimes leads us to exceeding the open files limit. When that happens, + # and the offending file descriptor is coming over a socket, the `socket` Python + # package silently strips the file descriptor from the message, setting only the + # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that + # it _indicates that some control data were discarded due to lack of space in + # the buffer for ancillary data_). This might reflect the C implementation of + # AF_UNIX sockets. + # + # This behaviour can be reproduced with the script and instructions at the + # bottom of this note. + # + # When that happens, the standard Python `multiprocessing` (and not + # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata` + # + # Sometimes, instead of the FD being stripped, you may get an `OSError: + # Too many open files`, both in the script below and in DataLoader. However, + # this is rare and seems to be nondeterministic. + # + # + # #!/usr/bin/env python3 + # import sys + # import socket + # import os + # import array + # import shutil + # import socket + # + # + # if len(sys.argv) != 4: + # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)") + # sys.exit(1) + # + # if __name__ == '__main__': + # dirname = sys.argv[1] + # sock_path = dirname + "/sock" + # iterations = int(sys.argv[2]) + # def dummy_path(i): + # return dirname + "/" + str(i) + ".dummy" + # + # + # if sys.argv[3] == 'send': + # while not os.path.exists(sock_path): + # pass + # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # client.connect(sock_path) + # for i in range(iterations): + # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT) + # ancdata = array.array('i', [fd]) + # msg = bytes([i % 256]) + # print("Sending fd ", fd, " (iteration #", i, ")") + # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)]) + # + # + # else: + # assert sys.argv[3] == 'recv' + # + # if os.path.exists(dirname): + # raise Exception("Directory exists") + # + # os.mkdir(dirname) + # + # print("Opening socket...") + # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # server.bind(sock_path) + # + # print("Listening...") + # for i in range(iterations): + # a = array.array('i') + # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize)) + # assert(len(ancdata) == 1) + # cmsg_level, cmsg_type, cmsg_data = ancdata[0] + # a.frombytes(cmsg_data) + # print("Received fd ", a[0], " (iteration #", i, ")") + # + # shutil.rmtree(dirname) + # + # Steps to reproduce: + # + # 1. Run two shells and set lower file descriptor limit in the receiving one: + # (shell1) ulimit -n 1020 + # (shell2) ulimit -n 1022 + # + # 2. Run the script above with the `recv` option in the first shell + # (shell1) ./test_socket.py sock_tmp 1017 recv + # + # 3. Run the script with the `send` option in the second shell: + # (shell2) ./test_socket.py sock_tmp 1017 send + + def _get_data(self): + # Fetches data from `self._data_queue`. + # + # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, + # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` + # in a loop. This is the only mechanism to detect worker failures for + # Windows. For other platforms, a SIGCHLD handler is also used for + # worker failure detection. + # + # If `pin_memory=True`, we also need check if `pin_memory_thread` had + # died at timeouts. + if self._timeout > 0: + success, data = self._try_get_data(self._timeout) + if success: + return data + else: + raise RuntimeError( + f"DataLoader timed out after {self._timeout} seconds" + ) + elif self._pin_memory: + while self._pin_memory_thread.is_alive(): + success, data = self._try_get_data() + if success: + return data + else: + # while condition is false, i.e., pin_memory_thread died. + raise RuntimeError("Pin memory thread exited unexpectedly") + # In this case, `self._data_queue` is a `queue.Queue`,. But we don't + # need to call `.task_done()` because we don't use `.join()`. + else: + while True: + success, data = self._try_get_data() + if success: + return data + + def _next_data(self): + while True: + # If the worker responsible for `self._rcvd_idx` has already ended + # and was unable to fulfill this task (due to exhausting an `IterableDataset`), + # we try to advance `self._rcvd_idx` to find the next valid index. + # + # This part needs to run in the loop because both the `self._get_data()` + # call and `_IterableDatasetStopIteration` check below can mark + # extra worker(s) as dead. + while self._rcvd_idx < self._send_idx: + info = self._task_info.get(self._rcvd_idx, None) + if info: + worker_id = info[0] + if ( + len(info) == 2 or self._workers_status[worker_id] + ): # has data or is still active + break + del self._task_info[self._rcvd_idx] + self._rcvd_idx += 1 + else: + # no valid `self._rcvd_idx` is found (i.e., didn't break) + if not self._persistent_workers: + self._shutdown_workers() + raise StopIteration + + # Now `self._rcvd_idx` is the batch index we want to fetch + + # Check if the next sample has already been generated + if len(self._task_info[self._rcvd_idx]) == 2: + worker_id, data = self._task_info.pop(self._rcvd_idx) + self._rcvd_idx += 1 + return self._process_data(data, worker_id) + + if self._shutdown or self._tasks_outstanding <= 0: + raise AssertionError( + "Invalid iterator state: shutdown or no outstanding tasks when fetching next data" + ) + idx, data = self._get_data() + self._tasks_outstanding -= 1 + if self._dataset_kind == _DatasetKind.Iterable: + # Check for _IterableDatasetStopIteration + if isinstance(data, _utils.worker._IterableDatasetStopIteration): + if self._persistent_workers: + self._workers_status[data.worker_id] = False + else: + self._mark_worker_as_unavailable(data.worker_id) + self._try_put_index() + continue + + if idx != self._rcvd_idx: + if not self._in_order: + # don't store it for later, process now + # delete from self._task_info immediately + # this keeps the object size manageable + worker_id = self._task_info.pop(idx)[0] + return self._process_data(data, worker_id) + # store out-of-order samples + self._task_info[idx] += (data,) + else: + worker_id = self._task_info.pop(idx)[0] + self._rcvd_idx += 1 + return self._process_data(data, worker_id) + + def _try_put_index(self) -> None: + max_tasks = self._prefetch_factor * self._num_workers + if self._tasks_outstanding >= max_tasks: + raise AssertionError( + "Number of outstanding tasks exceeded maximum allowed tasks" + ) + + try: + index = self._next_index() + except StopIteration: + return + for _ in range(self._num_workers): # find the next active worker, if any + worker_queue_idx = next(self._worker_queue_idx_cycle) + if self._workers_status[worker_queue_idx]: + if self._in_order: + break + elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum( + self._workers_status + ): + # when self._in_order is False, distribute work to a worker if it has capacity + # _workers_status is updated only in this thread, so the sum is guaranteed > 0 + break + else: + # not found (i.e., didn't break) + return + + self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined] + self._task_info[self._send_idx] = (worker_queue_idx,) + self._workers_num_tasks[worker_queue_idx] += 1 + self._tasks_outstanding += 1 + self._send_idx += 1 + + def _process_data(self, data, worker_idx): + self._workers_num_tasks[worker_idx] -= 1 + self._try_put_index() + if isinstance(data, ExceptionWrapper): + data.reraise() + return data + + def _mark_worker_as_unavailable(self, worker_id, shutdown=False) -> None: + # Mark a worker as having finished its work e.g., due to + # exhausting an `IterableDataset`. This should be used only when this + # `_MultiProcessingDataLoaderIter` is going to continue running. + + if ( + not self._workers_status[worker_id] + and not self._persistent_workers + and not shutdown + ): + raise AssertionError( + "Worker status inconsistent when marking worker as unavailable" + ) + + # Signal termination to that specific worker. + q = self._index_queues[worker_id] + # Indicate that no more data will be put on this queue by the current + # process. + q.put(None) + + # Note that we don't actually join the worker here, nor do we remove the + # worker's pid from C side struct because (1) joining may be slow, and + # (2) since we don't join, the worker may still raise error, and we + # prefer capturing those, rather than ignoring them, even though they + # are raised after the worker has finished its job. + # Joining is deferred to `_shutdown_workers`, which it is called when + # all workers finish their jobs (e.g., `IterableDataset` replicas) or + # when this iterator is garbage collected. + + self._workers_status[worker_id] = False + + if self._workers_done_event.is_set() != shutdown: + raise AssertionError( + "_workers_done_event state does not match shutdown flag" + ) + + def _shutdown_workers(self) -> None: + # Called when shutting down this `_MultiProcessingDataLoaderIter`. + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on + # the logic of this function. + if ( + _utils is None + or _utils.python_exit_status is True + or _utils.python_exit_status is None + ): + # See (2) of the note. If Python is shutting down, do no-op. + return + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + if not self._shutdown: + self._shutdown = True + try: + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + + # Exit `pin_memory_thread` first because exiting workers may leave + # corrupted data in `worker_result_queue` which `pin_memory_thread` + # reads from. + if hasattr(self, "_pin_memory_thread"): + # Use hasattr in case error happens before we set the attribute. + self._pin_memory_thread_done_event.set() + # Send something to pin_memory_thread in case it is waiting + # so that it can wake up and check `pin_memory_thread_done_event` + self._worker_result_queue.put((None, None)) + self._pin_memory_thread.join() + self._worker_result_queue.cancel_join_thread() + self._worker_result_queue.close() + + # Exit workers now. + self._workers_done_event.set() + for worker_id in range(len(self._workers)): + # Get number of workers from `len(self._workers)` instead of + # `self._num_workers` in case we error before starting all + # workers. + # If we are using workers_status with persistent_workers + # we have to shut it down because the worker is paused + if self._persistent_workers or self._workers_status[worker_id]: + self._mark_worker_as_unavailable(worker_id, shutdown=True) + for w in self._workers: + # We should be able to join here, but in case anything went + # wrong, we set a timeout and if the workers fail to join, + # they are killed in the `finally` block. + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + for q in self._index_queues: + q.cancel_join_thread() + q.close() + finally: + # Even though all this function does is putting into queues that + # we have called `cancel_join_thread` on, weird things can + # happen when a worker is killed by a signal, e.g., hanging in + # `Event.set()`. So we need to guard this with SIGCHLD handler, + # and remove pids from the C side data structure only at the + # end. + # + # FIXME: Unfortunately, for Windows, we are missing a worker + # error detection mechanism here in this function, as it + # doesn't provide a SIGCHLD handler. + if self._worker_pids_set: + _utils.signal_handling._remove_worker_pids(id(self)) + self._worker_pids_set = False + for w in self._workers: + if w.is_alive(): + # Existing mechanisms try to make the workers exit + # peacefully, but in case that we unfortunately reach + # here, which we shouldn't, (e.g., pytorch/pytorch#39570), + # we kill the worker. + w.terminate() + + # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter` + @staticmethod + def _clean_up_worker(w) -> None: + try: + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + finally: + if w.is_alive(): + w.terminate() + + def __del__(self) -> None: + self._shutdown_workers() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac93de335b2d7379246de9cee658dd9eafe1d303 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__init__.py @@ -0,0 +1 @@ +from torch.utils.data.datapipes import dataframe as dataframe, iter as iter, map as map diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a85da6b10b6630aa41d1fc457dfdb412ce08cc8f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_decorator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_decorator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6293df71f420786b4c3bb84e0ec7a8a4d819e895 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_decorator.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92684b7caffe1539c633b8699ac8dce0237f1da1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_typing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_typing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..683b1fce7ab063267ea4128061f80602b736ca81 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_typing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/datapipe.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/datapipe.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b89fc99d7fb0b064c482c0c3559a10ef656af1b3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/datapipe.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/gen_pyi.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/gen_pyi.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09c7f598916ac1722f43482ccf1a88223ae0c2fd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/gen_pyi.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/_decorator.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/_decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..0289668c03abcfc0a8e37bc9ff62365fea3dd1cf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/_decorator.py @@ -0,0 +1,213 @@ +# mypy: allow-untyped-defs +import inspect +from collections.abc import Callable +from functools import wraps +from typing import Any, get_type_hints + +from torch.utils.data.datapipes._typing import _DataPipeMeta +from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe + + +###################################################### +# Functional API +###################################################### +class functional_datapipe: + name: str + + def __init__(self, name: str, enable_df_api_tracing=False) -> None: + """ + Define a functional datapipe. + + Args: + enable_df_api_tracing - if set, any returned DataPipe would accept + DataFrames API in tracing mode. + """ + self.name = name + self.enable_df_api_tracing = enable_df_api_tracing + + def __call__(self, cls): + if issubclass(cls, IterDataPipe): + if isinstance(cls, type): # type: ignore[arg-type] + if not isinstance(cls, _DataPipeMeta): + raise TypeError( + "`functional_datapipe` can only decorate IterDataPipe" + ) + # with non_deterministic decorator + else: + if not isinstance(cls, non_deterministic) and not ( + hasattr(cls, "__self__") + and isinstance(cls.__self__, non_deterministic) + ): + raise TypeError( + "`functional_datapipe` can only decorate IterDataPipe" + ) + IterDataPipe.register_datapipe_as_function( + self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing + ) + elif issubclass(cls, MapDataPipe): + MapDataPipe.register_datapipe_as_function(self.name, cls) + + return cls + + +###################################################### +# Determinism +###################################################### +_determinism: bool = False + + +class guaranteed_datapipes_determinism: + prev: bool + + def __init__(self) -> None: + global _determinism + self.prev = _determinism + _determinism = True + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + global _determinism + _determinism = self.prev + + +class non_deterministic: + cls: type[IterDataPipe] | None = None + # TODO: Lambda for picking + deterministic_fn: Callable[..., bool] + + def __init__(self, arg: type[IterDataPipe] | Callable[..., bool]) -> None: + # 1. Decorator doesn't have any argument + if isinstance(arg, type): # type: ignore[arg-type] + if not issubclass(arg, IterDataPipe): # type: ignore[arg-type] + raise TypeError( + "Only `IterDataPipe` can be decorated with `non_deterministic`" + f", but {arg.__name__} is found" + ) + self.cls = arg # type: ignore[assignment] + # 2. Decorator has an argument of a function + # This class should behave differently given different inputs. Use this + # function to verify the determinism for each instance. + # When the function returns True, the instance is non-deterministic. Otherwise, + # the instance is a deterministic DataPipe. + elif isinstance(arg, Callable): # type:ignore[arg-type] + self.deterministic_fn = arg + else: + raise TypeError(f"{arg} can not be decorated by non_deterministic") + + def __call__(self, *args, **kwargs): + global _determinism + # Decorate IterDataPipe + if self.cls is not None: + if _determinism: + raise TypeError( + f"{self.cls.__name__} is non-deterministic, but you set 'guaranteed_datapipes_determinism'. " + "You can turn off determinism for this DataPipe if that is acceptable " + "for your application" + ) + return self.cls(*args, **kwargs) # type: ignore[call-arg] + + # Decorate with a functional argument + if not ( + isinstance(args[0], type) and issubclass(args[0], IterDataPipe) # type: ignore[arg-type] + ): + raise TypeError( + f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found" + ) + self.cls = args[0] + return self.deterministic_wrapper_fn + + def deterministic_wrapper_fn(self, *args, **kwargs) -> IterDataPipe: + res = self.deterministic_fn(*args, **kwargs) + if not isinstance(res, bool): + raise TypeError( + "deterministic_fn of `non_deterministic` decorator is required " + f"to return a boolean value, but {type(res)} is found" + ) + global _determinism + if _determinism and res: + raise TypeError( + f"{self.cls.__name__} is non-deterministic with the inputs, but you set " # type: ignore[union-attr] + "'guaranteed_datapipes_determinism'. You can turn off determinism " + "for this DataPipe if that is acceptable for your application" + ) + return self.cls(*args, **kwargs) # type: ignore[call-arg, misc] + + +###################################################### +# Type validation +###################################################### +# Validate each argument of DataPipe with hint as a subtype of the hint. +def argument_validation(f): + signature = inspect.signature(f) + hints = get_type_hints(f) + + @wraps(f) + def wrapper(*args, **kwargs): + bound = signature.bind(*args, **kwargs) + for argument_name, value in bound.arguments.items(): + if argument_name in hints and isinstance( + hints[argument_name], _DataPipeMeta + ): + hint = hints[argument_name] + if not isinstance(value, IterDataPipe): + raise TypeError( + f"Expected argument '{argument_name}' as a IterDataPipe, but found {type(value)}" + ) + if not value.type.issubtype(hint.type): + raise TypeError( + f"Expected type of argument '{argument_name}' as a subtype of " + f"hint {hint.type}, but found {value.type}" + ) + + return f(*args, **kwargs) + + return wrapper + + +# Default value is True +_runtime_validation_enabled: bool = True + + +class runtime_validation_disabled: + prev: bool + + def __init__(self) -> None: + global _runtime_validation_enabled + self.prev = _runtime_validation_enabled + _runtime_validation_enabled = False + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + global _runtime_validation_enabled + _runtime_validation_enabled = self.prev + + +# Runtime checking +# Validate output data is subtype of return hint +def runtime_validation(f): + # TODO: + # Can be extended to validate '__getitem__' and nonblocking + if f.__name__ != "__iter__": + raise TypeError( + f"Can not decorate function {f.__name__} with 'runtime_validation'" + ) + + @wraps(f) + def wrapper(self): + global _runtime_validation_enabled + if not _runtime_validation_enabled: + yield from f(self) + else: + it = f(self) + for d in it: + if not self.type.issubtype_of_instance(d): + raise RuntimeError( + f"Expected an instance as subtype of {self.type}, but found {d}({type(d)})" + ) + yield d + + return wrapper diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/_hook_iterator.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/_hook_iterator.py new file mode 100644 index 0000000000000000000000000000000000000000..26836168047497000de0003b0489b19e832015bb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/_hook_iterator.py @@ -0,0 +1,279 @@ +# mypy: allow-untyped-defs +import functools +import inspect +from enum import Enum + +import torch + + +class _SnapshotState(Enum): + r""" + These are the snapshotting-related states that IterDataPipes can be in. + + `NotStarted` - allows you to restore a snapshot and create an iterator with reset + `Restored` - cannot restore again, allows you to create an iterator without resetting the DataPipe + `Iterating` - can restore, will reset if you create a new iterator + """ + + NotStarted = 0 + Restored = 1 + Iterating = 2 + + +def _simplify_obj_name(obj) -> str: + """Simplify the display strings of objects for the purpose of rendering within DataPipe error messages.""" + if inspect.isfunction(obj): + return obj.__name__ + else: + return repr(obj) + + +def _strip_datapipe_from_name(name: str) -> str: + return name.replace("IterDataPipe", "").replace("MapDataPipe", "") + + +def _generate_input_args_string(obj): + """Generate a string for the input arguments of an object.""" + signature = inspect.signature(obj.__class__) + input_param_names = set(signature.parameters.keys()) + result = [] + for name, value in inspect.getmembers(obj): + if name in input_param_names: + result.append((name, _simplify_obj_name(value))) + return ", ".join([f"{name}={value}" for name, value in result]) + + +def _generate_iterdatapipe_msg(datapipe, simplify_dp_name: bool = False): + output_string = ( + f"{datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})" + ) + if simplify_dp_name: + output_string = _strip_datapipe_from_name(output_string) + return output_string + + +def _gen_invalid_iterdatapipe_msg(datapipe) -> str: + return ( + "This iterator has been invalidated because another iterator has been created " + f"from the same IterDataPipe: {_generate_iterdatapipe_msg(datapipe)}\n" + "This may be caused multiple references to the same IterDataPipe. We recommend " + "using `.fork()` if that is necessary." + ) + + +_feedback_msg = ( + "\nFor feedback regarding this single iterator per IterDataPipe constraint, feel free " + "to comment on this issue: https://github.com/pytorch/data/issues/45." +) + + +def _check_iterator_valid(datapipe, iterator_id, next_method_exists=False) -> None: + r""" + Given an instance of a DataPipe and an iterator ID, check if the IDs match, and if not, raises an exception. + + In the case of ChildDataPipe, the ID gets compared to the one stored in `main_datapipe` as well. + """ + if next_method_exists: + # This is the case where `IterDataPipe` has both `__iter__` and `__next__`. + # The `_valid_iterator_id` should either be never set (`None`), or set by at most one + # iterator (`0`). Otherwise, it means there are multiple iterators. + if datapipe._valid_iterator_id is not None and datapipe._valid_iterator_id != 0: + extra_msg = "\nNote that this exception is raised inside your IterDataPipe's a `__next__` method" + raise RuntimeError( + _gen_invalid_iterdatapipe_msg(datapipe) + extra_msg + _feedback_msg + ) + elif ( + hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True + ): + if hasattr(datapipe, "_check_valid_iterator_id"): + if not datapipe._check_valid_iterator_id(iterator_id): + raise RuntimeError( + "This iterator has been invalidated, because a new iterator has been created " + f"from one of the ChildDataPipes of " + f"{_generate_iterdatapipe_msg(datapipe.main_datapipe)}." + + _feedback_msg + ) + else: + raise RuntimeError( + "ChildDataPipe must have method `_check_valid_iterator_id`." + ) + elif datapipe._valid_iterator_id != iterator_id: + raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + _feedback_msg) + + +def _set_datapipe_valid_iterator_id(datapipe): + """Given a DataPipe, updates its valid iterator ID and reset the DataPipe.""" + if hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True: + if hasattr(datapipe, "_set_main_datapipe_valid_iterator_id"): + datapipe._set_main_datapipe_valid_iterator_id() # reset() is called within this method when appropriate + else: + raise RuntimeError( + "ChildDataPipe must have method `_set_main_datapipe_valid_iterator_id`." + ) + else: + if datapipe._valid_iterator_id is None: + datapipe._valid_iterator_id = 0 + else: + datapipe._valid_iterator_id += 1 + datapipe.reset() + return datapipe._valid_iterator_id + + +def hook_iterator(namespace) -> None: + r""" + Define a hook that is applied to all `__iter__` of metaclass `_DataPipeMeta`. + + This is done for the purpose of profiling and checking if an iterator is still valid. + """ + + def profiler_record_fn_context(datapipe): + if not hasattr(datapipe, "_profile_name"): + datapipe._profile_name = _generate_iterdatapipe_msg( + datapipe, simplify_dp_name=True + ) + return torch.autograd.profiler.record_function(datapipe._profile_name) + + class IteratorDecorator: + r""" + Wrap the iterator and modifying its `__next__` method. + + This decorator is applied to DataPipes of which `__iter__` method is NOT a generator function. + Those `__iter__` method commonly returns `self` but not necessarily. + """ + + def __init__(self, iterator, datapipe, iterator_id, has_next_method) -> None: + self.iterator = iterator + self.datapipe = datapipe + self.iterator_id = iterator_id + self._profiler_enabled = torch.autograd._profiler_enabled() + # Check if `__iter__` returns `self` and `DataPipe` has `__next__` + self.self_and_has_next_method = ( + self.iterator is self.datapipe and has_next_method + ) + + def __iter__(self): + return self + + def _get_next(self): + """Return next with logic related to iterator validity, profiler, and incrementation of samples yielded.""" + _check_iterator_valid(self.datapipe, self.iterator_id) + result = next(self.iterator) + if not self.self_and_has_next_method: + self.datapipe._number_of_samples_yielded += 1 + return result + + def __next__(self): + # TODO: Add try-except to in-place reduce traceback from the Exception + # See: https://github.com/pytorch/data/issues/284 + if self._profiler_enabled: + with profiler_record_fn_context(self.datapipe): + return self._get_next() + else: # Decided against using `contextlib.nullcontext` for performance reasons + return self._get_next() + + def __getattr__(self, name): + return getattr(self.iterator, name) + + func = namespace["__iter__"] + + # ``__iter__`` of IterDataPipe is a generator function + if inspect.isgeneratorfunction(func): + + @functools.wraps(func) + def wrap_generator(*args, **kwargs): + gen = func(*args, **kwargs) + datapipe = args[0] + if datapipe._fast_forward_iterator: + it = datapipe._fast_forward_iterator + datapipe._fast_forward_iterator = None + datapipe._snapshot_state = _SnapshotState.Iterating + while True: + try: + yield next(it) + except StopIteration: + return + iterator_id = _set_datapipe_valid_iterator_id( + datapipe + ) # This ID is tied to each created iterator + _profiler_enabled = torch.autograd._profiler_enabled() + try: + if _profiler_enabled: + with profiler_record_fn_context(datapipe): + response = gen.send(None) + else: + response = gen.send(None) + + while True: + datapipe._number_of_samples_yielded += 1 + request = yield response + # Pass through here every time `__next__` is called + if _profiler_enabled: + with profiler_record_fn_context(datapipe): + _check_iterator_valid(datapipe, iterator_id) + response = gen.send(request) + else: # Decided against using `contextlib.nullcontext` for performance reasons + _check_iterator_valid(datapipe, iterator_id) + response = gen.send(request) + except StopIteration: + return + except Exception as e: + # TODO: Simplify the traceback message to skip over `response = gen.send(None)` + # Part of https://github.com/pytorch/data/issues/284 + datapipe = args[0] + msg = "thrown by __iter__ of" + single_iterator_msg = "single iterator per IterDataPipe constraint" + if hasattr(e.args, "__len__"): + full_msg = f"{msg} {datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})" + if len(e.args) == 0 or not isinstance( + e.args[0], str + ): # If an exception message doesn't exist + e.args = (f"\nThis exception is {full_msg}",) + elif msg not in e.args[0] and single_iterator_msg not in e.args[0]: + e.args = ( + e.args[0] + f"\nThis exception is {full_msg}", + ) + e.args[1:] + raise + + namespace["__iter__"] = wrap_generator + else: # ``__iter__`` of IterDataPipe is NOT a generator function + # IterDataPipe is an iterator with both ``__iter__`` and ``__next__`` + # And ``__iter__`` may or may not return `self` + if "__next__" in namespace: # If `__next__` exists, put a wrapper around it + next_func = namespace["__next__"] + + @functools.wraps(next_func) + def wrap_next(*args, **kwargs): + datapipe = args[0] + if torch.autograd._profiler_enabled(): + with profiler_record_fn_context(datapipe): + result = next_func(*args, **kwargs) + else: + result = next_func(*args, **kwargs) + datapipe._number_of_samples_yielded += 1 + return result + + namespace["__next__"] = wrap_next + + # Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but + # the user will be violating the iterator protocol. Potential issue: + # 1. Valid iterator ID may not update or checked properly + # 2. The number of samples yielded will be miscounted + + # Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators + @functools.wraps(func) + def wrap_iter(*args, **kwargs): + iter_ret = func(*args, **kwargs) + datapipe = args[0] + datapipe._snapshot_state = _SnapshotState.Iterating + if datapipe._fast_forward_iterator: + iter_ret = datapipe._fast_forward_iterator + datapipe._fast_forward_iterator = None + return iter_ret + iterator_id = _set_datapipe_valid_iterator_id( + datapipe + ) # This ID is tied to each created iterator + return IteratorDecorator( + iter_ret, datapipe, iterator_id, "__next__" in namespace + ) + + namespace["__iter__"] = wrap_iter diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/_typing.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..e198aa16caa66105c0b2009ed8da1e655effe151 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/_typing.py @@ -0,0 +1,484 @@ +# mypy: allow-untyped-defs +# Taking reference from official Python typing +# https://github.com/python/cpython/blob/master/Lib/typing.py + +import collections +import functools +import numbers +import sys + +# Please check [Note: TypeMeta and TypeAlias] +# In case of metaclass conflict due to ABCMeta or _ProtocolMeta +# For Python 3.9, only Protocol in typing uses metaclass +from abc import ABCMeta +from collections.abc import Iterator + +# TODO: Use TypeAlias when Python 3.6 is deprecated +from typing import ( # type: ignore[attr-defined] + _eval_type, + _GenericAlias, + _tp_cache, + _type_check, + _type_repr, + Any, + ForwardRef, + Generic, + get_type_hints, + TypeVar, + Union, +) + +from torch.utils.data.datapipes._hook_iterator import _SnapshotState, hook_iterator + + +class GenericMeta(ABCMeta): # type: ignore[no-redef] + pass + + +class Integer(numbers.Integral): + pass + + +class Boolean(numbers.Integral): + pass + + +# Python 'type' object is not subscriptable +# Tuple[int, List, dict] -> valid +# tuple[int, list, dict] -> invalid +# Map Python 'type' to abstract base class +TYPE2ABC = { + bool: Boolean, + int: Integer, + float: numbers.Real, + complex: numbers.Complex, + dict: dict, + list: list, + set: set, + tuple: tuple, + None: type(None), +} + + +def issubtype(left, right, recursive=True): + r""" + Check if the left-side type is a subtype of the right-side type. + + If any of type is a composite type like `Union` and `TypeVar` with + bounds, it would be expanded into a list of types and check all + of left-side types are subtypes of either one from right-side types. + """ + left = TYPE2ABC.get(left, left) + right = TYPE2ABC.get(right, right) + + if right is Any or left == right: + return True + + if isinstance(right, _GenericAlias): + if getattr(right, "__origin__", None) is Generic: + return True + + if right is type(None): + return False + + # Right-side type + constraints = _decompose_type(right) + + if len(constraints) == 0 or Any in constraints: + return True + + if left is Any: + return False + + # Left-side type + variants = _decompose_type(left) + + # all() will return True for empty variants + if len(variants) == 0: + return False + + return all( + _issubtype_with_constraints(variant, constraints, recursive) + for variant in variants + ) + + +def _decompose_type(t, to_list=True): + if isinstance(t, TypeVar): + if t.__bound__ is not None: + ts = [t.__bound__] + else: + # For T_co, __constraints__ is () + ts = list(t.__constraints__) + elif hasattr(t, "__origin__") and t.__origin__ == Union: + ts = t.__args__ + else: + if not to_list: + return None + ts = [t] + # Ignored: Generator has incompatible item type "object"; expected "Type[Any]" + ts = [TYPE2ABC.get(_t, _t) for _t in ts] # type: ignore[misc] + return ts + + +def _issubtype_with_constraints(variant, constraints, recursive=True): + r""" + Check if the variant is a subtype of either one from constraints. + + For composite types like `Union` and `TypeVar` with bounds, they + would be expanded for testing. + """ + if variant in constraints: + return True + + # [Note: Subtype for Union and TypeVar] + # Python typing is able to flatten Union[Union[...]] or Union[TypeVar]. + # But it couldn't flatten the following scenarios: + # - Union[int, TypeVar[Union[...]]] + # - TypeVar[TypeVar[...]] + # So, variant and each constraint may be a TypeVar or a Union. + # In these cases, all of inner types from the variant are required to be + # extracted and verified as a subtype of any constraint. And, all of + # inner types from any constraint being a TypeVar or a Union are + # also required to be extracted and verified if the variant belongs to + # any of them. + + # Variant + vs = _decompose_type(variant, to_list=False) + + # Variant is TypeVar or Union + if vs is not None: + return all(_issubtype_with_constraints(v, constraints, recursive) for v in vs) + + # Variant is not TypeVar or Union + if hasattr(variant, "__origin__") and variant.__origin__ is not None: + v_origin = variant.__origin__ + # In Python-3.9 typing library untyped generics do not have args + v_args = getattr(variant, "__args__", None) + else: + v_origin = variant + v_args = None + + # Constraints + for constraint in constraints: + cs = _decompose_type(constraint, to_list=False) + + # Constraint is TypeVar or Union + if cs is not None: + if _issubtype_with_constraints(variant, cs, recursive): + return True + # Constraint is not TypeVar or Union + else: + # __origin__ can be None for plain list, tuple, ... in Python 3.6 + if hasattr(constraint, "__origin__") and constraint.__origin__ is not None: + c_origin = constraint.__origin__ + if v_origin == c_origin: + if not recursive: + return True + # In Python-3.9 typing library untyped generics do not have args + c_args = getattr(constraint, "__args__", None) + if c_args is None or len(c_args) == 0: + return True + if ( + v_args is not None + and len(v_args) == len(c_args) + and all( + issubtype(v_arg, c_arg) + for v_arg, c_arg in zip(v_args, c_args, strict=True) + ) + ): + return True + # Tuple[int] -> Tuple + else: + if v_origin == constraint: + return True + + return False + + +def issubinstance(data, data_type): + if not issubtype(type(data), data_type, recursive=False): + return False + + # In Python-3.9 typing library __args__ attribute is not defined for untyped generics + dt_args = getattr(data_type, "__args__", None) + if isinstance(data, tuple): + if dt_args is None or len(dt_args) == 0: + return True + if len(dt_args) != len(data): + return False + return all(issubinstance(d, t) for d, t in zip(data, dt_args, strict=True)) + elif isinstance(data, (list, set)): + if dt_args is None or len(dt_args) == 0: + return True + t = dt_args[0] + return all(issubinstance(d, t) for d in data) + elif isinstance(data, dict): + if dt_args is None or len(dt_args) == 0: + return True + kt, vt = dt_args + return all( + issubinstance(k, kt) and issubinstance(v, vt) for k, v in data.items() + ) + + return True + + +# [Note: TypeMeta and TypeAlias] +# In order to keep compatibility for Python 3.6, use Meta for the typing. +# TODO: When PyTorch drops the support for Python 3.6, it can be converted +# into the Alias system and using `__class_getitem__` for DataPipe. The +# typing system will gain benefit of performance and resolving metaclass +# conflicts as elaborated in https://www.python.org/dev/peps/pep-0560/ + + +class _DataPipeType: + r"""Save type annotation in `param`.""" + + def __init__(self, param) -> None: + self.param = param + + def __repr__(self) -> str: + return _type_repr(self.param) + + def __eq__(self, other): + if isinstance(other, _DataPipeType): + return self.param == other.param + return NotImplemented + + def __hash__(self): + return hash(self.param) + + def issubtype(self, other): + if isinstance(other.param, _GenericAlias): + if getattr(other.param, "__origin__", None) is Generic: + return True + if isinstance(other, _DataPipeType): + return issubtype(self.param, other.param) + if isinstance(other, type): + return issubtype(self.param, other) + raise TypeError(f"Expected '_DataPipeType' or 'type', but found {type(other)}") + + def issubtype_of_instance(self, other): + return issubinstance(other, self.param) + + +# Default type for DataPipe without annotation +_T_co = TypeVar("_T_co", covariant=True) +# pyrefly: ignore [invalid-annotation] +_DEFAULT_TYPE = _DataPipeType(Generic[_T_co]) + + +class _DataPipeMeta(GenericMeta): + r""" + Metaclass for `DataPipe`. + + Add `type` attribute and `__init_subclass__` based on the type, and validate the return hint of `__iter__`. + + Note that there is subclass `_IterDataPipeMeta` specifically for `IterDataPipe`. + """ + + type: _DataPipeType + + def __new__(cls, name, bases, namespace, **kwargs): + return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] + + # TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now. + # pyrefly: ignore [no-access] + cls.__origin__ = None + if "type" in namespace: + return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] + + namespace["__type_class__"] = False + # For plain derived class without annotation + for base in bases: + if isinstance(base, _DataPipeMeta): + return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] + + namespace.update( + {"type": _DEFAULT_TYPE, "__init_subclass__": _dp_init_subclass} + ) + return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] + + def __init__(self, name, bases, namespace, **kwargs) -> None: + super().__init__(name, bases, namespace, **kwargs) # type: ignore[call-overload] + + # TODO: Fix isinstance bug + @_tp_cache + def _getitem_(self, params): + if params is None: + raise TypeError(f"{self.__name__}[t]: t can not be None") + if isinstance(params, str): + params = ForwardRef(params) + if not isinstance(params, tuple): + params = (params,) + + msg = f"{self.__name__}[t]: t must be a type" + params = tuple(_type_check(p, msg) for p in params) + + if isinstance(self.type.param, _GenericAlias): + orig = getattr(self.type.param, "__origin__", None) + if isinstance(orig, type) and orig is not Generic: + p = self.type.param[params] # type: ignore[index] + t = _DataPipeType(p) + l = len(str(self.type)) + 2 + name = self.__name__[:-l] + name = name + "[" + str(t) + "]" + bases = (self,) + self.__bases__ + return self.__class__( + name, + bases, + { + "__init_subclass__": _dp_init_subclass, + "type": t, + "__type_class__": True, + }, + ) + + if len(params) > 1: + raise TypeError( + f"Too many parameters for {self} actual {len(params)}, expected 1" + ) + + t = _DataPipeType(params[0]) + + if not t.issubtype(self.type): + raise TypeError( + f"Can not subclass a DataPipe[{t}] from DataPipe[{self.type}]" + ) + + # Types are equal, fast path for inheritance + if self.type == t: + return self + + name = self.__name__ + "[" + str(t) + "]" + bases = (self,) + self.__bases__ + + return self.__class__( + name, + bases, + {"__init_subclass__": _dp_init_subclass, "__type_class__": True, "type": t}, + ) + + # TODO: Fix isinstance bug + def _eq_(self, other): + if not isinstance(other, _DataPipeMeta): + return NotImplemented + if self.__origin__ is None or other.__origin__ is None: # type: ignore[has-type] + return self is other + return ( + self.__origin__ == other.__origin__ # type: ignore[has-type] + and self.type == other.type + ) + + # TODO: Fix isinstance bug + def _hash_(self): + return hash((self.__name__, self.type)) + + +class _IterDataPipeMeta(_DataPipeMeta): + r""" + Metaclass for `IterDataPipe` and inherits from `_DataPipeMeta`. + + Add various functions for behaviors specific to `IterDataPipe`. + """ + + def __new__(cls, name, bases, namespace, **kwargs): + if "reset" in namespace: + reset_func = namespace["reset"] + + @functools.wraps(reset_func) + def conditional_reset(*args, **kwargs) -> None: + r""" + Only execute DataPipe's `reset()` method if `_SnapshotState` is `Iterating` or `NotStarted`. + + This allows recently restored DataPipe to preserve its restored state during the initial `__iter__` call. + """ + datapipe = args[0] + if datapipe._snapshot_state in ( + _SnapshotState.Iterating, + _SnapshotState.NotStarted, + ): + # Reset `NotStarted` is necessary because the `source_datapipe` of a DataPipe might have + # already begun iterating. + datapipe._number_of_samples_yielded = 0 + datapipe._fast_forward_iterator = None + reset_func(*args, **kwargs) + datapipe._snapshot_state = _SnapshotState.Iterating + + namespace["reset"] = conditional_reset + + if "__iter__" in namespace: + hook_iterator(namespace) + return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] + + +def _dp_init_subclass(sub_cls, *args, **kwargs) -> None: + # Add function for datapipe instance to reinforce the type + sub_cls.reinforce_type = reinforce_type + + # TODO: + # - add global switch for type checking at compile-time + + # Ignore internal type class + if getattr(sub_cls, "__type_class__", False): + return + + # Check if the string type is valid + if isinstance(sub_cls.type.param, ForwardRef): + base_globals = sys.modules[sub_cls.__module__].__dict__ + try: + param = _eval_type(sub_cls.type.param, base_globals, locals()) + sub_cls.type.param = param + except TypeError as e: + raise TypeError( + f"{sub_cls.type.param.__forward_arg__} is not supported by Python typing" + ) from e + + if "__iter__" in sub_cls.__dict__: + iter_fn = sub_cls.__dict__["__iter__"] + hints = get_type_hints(iter_fn) + if "return" in hints: + return_hint = hints["return"] + # Plain Return Hint for Python 3.6 + if return_hint == Iterator: + return + if not ( + hasattr(return_hint, "__origin__") + and ( + return_hint.__origin__ == Iterator + or return_hint.__origin__ == collections.abc.Iterator + ) + ): + raise TypeError( + "Expected 'Iterator' as the return annotation for `__iter__` of {}" + ", but found {}".format( + sub_cls.__name__, _type_repr(hints["return"]) + ) + ) + data_type = return_hint.__args__[0] + if not issubtype(data_type, sub_cls.type.param): + raise TypeError( + f"Expected return type of '__iter__' as a subtype of {sub_cls.type}," + f" but found {_type_repr(data_type)} for {sub_cls.__name__}" + ) + + +def reinforce_type(self, expected_type): + r""" + Reinforce the type for DataPipe instance. + + And the 'expected_type' is required to be a subtype of the original type + hint to restrict the type requirement of DataPipe instance. + """ + if isinstance(expected_type, tuple): + expected_type = tuple[expected_type] # type: ignore[valid-type] + _type_check(expected_type, msg="'expected_type' must be a type") + + if not issubtype(expected_type, self.type.param): + raise TypeError( + f"Expected 'expected_type' as subtype of {self.type}, but found {_type_repr(expected_type)}" + ) + + self.type = _DataPipeType(expected_type) + return self diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7f4b7dcb414c205614a694ccaa02961e45e9b3e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__init__.py @@ -0,0 +1,12 @@ +from torch.utils.data.datapipes.dataframe.dataframes import ( + CaptureDataFrame, + DFIterDataPipe, +) +from torch.utils.data.datapipes.dataframe.datapipes import DataFramesAsTuplesPipe + + +__all__ = ["CaptureDataFrame", "DFIterDataPipe", "DataFramesAsTuplesPipe"] + +# Please keep this list sorted +if __all__ != sorted(__all__): + raise AssertionError("__all__ is not sorted") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1d4d457e77da0541cd85ba7fe935a83b8904f21 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0170ee0ef4aaa834fce3bf2317ad7eba43781f48 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..691c1b764918b480d72d389af6d64109059c6599 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e704df0de74ad0b8dbac925ee17627592b521478 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37f968e15e56b1a652967ebe3686753d6e4d369b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfc5c268a17455cb30d981036996a716d0cc668 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py @@ -0,0 +1,128 @@ +# mypy: allow-untyped-defs +from typing import Any + + +_pandas: Any = None +_WITH_PANDAS: bool | None = None + + +def _try_import_pandas() -> bool: + try: + import pandas # type: ignore[import] + + global _pandas + _pandas = pandas + return True + except ImportError: + return False + + +# pandas used only for prototyping, will be shortly replaced with TorchArrow +def _with_pandas() -> bool: + global _WITH_PANDAS + if _WITH_PANDAS is None: + _WITH_PANDAS = _try_import_pandas() + return _WITH_PANDAS + + +class PandasWrapper: + @classmethod + def create_dataframe(cls, data, columns): + if not _with_pandas(): + raise RuntimeError("DataFrames prototype requires pandas to function") + return _pandas.DataFrame(data, columns=columns) # type: ignore[union-attr] + + @classmethod + def is_dataframe(cls, data): + if not _with_pandas(): + return False + return isinstance(data, _pandas.core.frame.DataFrame) # type: ignore[union-attr] + + @classmethod + def is_column(cls, data): + if not _with_pandas(): + return False + return isinstance(data, _pandas.core.series.Series) # type: ignore[union-attr] + + @classmethod + def iterate(cls, data): + if not _with_pandas(): + raise RuntimeError("DataFrames prototype requires pandas to function") + yield from data.itertuples(index=False) + + @classmethod + def concat(cls, buffer): + if not _with_pandas(): + raise RuntimeError("DataFrames prototype requires pandas to function") + return _pandas.concat(buffer) # type: ignore[union-attr] + + @classmethod + def get_item(cls, data, idx): + if not _with_pandas(): + raise RuntimeError("DataFrames prototype requires pandas to function") + return data[idx : idx + 1] + + @classmethod + def get_len(cls, df): + if not _with_pandas(): + raise RuntimeError("DataFrames prototype requires pandas to function") + return len(df.index) + + @classmethod + def get_columns(cls, df): + if not _with_pandas(): + raise RuntimeError("DataFrames prototype requires pandas to function") + return list(df.columns.values.tolist()) + + +# When you build own implementation just override it with dataframe_wrapper.set_df_wrapper(new_wrapper_class) +default_wrapper = PandasWrapper + + +def get_df_wrapper(): + return default_wrapper + + +def set_df_wrapper(wrapper) -> None: + global default_wrapper + default_wrapper = wrapper + + +def create_dataframe(data, columns=None): + wrapper = get_df_wrapper() + return wrapper.create_dataframe(data, columns) + + +def is_dataframe(data): + wrapper = get_df_wrapper() + return wrapper.is_dataframe(data) + + +def get_columns(data): + wrapper = get_df_wrapper() + return wrapper.get_columns(data) + + +def is_column(data): + wrapper = get_df_wrapper() + return wrapper.is_column(data) + + +def concat(buffer): + wrapper = get_df_wrapper() + return wrapper.concat(buffer) + + +def iterate(data): + wrapper = get_df_wrapper() + return wrapper.iterate(data) + + +def get_item(data, idx): + wrapper = get_df_wrapper() + return wrapper.get_item(data, idx) + + +def get_len(df): + wrapper = get_df_wrapper() + return wrapper.get_len(df) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/dataframes.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/dataframes.py new file mode 100644 index 0000000000000000000000000000000000000000..5361c29b4822440d94b4949c5a6062ec7d58a2ef --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/dataframes.py @@ -0,0 +1,481 @@ +# mypy: allow-untyped-defs +from typing import Any, NoReturn + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.dataframe.structures import DataChunkDF +from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe + + +# TODO(VitalyFedyunin): Add error when two different traces get combined + +__all__ = [ + "Capture", + "CaptureA", + "CaptureAdd", + "CaptureCall", + "CaptureControl", + "CaptureDataFrame", + "CaptureDataFrameWithDataPipeOps", + "CaptureF", + "CaptureGetAttr", + "CaptureGetItem", + "CaptureInitial", + "CaptureLikeMock", + "CaptureMul", + "CaptureSetItem", + "CaptureSub", + "CaptureVariable", + "CaptureVariableAssign", + "DataFrameTracer", + "DataFrameTracedOps", + "disable_capture", + "get_val", +] + + +def disable_capture() -> None: + CaptureControl.disabled = True + + +class CaptureControl: + disabled = False + + +class DataFrameTracedOps(DFIterDataPipe): + def __init__(self, source_datapipe, output_var) -> None: + super().__init__() + self.source_datapipe = source_datapipe + self.output_var = output_var + + def __iter__(self): + for item in self.source_datapipe: + yield self.output_var.apply_ops(item) + + +# TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registered functions +DATAPIPES_OPS = [ + "_dataframes_as_tuples", + "groupby", + "_dataframes_filter", + "map", + "to_datapipe", + "shuffle", + "concat", + "batch", + "_dataframes_per_row", + "_dataframes_concat", + "_dataframes_shuffle", +] + +UNIMPLEMENTED_ATTR = ["__deepcopy__", "__setstate__", "is_shardable", "apply_sharding"] + + +class Capture: + # TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures + + def __init__(self, schema_df=None) -> None: + self.ctx = {"operations": [], "variables": [], "schema_df": schema_df} + + def __str__(self) -> str: + return self._ops_str() + + def _ops_str(self): + res = "" + # pyrefly: ignore [not-iterable] + for op in self.ctx["operations"]: + if len(res) > 0: + res += "\n" + res += str(op) + return res + + def __getstate__(self): + # TODO(VitalyFedyunin): Currently can't pickle (why?) + self.ctx["schema_df"] = None + # pyrefly: ignore [not-iterable] + for var in self.ctx["variables"]: + var.calculated_value = None + state = {} + for item in self.__dict__: + state[item] = getattr(self, item) + return state + + def __setstate__(self, state): + for k, v in state.items(): + setattr(self, k, v) + + def __getattr__(self, attrname): + if attrname == "kwarg" or attrname == "kwargs": + raise RuntimeError("no kwargs!") + if attrname == "__deepcopy__": + raise AttributeError + result = CaptureGetAttr(self, attrname, ctx=self.ctx) + return result + + def __getitem__(self, key): + return CaptureGetItem(self, key, ctx=self.ctx) + + def __setitem__(self, key, value) -> None: + # pyrefly: ignore [missing-attribute] + self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx)) + + def __add__(self, add_val): + res = CaptureAdd(self, add_val, ctx=self.ctx) + var = CaptureVariable(res, ctx=self.ctx) + # pyrefly: ignore [missing-attribute] + self.ctx["operations"].append( + CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) + ) + return var + + def __sub__(self, add_val): + res = CaptureSub(self, add_val, ctx=self.ctx) + var = CaptureVariable(res, ctx=self.ctx) + # pyrefly: ignore [missing-attribute] + self.ctx["operations"].append( + CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) + ) + return var + + def __mul__(self, add_val): + res = CaptureMul(self, add_val, ctx=self.ctx) + var = CaptureVariable(res, ctx=self.ctx) + t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) + # pyrefly: ignore [missing-attribute] + self.ctx["operations"].append(t) + return var + + def _is_context_empty(self): + # pyrefly: ignore [bad-argument-type] + return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0 + + def apply_ops_2(self, dataframe) -> None: + # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer) + # pyrefly: ignore [unsupported-operation] + self.ctx["variables"][0].calculated_value = dataframe + # pyrefly: ignore [not-iterable] + for op in self.ctx["operations"]: + op.execute() + + @property + def columns(self): + self.apply_ops_2(self.ctx["schema_df"]) + value = self.execute() + return value.columns + + # TODO(VitalyFedyunin): Add tests + # TODO(VitalyFedyunin): Need to join context if one of them are empty because we used capture + + def __call__(self, *args, **kwargs): + # TODO: Check if args or kwargs have more than one different context + if self._is_context_empty(): + # TODO: Allow CaptureA to take context from mock + for arg in args: + if isinstance(arg, Capture) and not arg._is_context_empty(): + self.ctx = arg.ctx + break + if self._is_context_empty(): + for k, v in kwargs.items(): + if isinstance(k, Capture) and not k._is_context_empty(): + self.ctx = k.ctx + break + if isinstance(v, Capture) and not v._is_context_empty(): + self.ctx = v.ctx + break + + res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs) + var = CaptureVariable(None, ctx=self.ctx) + t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res) + # pyrefly: ignore [missing-attribute] + self.ctx["operations"].append(t) + return var + + +class CaptureF(Capture): + def __init__(self, ctx=None, **kwargs) -> None: + super().__init__() + if ctx is None: + self.ctx = {"operations": [], "variables": []} + else: + self.ctx = ctx + self.kwargs = kwargs + + +class CaptureA(CaptureF): + def __str__(self) -> str: + return f"{self.kwargs['name']}" + + def execute(self): + value = self.kwargs["real_attribute"] + return value + + +class CaptureLikeMock: + def __init__(self, name) -> None: + import unittest.mock as mock + + # TODO(VitalyFedyunin): Do not use private function here, copy own implementation instead. + get_target, attribute = mock._get_target(name) # type: ignore[attr-defined] + self.get_target = get_target + self.attribute = attribute + self.name = name + + def __enter__(self): + self.save = getattr(self.get_target(), self.attribute) + capt = CaptureA(name=self.name, real_attribute=self.save) + setattr(self.get_target(), self.attribute, capt) + + def __exit__(self, *exc_info): + setattr(self.get_target(), self.attribute, self.save) + + +class CaptureCall(Capture): + def __init__(self, callable, ctx=None, **kwargs) -> None: + super().__init__() + if ctx is None: + self.ctx = {"operations": [], "variables": []} + else: + self.ctx = ctx + self.kwargs = kwargs + self.callable = callable + + def __str__(self) -> str: + return "{callable}({args},{kwargs})".format( + callable=self.callable, **self.kwargs + ) + + def execute(self): + # TODO: VitalyFedyunin execute kwargs and maybe nested structures + executed_args = [] + for arg in self.kwargs["args"]: + if isinstance(arg, Capture): + executed_args.append(arg.execute()) + else: + executed_args.append(arg) + left = get_val(self.callable) + return left(*executed_args, **self.kwargs["kwargs"]) + + +class CaptureVariableAssign(CaptureF): + def __str__(self) -> str: + variable = self.kwargs["variable"] + value = self.kwargs["value"] + return f"{variable} = {value}" + + def execute(self) -> None: + self.kwargs["variable"].calculated_value = self.kwargs["value"].execute() + + +class CaptureVariable(Capture): + # TODO(VitalyFedyunin): This should be atomic and thread safe + names_idx = 0 + + def __init__(self, value, ctx) -> None: + super().__init__() + if CaptureControl.disabled: + raise RuntimeError("Attempting to create capture variable with capture off") + self.ctx = ctx + self.value = value + self.name = f"var_{CaptureVariable.names_idx}" + CaptureVariable.names_idx += 1 + self.ctx["variables"].append(self) + + def __str__(self) -> str: + return self.name + + def execute(self): + return self.calculated_value + + def apply_ops(self, dataframe): + # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer) + # pyrefly: ignore [unsupported-operation] + self.ctx["variables"][0].calculated_value = dataframe + # pyrefly: ignore [not-iterable] + for op in self.ctx["operations"]: + op.execute() + return self.calculated_value + + +class CaptureGetItem(Capture): + def __init__(self, left, key, ctx) -> None: + super().__init__() + self.ctx = ctx + self.left = left + self.key = key + + def __str__(self) -> str: + return f"{self.left}[{get_val(self.key)}]" + + def execute(self): + left = self.left.execute() + return left[self.key] + + +class CaptureSetItem(Capture): + def __init__(self, left, key, value, ctx) -> None: + super().__init__() + self.ctx = ctx + self.left = left + self.key = key + self.value = value + + def __str__(self) -> str: + return f"{self.left}[{get_val(self.key)}] = {self.value}" + + def execute(self) -> None: + left = self.left.execute() + value = self.value.execute() + left[self.key] = value + + +class CaptureAdd(Capture): + def __init__(self, left, right, ctx) -> None: + super().__init__() + self.ctx = ctx + self.left = left + self.right = right + + def __str__(self) -> str: + return f"{self.left} + {self.right}" + + def execute(self): + return get_val(self.left) + get_val(self.right) + + +class CaptureMul(Capture): + def __init__(self, left, right, ctx) -> None: + super().__init__() + self.ctx = ctx + self.left = left + self.right = right + + def __str__(self) -> str: + return f"{self.left} * {self.right}" + + def execute(self): + return get_val(self.left) * get_val(self.right) + + +class CaptureSub(Capture): + def __init__(self, left, right, ctx) -> None: + super().__init__() + self.ctx = ctx + self.left = left + self.right = right + + def __str__(self) -> str: + return f"{self.left} - {self.right}" + + def execute(self): + return get_val(self.left) - get_val(self.right) + + +class CaptureGetAttr(Capture): + def __init__(self, src, name, ctx) -> None: + super().__init__() + self.ctx = ctx + self.src = src + self.name = name + + def __str__(self) -> str: + return f"{self.src}.{self.name}" + + def execute(self): + val = get_val(self.src) + return getattr(val, self.name) + + +def get_val(capture): + if isinstance(capture, Capture): + return capture.execute() + elif isinstance(capture, str): + return f'"{capture}"' + else: + return capture + + +class CaptureInitial(CaptureVariable): + def __init__(self, schema_df=None) -> None: + # pyrefly: ignore [bad-assignment] + new_ctx: dict[str, list[Any]] = { + "operations": [], + "variables": [], + "schema_df": schema_df, + } + super().__init__(None, new_ctx) + self.name = f"input_{self.name}" + + +class CaptureDataFrame(CaptureInitial): + pass + + +class CaptureDataFrameWithDataPipeOps(CaptureDataFrame): + def as_datapipe(self): + # pyrefly: ignore [unsupported-operation] + return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self) + + def raw_iterator(self): + return self.as_datapipe().__iter__() + + def __iter__(self): + return iter(self._dataframes_as_tuples()) + + def batch(self, batch_size=10, drop_last: bool = False, wrapper_class=DataChunkDF): + dp = self._dataframes_per_row()._dataframes_concat(batch_size) + dp = dp.as_datapipe().batch(1, drop_last=drop_last, wrapper_class=wrapper_class) + dp._dp_contains_dataframe = True + return dp + + def groupby( + self, + group_key_fn, + *, + buffer_size=10000, + group_size=None, + guaranteed_group_size=None, + drop_remaining=False, + ): + dp = self._dataframes_per_row() + dp = dp.as_datapipe().groupby( + group_key_fn, + buffer_size=buffer_size, + group_size=group_size, + guaranteed_group_size=guaranteed_group_size, + drop_remaining=drop_remaining, + ) + return dp + + def shuffle(self, *args, **kwargs): + return self._dataframes_shuffle(*args, **kwargs) + + def filter(self, *args, **kwargs): + return self._dataframes_filter(*args, **kwargs) + + def collate(self, *args, **kwargs) -> NoReturn: + raise RuntimeError("Can't collate unbatched DataFrames stream") + + def __getattr__(self, attrname): # ? + if attrname in UNIMPLEMENTED_ATTR: + raise AttributeError("Attempting to get ", attrname) + if attrname in DATAPIPES_OPS: + return (self.as_datapipe()).__getattr__(attrname) + return super().__getattr__(attrname) + + +@functional_datapipe("trace_as_dataframe") +class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: ignore[misc] + source_datapipe: Any | None = None + + # TODO(VitalyFedyunin): Must implement all special functions of datapipes + + def set_shuffle_settings(self, *args, **kwargs) -> None: + pass + + def is_shardable(self) -> bool: + return False + + def __init__(self, source_datapipe, schema_df=None) -> None: + self.source_datapipe = source_datapipe + if schema_df is None: + schema_df = next(iter(self.source_datapipe)) + super().__init__(schema_df=schema_df) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/datapipes.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/datapipes.py new file mode 100644 index 0000000000000000000000000000000000000000..50c5a44dfd5f323ca6c27276907abd48d6d5532c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/datapipes.py @@ -0,0 +1,137 @@ +# mypy: allow-untyped-defs +import random +from typing import Any + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper +from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe + + +__all__ = [ + "ConcatDataFramesPipe", + "DataFramesAsTuplesPipe", + "ExampleAggregateAsDataFrames", + "FilterDataFramesPipe", + "PerRowDataFramesPipe", + "ShuffleDataFramesPipe", +] + + +@functional_datapipe("_dataframes_as_tuples") +class DataFramesAsTuplesPipe(IterDataPipe): + def __init__(self, source_datapipe) -> None: + self.source_datapipe = source_datapipe + + def __iter__(self): + for df in self.source_datapipe: + # for record in df.to_records(index=False): + yield from df_wrapper.iterate(df) + + +@functional_datapipe("_dataframes_per_row", enable_df_api_tracing=True) +class PerRowDataFramesPipe(DFIterDataPipe): + def __init__(self, source_datapipe) -> None: + self.source_datapipe = source_datapipe + + def __iter__(self): + for df in self.source_datapipe: + # TODO(VitalyFedyunin): Replacing with TorchArrow only API, as we are dropping pandas as followup + for i in range(len(df)): + yield df[i : i + 1] + + +@functional_datapipe("_dataframes_concat", enable_df_api_tracing=True) +class ConcatDataFramesPipe(DFIterDataPipe): + def __init__(self, source_datapipe, batch=3) -> None: + self.source_datapipe = source_datapipe + self.n_batch = batch + + def __iter__(self): + buffer = [] + for df in self.source_datapipe: + buffer.append(df) + if len(buffer) == self.n_batch: + yield df_wrapper.concat(buffer) + buffer = [] + if buffer: + yield df_wrapper.concat(buffer) + + +@functional_datapipe("_dataframes_shuffle", enable_df_api_tracing=True) +class ShuffleDataFramesPipe(DFIterDataPipe): + def __init__(self, source_datapipe) -> None: + self.source_datapipe = source_datapipe + + def __iter__(self): + size = None + all_buffer: list[Any] = [] + for df in self.source_datapipe: + if size is None: + size = df_wrapper.get_len(df) + all_buffer.extend( + df_wrapper.get_item(df, i) for i in range(df_wrapper.get_len(df)) + ) + random.shuffle(all_buffer) + buffer = [] + for df in all_buffer: + buffer.append(df) + if len(buffer) == size: + yield df_wrapper.concat(buffer) + buffer = [] + if buffer: + yield df_wrapper.concat(buffer) + + +@functional_datapipe("_dataframes_filter", enable_df_api_tracing=True) +class FilterDataFramesPipe(DFIterDataPipe): + def __init__(self, source_datapipe, filter_fn) -> None: + self.source_datapipe = source_datapipe + self.filter_fn = filter_fn + + def __iter__(self): + size = None + all_buffer = [] + filter_res = [] + # pyrefly: ignore [bad-assignment] + for df in self.source_datapipe: + if size is None: + size = len(df.index) + for i in range(len(df.index)): + all_buffer.append(df[i : i + 1]) + filter_res.append(self.filter_fn(df.iloc[i])) + + buffer = [] + for df, res in zip(all_buffer, filter_res, strict=True): + if res: + buffer.append(df) + if len(buffer) == size: + yield df_wrapper.concat(buffer) + buffer = [] + if buffer: + yield df_wrapper.concat(buffer) + + +@functional_datapipe("_to_dataframes_pipe", enable_df_api_tracing=True) +class ExampleAggregateAsDataFrames(DFIterDataPipe): + def __init__(self, source_datapipe, dataframe_size=10, columns=None) -> None: + self.source_datapipe = source_datapipe + self.columns = columns + self.dataframe_size = dataframe_size + + def _as_list(self, item): + try: + return list(item) + except ( + Exception + ): # TODO(VitalyFedyunin): Replace with better iterable exception + return [item] + + def __iter__(self): + aggregate = [] + for item in self.source_datapipe: + aggregate.append(self._as_list(item)) + if len(aggregate) == self.dataframe_size: + yield df_wrapper.create_dataframe(aggregate, columns=self.columns) + aggregate = [] + if len(aggregate) > 0: + yield df_wrapper.create_dataframe(aggregate, columns=self.columns) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/structures.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/structures.py new file mode 100644 index 0000000000000000000000000000000000000000..26b4c33db03cc584f223444c07730ef67f4495e7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/structures.py @@ -0,0 +1,22 @@ +from collections.abc import Iterator +from typing import Any + +from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper +from torch.utils.data.datapipes.datapipe import DataChunk + + +__all__ = ["DataChunkDF"] + + +class DataChunkDF(DataChunk): + """DataChunkDF iterating over individual items inside of DataFrame containers, to access DataFrames user `raw_iterator`.""" + + def __iter__(self) -> Iterator[Any]: + for df in self.items: + yield from df_wrapper.iterate(df) + + def __len__(self) -> int: + total_len = 0 + for df in self.items: + total_len += df_wrapper.get_len(df) + return total_len diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.py new file mode 100644 index 0000000000000000000000000000000000000000..51c1689008530b1ec4e78c9c921fd9aa6629ecfb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.py @@ -0,0 +1,427 @@ +import functools +import pickle +from collections.abc import Callable, Iterable, Iterator +from typing import TypeVar + +from torch.utils._import_utils import import_dill +from torch.utils.data.datapipes._hook_iterator import _SnapshotState +from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta +from torch.utils.data.datapipes.utils.common import ( + _deprecation_warning, + _iter_deprecated_functional_names, + _map_deprecated_functional_names, +) +from torch.utils.data.dataset import Dataset, IterableDataset + + +dill = import_dill() +HAS_DILL = dill is not None + +__all__ = [ + "DataChunk", + "DFIterDataPipe", + "IterDataPipe", + "MapDataPipe", +] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + +UNTRACABLE_DATAFRAME_PIPES = [ + "batch", # As it returns DataChunks + "groupby", # As it returns DataChunks + "_dataframes_as_tuples", # As it unpacks DF + "trace_as_dataframe", # As it used to mark DF for tracing +] + + +class DataChunk(list[_T]): + def __init__(self, items: Iterable[_T]) -> None: + items = list(items) + super().__init__(items) + self.items = items + + def as_str(self, indent: str = "") -> str: + return indent + "[" + ", ".join(str(i) for i in iter(self)) + "]" + + def __iter__(self) -> Iterator[_T]: + yield from super().__iter__() + + def raw_iterator(self) -> Iterator[_T]: + yield from self.items + + +class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): + r""" + Iterable-style DataPipe. + + All DataPipes that represent an iterable of data samples should subclass this. + This style of DataPipes is particularly useful when data come from a stream, or + when the number of samples is too large to fit them all in memory. ``IterDataPipe`` is lazily initialized and its + elements are computed only when ``next()`` is called on the iterator of an ``IterDataPipe``. + + All subclasses should overwrite :meth:`__iter__`, which would return an + iterator of samples in this DataPipe. Calling ``__iter__`` of an ``IterDataPipe`` automatically invokes its + method ``reset()``, which by default performs no operation. When writing a custom ``IterDataPipe``, users should + override ``reset()`` if necessary. The common usages include resetting buffers, pointers, + and various state variables within the custom ``IterDataPipe``. + + Note: + Only `one` iterator can be valid for each ``IterDataPipe`` at a time, + and the creation a second iterator will invalidate the first one. This constraint is necessary because + some ``IterDataPipe`` have internal buffers, whose states can become invalid if there are multiple iterators. + The code example below presents details on how this constraint looks in practice. + If you have any feedback related to this constraint, please see `GitHub IterDataPipe Single Iterator Issue`_. + + These DataPipes can be invoked in two ways, using the class constructor or applying their + functional form onto an existing ``IterDataPipe`` (recommended, available to most but not all DataPipes). + You can chain multiple `IterDataPipe` together to form a pipeline that will perform multiple + operations in succession. + + .. _GitHub IterDataPipe Single Iterator Issue: + https://github.com/pytorch/data/issues/45 + + Note: + When a subclass is used with :class:`~torch.utils.data.DataLoader`, each + item in the DataPipe will be yielded from the :class:`~torch.utils.data.DataLoader` + iterator. When :attr:`num_workers > 0`, each worker process will have a + different copy of the DataPipe object, so it is often desired to configure + each copy independently to avoid having duplicate data returned from the + workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker + process, returns information about the worker. It can be used in either the + dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's + :attr:`worker_init_fn` option to modify each copy's behavior. + + Examples: + General Usage: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper, Mapper + >>> dp = IterableWrapper(range(10)) + >>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor + >>> map_dp_2 = dp.map( + ... lambda x: x + 1 + ... ) # Using functional form (recommended) + >>> list(map_dp_1) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> list(map_dp_2) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0) + >>> list(filter_dp) + [2, 4, 6, 8, 10] + Single Iterator Constraint Example: + >>> from torchdata.datapipes.iter import IterableWrapper, Mapper + >>> source_dp = IterableWrapper(range(10)) + >>> it1 = iter(source_dp) + >>> list(it1) + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + >>> it1 = iter(source_dp) + >>> it2 = iter( + ... source_dp + ... ) # The creation of a new iterator invalidates `it1` + >>> next(it2) + 0 + >>> next(it1) # Further usage of `it1` will raise a `RunTimeError` + """ + + functions: dict[str, Callable] = {} + reduce_ex_hook: Callable | None = None + getstate_hook: Callable | None = None + str_hook: Callable | None = None + repr_hook: Callable | None = None + _valid_iterator_id: int | None = None + _number_of_samples_yielded: int = 0 + _snapshot_state: _SnapshotState = _SnapshotState.NotStarted + _fast_forward_iterator: Iterator | None = None + + def __iter__(self) -> Iterator[_T_co]: + # pyrefly: ignore [bad-return] + return self + + def __getattr__(self, attribute_name): + if attribute_name in IterDataPipe.functions: + if attribute_name in _iter_deprecated_functional_names: + kwargs = _iter_deprecated_functional_names[attribute_name] + _deprecation_warning(**kwargs) + f = IterDataPipe.functions[attribute_name] + function = functools.partial(f, self) + functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",)) + return function + else: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{attribute_name}" + ) + + @classmethod + def register_function(cls, function_name, function) -> None: + cls.functions[function_name] = function + + @classmethod + def register_datapipe_as_function( + cls, function_name, cls_to_register, enable_df_api_tracing=False + ) -> None: + if function_name in cls.functions: + raise Exception( # noqa: TRY002 + f"Unable to add DataPipe function name {function_name} as it is already taken" + ) + + def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs): + result_pipe = cls(source_dp, *args, **kwargs) + if isinstance(result_pipe, IterDataPipe): + if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe): + if function_name not in UNTRACABLE_DATAFRAME_PIPES: + result_pipe = result_pipe.trace_as_dataframe() + + return result_pipe + + function = functools.partial( + class_function, cls_to_register, enable_df_api_tracing + ) + functools.update_wrapper( + wrapper=function, wrapped=cls_to_register, assigned=("__doc__",) + ) + cls.functions[function_name] = function + + def __getstate__(self): + """ + Serialize `lambda` functions when `dill` is available. + + If this doesn't cover your custom DataPipe's use case, consider writing custom methods for + `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization. + """ + state = self.__dict__ + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __reduce_ex__(self, *args, **kwargs): + if IterDataPipe.reduce_ex_hook is not None: + try: + return IterDataPipe.reduce_ex_hook(self) + except NotImplementedError: + pass + return super().__reduce_ex__(*args, **kwargs) + + @classmethod + def set_getstate_hook(cls, hook_fn) -> None: + if IterDataPipe.getstate_hook is not None and hook_fn is not None: + raise RuntimeError("Attempt to override existing getstate_hook") + IterDataPipe.getstate_hook = hook_fn + + @classmethod + def set_reduce_ex_hook(cls, hook_fn) -> None: + if IterDataPipe.reduce_ex_hook is not None and hook_fn is not None: + raise RuntimeError("Attempt to override existing reduce_ex_hook") + IterDataPipe.reduce_ex_hook = hook_fn + + def __repr__(self) -> str: + if self.repr_hook is not None: + return self.repr_hook(self) + # Instead of showing , return the class name + return str(self.__class__.__qualname__) + + def __str__(self) -> str: + if self.str_hook is not None: + return self.str_hook(self) + # Instead of showing , return the class name + return str(self.__class__.__qualname__) + + def __dir__(self): + # for auto-completion in a REPL (e.g. Jupyter notebook) + return list(super().__dir__()) + list(self.functions.keys()) + + def reset(self) -> None: + r""" + Reset the `IterDataPipe` to the initial state. + + By default, no-op. For subclasses of `IterDataPipe`, depending on their functionalities, + they may want to override this method with implementations that + may clear the buffers and reset pointers of the DataPipe. + The `reset` method is always called when `__iter__` is called as part of `hook_iterator`. + """ + + +class DFIterDataPipe(IterDataPipe): + def _is_dfpipe(self) -> bool: + return True + + +class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta): + r""" + Map-style DataPipe. + + All datasets that represent a map from keys to data samples should subclass this. + Subclasses should overwrite :meth:`__getitem__`, supporting fetching a + data sample for a given, unique key. Subclasses can also optionally overwrite + :meth:`__len__`, which is expected to return the size of the dataset by many + :class:`~torch.utils.data.Sampler` implementations and the default options + of :class:`~torch.utils.data.DataLoader`. + + These DataPipes can be invoked in two ways, using the class constructor or applying their + functional form onto an existing `MapDataPipe` (recommend, available to most but not all DataPipes). + + Note: + :class:`~torch.utils.data.DataLoader` by default constructs an index + sampler that yields integral indices. To make it work with a map-style + DataPipe with non-integral indices/keys, a custom sampler must be provided. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper, Mapper + >>> dp = SequenceWrapper(range(10)) + >>> map_dp_1 = dp.map(lambda x: x + 1) # Using functional form (recommended) + >>> list(map_dp_1) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> map_dp_2 = Mapper(dp, lambda x: x + 1) # Using class constructor + >>> list(map_dp_2) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> batch_dp = map_dp_1.batch(batch_size=2) + >>> list(batch_dp) + [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] + """ + + functions: dict[str, Callable] = {} + reduce_ex_hook: Callable | None = None + getstate_hook: Callable | None = None + str_hook: Callable | None = None + repr_hook: Callable | None = None + + def __getattr__(self, attribute_name): + if attribute_name in MapDataPipe.functions: + if attribute_name in _map_deprecated_functional_names: + kwargs = _map_deprecated_functional_names[attribute_name] + _deprecation_warning(**kwargs) + f = MapDataPipe.functions[attribute_name] + function = functools.partial(f, self) + functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",)) + return function + else: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{attribute_name}" + ) + + @classmethod + def register_function(cls, function_name, function) -> None: + cls.functions[function_name] = function + + @classmethod + def register_datapipe_as_function(cls, function_name, cls_to_register) -> None: + if function_name in cls.functions: + raise Exception( # noqa: TRY002 + f"Unable to add DataPipe function name {function_name} as it is already taken" + ) + + def class_function(cls, source_dp, *args, **kwargs): + result_pipe = cls(source_dp, *args, **kwargs) + return result_pipe + + function = functools.partial(class_function, cls_to_register) + functools.update_wrapper( + wrapper=function, wrapped=cls_to_register, assigned=("__doc__",) + ) + cls.functions[function_name] = function + + def __getstate__(self): + """ + Serialize `lambda` functions when `dill` is available. + + If this doesn't cover your custom DataPipe's use case, consider writing custom methods for + `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization. + """ + state = self.__dict__ + if MapDataPipe.getstate_hook is not None: + return MapDataPipe.getstate_hook(state) + return state + + def __reduce_ex__(self, *args, **kwargs): + if MapDataPipe.reduce_ex_hook is not None: + try: + return MapDataPipe.reduce_ex_hook(self) + except NotImplementedError: + pass + return super().__reduce_ex__(*args, **kwargs) + + @classmethod + def set_getstate_hook(cls, hook_fn) -> None: + if MapDataPipe.getstate_hook is not None and hook_fn is not None: + raise RuntimeError("Attempt to override existing getstate_hook") + MapDataPipe.getstate_hook = hook_fn + + @classmethod + def set_reduce_ex_hook(cls, hook_fn) -> None: + if MapDataPipe.reduce_ex_hook is not None and hook_fn is not None: + raise RuntimeError("Attempt to override existing reduce_ex_hook") + MapDataPipe.reduce_ex_hook = hook_fn + + def __repr__(self) -> str: + if self.repr_hook is not None: + return self.repr_hook(self) + # Instead of showing , return the class name + return str(self.__class__.__qualname__) + + def __str__(self) -> str: + if self.str_hook is not None: + return self.str_hook(self) + # Instead of showing , return the class name + return str(self.__class__.__qualname__) + + def __dir__(self): + # for auto-completion in a REPL (e.g. Jupyter notebook) + return list(super().__dir__()) + list(self.functions.keys()) + + +class _DataPipeSerializationWrapper: + def __init__(self, datapipe) -> None: + self._datapipe = datapipe + + def __getstate__(self): + use_dill = False + try: + value = pickle.dumps(self._datapipe) + except Exception: + if HAS_DILL: + # pyrefly: ignore [missing-attribute] + value = dill.dumps(self._datapipe) + use_dill = True + else: + raise + return (value, use_dill) + + def __setstate__(self, state): + value, use_dill = state + if use_dill: + # pyrefly: ignore [missing-attribute] + self._datapipe = dill.loads(value) + else: + self._datapipe = pickle.loads(value) + + def __len__(self) -> int: + try: + return len(self._datapipe) + except Exception as e: + raise TypeError( + f"{type(self).__name__} instance doesn't have valid length" + ) from e + + +class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe): + def __init__(self, datapipe: IterDataPipe[_T_co]) -> None: + super().__init__(datapipe) + # pyrefly: ignore [invalid-type-var] + self._datapipe_iter: Iterator[_T_co] | None = None + + def __iter__(self) -> "_IterDataPipeSerializationWrapper": + self._datapipe_iter = iter(self._datapipe) + return self + + def __next__(self) -> _T_co: # type: ignore[type-var] + if self._datapipe_iter is None: + raise AssertionError( + "Iterator has not been initialized; call __iter__() before __next__()" + ) + return next(self._datapipe_iter) + + +class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe): + def __getitem__(self, idx): + return self._datapipe[idx] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.pyi b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.pyi new file mode 100644 index 0000000000000000000000000000000000000000..7f49cc212383b2a635c36e1dc96c040d1d63868d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.pyi @@ -0,0 +1,746 @@ +# @generated by torch/utils/data/datapipes/gen_pyi.py from datapipe.pyi.in +# mypy: allow-untyped-defs +# This base template ("datapipe.pyi.in") is generated from mypy stubgen with minimal editing for code injection +# The output file will be "datapipe.pyi". This is executed as part of torch/CMakeLists.txt +# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other +# classes/objects here, even though we are not injecting extra code into them at the moment. + +from collections.abc import Callable, Iterable, Iterator +from typing import Any, Literal, TypeVar + +from torch.utils.data import Dataset, default_collate, IterableDataset +from torch.utils.data.datapipes._hook_iterator import _SnapshotState +from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +UNTRACABLE_DATAFRAME_PIPES: Any + +class DataChunk(list[_T]): + items: list[_T] + def __init__(self, items: Iterable[_T]) -> None: ... + def as_str(self, indent: str = "") -> str: ... + def __iter__(self) -> Iterator[_T]: ... + def raw_iterator(self) -> Iterator[_T]: ... + +class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta): + functions: dict[str, Callable] = ... + reduce_ex_hook: Callable | None = ... + getstate_hook: Callable | None = ... + str_hook: Callable | None = ... + repr_hook: Callable | None = ... + def __getattr__(self, attribute_name: Any): ... + @classmethod + def register_function(cls, function_name: Any, function: Any) -> None: ... + @classmethod + def register_datapipe_as_function( + cls, + function_name: Any, + cls_to_register: Any, + ): ... + def __getstate__(self): ... + def __reduce_ex__(self, *args: Any, **kwargs: Any): ... + @classmethod + def set_getstate_hook(cls, hook_fn: Any) -> None: ... + @classmethod + def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ... + # Functional form of 'BatcherMapDataPipe' + def batch( + self, + batch_size: int, + drop_last: bool = False, + wrapper_class: type[DataChunk] = DataChunk, + ) -> MapDataPipe: + r""" + Create mini-batches of data (functional name: ``batch``). + + An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, + or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``. + + Args: + datapipe: Iterable DataPipe being batched + batch_size: The size of each batch + drop_last: Option to drop the last batch if it's not full + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp = SequenceWrapper(range(10)) + >>> batch_dp = dp.batch(batch_size=2) + >>> list(batch_dp) + [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] + """ + # Functional form of 'ConcaterMapDataPipe' + def concat(self, *datapipes: MapDataPipe) -> MapDataPipe: + r""" + Concatenate multiple Map DataPipes (functional name: ``concat``). + + The new index of is the cumulative sum of source DataPipes. + For example, if there are 2 source DataPipes both with length 5, + index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to + elements of the first DataPipe, and 5 to 9 would refer to elements + of the second DataPipe. + + Args: + datapipes: Map DataPipes being concatenated + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp1 = SequenceWrapper(range(3)) + >>> dp2 = SequenceWrapper(range(3)) + >>> concat_dp = dp1.concat(dp2) + >>> list(concat_dp) + [0, 1, 2, 0, 1, 2] + """ + # Functional form of 'MapperMapDataPipe' + def map(self, fn: Callable = ...) -> MapDataPipe: + r""" + Apply the input function over each item from the source DataPipe (functional name: ``map``). + + The function can be any regular Python function or partial object. Lambda + function is not recommended as it is not supported by pickle. + + Args: + datapipe: Source MapDataPipe + fn: Function being applied to each item + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper, Mapper + >>> def add_one(x): + ... return x + 1 + >>> dp = SequenceWrapper(range(10)) + >>> map_dp_1 = dp.map(add_one) + >>> list(map_dp_1) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> map_dp_2 = Mapper(dp, lambda x: x + 1) + >>> list(map_dp_2) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + """ + # Functional form of 'ShufflerIterDataPipe' + def shuffle(self, *, indices: list | None = None) -> IterDataPipe: + r""" + Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``). + + When it is used with :class:`~torch.utils.data.DataLoader`, the methods to + set up random seed are different based on :attr:`num_workers`. + + For single-process mode (:attr:`num_workers == 0`), the random seed is set before + the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process + mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed + for each worker process. + + Args: + datapipe: MapDataPipe being shuffled + indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp = SequenceWrapper(range(10)) + >>> shuffle_dp = dp.shuffle().set_seed(0) + >>> list(shuffle_dp) + [7, 8, 1, 5, 3, 4, 2, 0, 9, 6] + >>> list(shuffle_dp) + [6, 1, 9, 5, 2, 4, 7, 3, 8, 0] + >>> # Reset seed for Shuffler + >>> shuffle_dp = shuffle_dp.set_seed(0) + >>> list(shuffle_dp) + [7, 8, 1, 5, 3, 4, 2, 0, 9, 6] + + Note: + Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an + ``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to + the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order + of data during data-processing. + """ + # Functional form of 'ZipperMapDataPipe' + def zip(self, *datapipes: MapDataPipe[_T_co]) -> MapDataPipe: + r""" + Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``). + + This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted. + + Args: + *datapipes: Map DataPipes being aggregated + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp1 = SequenceWrapper(range(3)) + >>> dp2 = SequenceWrapper(range(10, 13)) + >>> zip_dp = dp1.zip(dp2) + >>> list(zip_dp) + [(0, 10), (1, 11), (2, 12)] + """ + +class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): + functions: dict[str, Callable] = ... + reduce_ex_hook: Callable | None = ... + getstate_hook: Callable | None = ... + str_hook: Callable | None = ... + repr_hook: Callable | None = ... + _number_of_samples_yielded: int = ... + _snapshot_state: _SnapshotState = _SnapshotState.Iterating # noqa: PYI015 + _fast_forward_iterator: Iterator | None = ... + def __getattr__(self, attribute_name: Any): ... + @classmethod + def register_function(cls, function_name: Any, function: Any) -> None: ... + @classmethod + def register_datapipe_as_function( + cls, + function_name: Any, + cls_to_register: Any, + enable_df_api_tracing: bool = ..., + ): ... + def __getstate__(self): ... + def __reduce_ex__(self, *args: Any, **kwargs: Any): ... + @classmethod + def set_getstate_hook(cls, hook_fn: Any) -> None: ... + @classmethod + def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ... + # Functional form of 'BatcherIterDataPipe' + def batch( + self, + batch_size: int, + drop_last: bool = False, + wrapper_class: type[DataChunk] = DataChunk, + ) -> IterDataPipe: + r""" + Creates mini-batches of data (functional name: ``batch``). + + An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, or ``length % batch_size`` for the + last batch if ``drop_last`` is set to ``False``. + + Args: + datapipe: Iterable DataPipe being batched + batch_size: The size of each batch + drop_last: Option to drop the last batch if it's not full + wrapper_class: wrapper to apply onto each batch (type ``List``) before yielding, + defaults to ``DataChunk`` + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp = IterableWrapper(range(10)) + >>> dp = dp.batch(batch_size=3, drop_last=True) + >>> list(dp) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + # Functional form of 'CollatorIterDataPipe' + def collate( + self, + conversion: Callable[..., Any]| dict[str | Any, Callable | Any]| None = default_collate, + collate_fn: Callable | None = None, + ) -> IterDataPipe: # fmt: skip + r""" + Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``). + + By default, it uses :func:`torch.utils.data.default_collate`. + + .. note:: + While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the + default behavior and `functools.partial` to specify any additional arguments. + + Args: + datapipe: Iterable DataPipe being collated + collate_fn: Customized collate function to collect and combine data or a batch of data. + Default function collates to Tensor(s) based on data type. + + Example: + >>> # xdoctest: +SKIP + >>> # Convert integer data to float Tensor + >>> class MyIterDataPipe(torch.utils.data.IterDataPipe): + ... def __init__(self, start, end): + ... super(MyIterDataPipe).__init__() + ... assert end > start, "this example only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... return iter(range(self.start, self.end)) + ... + ... def __len__(self): + ... return self.end - self.start + >>> ds = MyIterDataPipe(start=3, end=7) + >>> print(list(ds)) + [3, 4, 5, 6] + >>> def collate_fn(batch): + ... return torch.tensor(batch, dtype=torch.float) + >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) + >>> print(list(collated_ds)) + [tensor(3.), tensor(4.), tensor(5.), tensor(6.)] + """ + # Functional form of 'ConcaterIterDataPipe' + def concat(self, *datapipes: IterDataPipe) -> IterDataPipe: + r""" + Concatenates multiple Iterable DataPipes (functional name: ``concat``). + + The resulting DataPipe will yield all the elements from the first input DataPipe, before yielding from the subsequent ones. + + Args: + datapipes: Iterable DataPipes being concatenated + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> import random + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp1 = IterableWrapper(range(3)) + >>> dp2 = IterableWrapper(range(5)) + >>> list(dp1.concat(dp2)) + [0, 1, 2, 0, 1, 2, 3, 4] + """ + # Functional form of 'DemultiplexerIterDataPipe' + def demux( + self, + num_instances: int, + classifier_fn: Callable[[_T_co], int | None], + drop_none: bool = False, + buffer_size: int = 1000, + ) -> list[IterDataPipe]: + r""" + Splits the input DataPipe into multiple child DataPipes, using the given classification function (functional name: ``demux``). + + A list of the child DataPipes is returned from this operation. + + Args: + datapipe: Iterable DataPipe being filtered + num_instances: number of instances of the DataPipe to create + classifier_fn: a function that maps values to an integer within the range ``[0, num_instances - 1]`` or ``None`` + drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None`` + buffer_size: this defines the maximum number of inputs that the buffer can hold across all child + DataPipes while waiting for their values to be yielded. + Defaults to ``1000``. Use ``-1`` for the unlimited buffer. + + Examples: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> def odd_or_even(n): + ... return n % 2 + >>> source_dp = IterableWrapper(range(5)) + >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even) + >>> list(dp1) + [0, 2, 4] + >>> list(dp2) + [1, 3] + >>> # It can also filter out any element that gets `None` from the `classifier_fn` + >>> def odd_or_even_no_zero(n): + ... return n % 2 if n != 0 else None + >>> dp1, dp2 = source_dp.demux( + ... num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True + ... ) + >>> list(dp1) + [2, 4] + >>> list(dp2) + [1, 3] + """ + # Functional form of 'FilterIterDataPipe' + def filter(self, filter_fn: Callable, input_col=None) -> IterDataPipe: + r""" + Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``). + + Args: + datapipe: Iterable DataPipe being filtered + filter_fn: Customized function mapping an element to a boolean. + input_col: Index or indices of data which ``filter_fn`` is applied, such as: + + - ``None`` as default to apply ``filter_fn`` to the data directly. + - Integer(s) is used for list/tuple. + - Key(s) is used for dict. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> def is_even(n): + ... return n % 2 == 0 + >>> dp = IterableWrapper(range(5)) + >>> filter_dp = dp.filter(filter_fn=is_even) + >>> list(filter_dp) + [0, 2, 4] + """ + # Functional form of 'ForkerIterDataPipe' + def fork( + self, + num_instances: int, + buffer_size: int = 1000, + copy: Literal["shallow", "deep"] | None = None, + ) -> list[IterDataPipe]: + r""" + Creates multiple instances of the same Iterable DataPipe (functional name: ``fork``). + + Args: + datapipe: Iterable DataPipe being copied + num_instances: number of instances of the datapipe to create + buffer_size: this restricts how far ahead the leading child DataPipe + can read relative to the slowest child DataPipe. + Defaults to ``1000``. Use ``-1`` for the unlimited buffer. + copy: copy strategy to use for items yielded by each branch. Supported + options are ``None`` for no copying, ``"shallow"`` for shallow object + copies, and ``"deep"`` for deep object copies. Defaults to ``None``. + + Note: + All branches of the forked pipeline return the identical object unless + the copy parameter is supplied. If the object is mutable or contains + mutable objects, changing them in one branch will affect all others. + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> source_dp = IterableWrapper(range(5)) + >>> dp1, dp2 = source_dp.fork(num_instances=2) + >>> list(dp1) + [0, 1, 2, 3, 4] + >>> list(dp2) + [0, 1, 2, 3, 4] + """ + # Functional form of 'GrouperIterDataPipe' + def groupby( + self, + group_key_fn: Callable[[_T_co], Any], + *, + keep_key: bool = False, + buffer_size: int = 10000, + group_size: int | None = None, + guaranteed_group_size: int | None = None, + drop_remaining: bool = False, + ) -> IterDataPipe: + r""" + Groups data from IterDataPipe by keys from ``group_key_fn``, yielding a ``DataChunk`` with batch size up to ``group_size``. + + (functional name: ``groupby``). + + The samples are read sequentially from the source ``datapipe``, and a batch of samples belonging to the same group + will be yielded as soon as the size of the batch reaches ``group_size``. When the buffer is full, + the DataPipe will yield the largest batch with the same key, provided that its size is larger + than ``guaranteed_group_size``. If its size is smaller, it will be dropped if ``drop_remaining=True``. + + After iterating through the entirety of source ``datapipe``, everything not dropped due to the buffer capacity + will be yielded from the buffer, even if the group sizes are smaller than ``guaranteed_group_size``. + + Args: + datapipe: Iterable datapipe to be grouped + group_key_fn: Function used to generate group key from the data of the source datapipe + keep_key: Option to yield the matching key along with the items in a tuple, + resulting in `(key, [items])` otherwise returning [items] + buffer_size: The size of buffer for ungrouped data + group_size: The max size of each group, a batch is yielded as soon as it reaches this size + guaranteed_group_size: The guaranteed minimum group size to be yielded in case the buffer is full + drop_remaining: Specifies if the group smaller than ``guaranteed_group_size`` will be dropped from buffer + when the buffer is full + + Example: + >>> import os + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> def group_fn(file): + ... return os.path.basename(file).split(".")[0] + >>> source_dp = IterableWrapper( + ... ["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"] + ... ) + >>> dp0 = source_dp.groupby(group_key_fn=group_fn) + >>> list(dp0) + [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']] + >>> # A group is yielded as soon as its size equals to `group_size` + >>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2) + >>> list(dp1) + [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] + >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size` + >>> dp2 = source_dp.groupby( + ... group_key_fn=group_fn, + ... buffer_size=3, + ... group_size=3, + ... guaranteed_group_size=2, + ... ) + >>> list(dp2) + [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] + """ + # Functional form of 'FileListerIterDataPipe' + def list_files( + self, + masks: str | list[str] = "", + *, + recursive: bool = False, + abspath: bool = False, + non_deterministic: bool = False, + length: int = -1, + ) -> IterDataPipe: + r""" + Given path(s) to the root directory, yields file pathname(s) (path + filename) of files within the root directory. + + Multiple root directories can be provided (functional name: ``list_files``). + + Args: + root: Root directory or a sequence of root directories + masks: Unix style filter string or string list for filtering file name(s) + recursive: Whether to return pathname from nested directories or not + abspath: Whether to return relative pathname or absolute pathname + non_deterministic: Whether to return pathname in sorted order or not. + If ``False``, the results yielded from each root directory will be sorted + length: Nominal length of the datapipe + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import FileLister + >>> dp = FileLister(root=".", recursive=True) + >>> list(dp) + ['example.py', './data/data.tar'] + """ + # Functional form of 'MapperIterDataPipe' + def map( + self, + fn: Callable, + input_col=None, + output_col=None, + ) -> IterDataPipe: + r""" + Applies a function over each item from the source DataPipe (functional name: ``map``). + + The function can be any regular Python function or partial object. Lambda + function is not recommended as it is not supported by pickle. + + Args: + datapipe: Source Iterable DataPipe + fn: Function being applied over each item + input_col: Index or indices of data which ``fn`` is applied, such as: + + - ``None`` as default to apply ``fn`` to the data directly. + - Integer(s) is used for list/tuple. + - Key(s) is used for dict. + + output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified + only when ``input_col`` is not ``None`` + + - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with + multiple indices, the left-most one is used, and other indices will be removed. + - Integer is used for list/tuple. ``-1`` represents to append result at the end. + - Key is used for dict. New key is acceptable. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper, Mapper + >>> def add_one(x): + ... return x + 1 + >>> dp = IterableWrapper(range(10)) + >>> # Invocation via functional form is preferred + ... map_dp_1 = dp.map(add_one) + >>> list(map_dp_1) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle` + >>> # Use `functools.partial` or explicitly define the function instead + >>> map_dp_2 = Mapper(dp, lambda x: x + 1) + >>> list(map_dp_2) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + """ + # Functional form of 'MultiplexerIterDataPipe' + def mux(self, *datapipes) -> IterDataPipe: + r""" + Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``). + + As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration, + and so on. It ends when the shortest input DataPipe is exhausted. + + Args: + datapipes: Iterable DataPipes that will take turn to yield their elements, until the shortest DataPipe is exhausted + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp1, dp2, dp3 = ( + ... IterableWrapper(range(3)), + ... IterableWrapper(range(10, 15)), + ... IterableWrapper(range(20, 25)), + ... ) + >>> list(dp1.mux(dp2, dp3)) + [0, 10, 20, 1, 11, 21, 2, 12, 22] + """ + # Functional form of 'FileOpenerIterDataPipe' + def open_files( + self, + mode: str = "r", + encoding: str | None = None, + length: int = -1, + ) -> IterDataPipe: + r""" + Given pathnames, opens files and yield pathname and file stream in a tuple (functional name: ``open_files``). + + Args: + datapipe: Iterable datapipe that provides pathnames + mode: An optional string that specifies the mode in which + the file is opened by ``open()``. It defaults to ``r``, other options are + ``b`` for reading in binary mode and ``t`` for text mode. + encoding: An optional string that specifies the encoding of the + underlying file. It defaults to ``None`` to match the default encoding of ``open``. + length: Nominal length of the datapipe + + Note: + The opened file handles will be closed by Python's GC periodically. Users can choose + to close them explicitly. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import ( + ... FileLister, + ... FileOpener, + ... StreamReader, + ... ) + >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith(".txt")) + >>> dp = FileOpener(dp) + >>> dp = StreamReader(dp) + >>> list(dp) + [('./abc.txt', 'abc')] + """ + # Functional form of 'StreamReaderIterDataPipe' + def read_from_stream(self, chunk: int | None = None) -> IterDataPipe: + r""" + Given IO streams and their label names, yield bytes with label name as tuple. + + (functional name: ``read_from_stream``). + + Args: + datapipe: Iterable DataPipe provides label/URL and byte stream + chunk: Number of bytes to be read from stream per iteration. + If ``None``, all bytes will be read until the EOF. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper, StreamReader + >>> from io import StringIO + >>> dp = IterableWrapper([("alphabet", StringIO("abcde"))]) + >>> list(StreamReader(dp, chunk=1)) + [('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')] + """ + # Functional form of 'RoutedDecoderIterDataPipe' + def routed_decode( + self, + *handlers: Callable, + key_fn: Callable = ..., + ) -> IterDataPipe: + r""" + Decodes binary streams from input DataPipe, yields pathname and decoded data in a tuple. + + (functional name: ``routed_decode``) + + Args: + datapipe: Iterable datapipe that provides pathname and binary stream in tuples + handlers: Optional user defined decoder handlers. If ``None``, basic and image decoder + handlers will be set as default. If multiple handles are provided, the priority + order follows the order of handlers (the first handler has the top priority) + key_fn: Function for decoder to extract key from pathname to dispatch handlers. + Default is set to extract file extension from pathname + + Note: + When ``key_fn`` is specified returning anything other than extension, the default + handler will not work and users need to specify custom handler. Custom handler + could use regex to determine the eligibility to handle data. + """ + # Functional form of 'ShardingFilterIterDataPipe' + def sharding_filter(self, sharding_group_filter=None) -> IterDataPipe: + r""" + Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``). + + After ``apply_sharding`` is called, each instance of the DataPipe (on different workers) will have every `n`-th element of the + original DataPipe, where `n` equals to the number of instances. + + Args: + source_datapipe: Iterable DataPipe that will be sharded + """ + # Functional form of 'ShufflerIterDataPipe' + def shuffle( + self, + *, + buffer_size: int = 10000, + unbatch_level: int = 0, + ) -> IterDataPipe: + r""" + Shuffle the input DataPipe with a buffer (functional name: ``shuffle``). + + The buffer with ``buffer_size`` is filled with elements from the datapipe first. Then, + each item will be yielded from the buffer by reservoir sampling via iterator. + + ``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the + datapipe is not shuffled. In order to fully shuffle all elements from datapipe, + ``buffer_size`` is required to be greater than or equal to the size of datapipe. + + When it is used with :class:`torch.utils.data.DataLoader`, the methods to + set up random seed are different based on :attr:`num_workers`. + + For single-process mode (:attr:`num_workers == 0`), the random seed is set before + the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process + mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed + for each worker process. + + Args: + datapipe: The IterDataPipe being shuffled + buffer_size: The buffer size for shuffling (default to ``10000``) + unbatch_level: Specifies if it is necessary to unbatch source data before + applying the shuffle + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp = IterableWrapper(range(10)) + >>> shuffle_dp = dp.shuffle() + >>> list(shuffle_dp) + [0, 4, 1, 6, 3, 2, 9, 5, 7, 8] + """ + # Functional form of 'UnBatcherIterDataPipe' + def unbatch(self, unbatch_level: int = 1) -> IterDataPipe: + r""" + Undos batching of data (functional name: ``unbatch``). + + In other words, it flattens the data up to the specified level within a batched DataPipe. + + Args: + datapipe: Iterable DataPipe being un-batched + unbatch_level: Defaults to ``1`` (only flattening the top level). If set to ``2``, + it will flatten the top two levels, and ``-1`` will flatten the entire DataPipe. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> source_dp = IterableWrapper([[[0, 1], [2]], [[3, 4], [5]], [[6]]]) + >>> dp1 = source_dp.unbatch() + >>> list(dp1) + [[0, 1], [2], [3, 4], [5], [6]] + >>> dp2 = source_dp.unbatch(unbatch_level=2) + >>> list(dp2) + [0, 1, 2, 3, 4, 5, 6] + """ + # Functional form of 'ZipperIterDataPipe' + def zip(self, *datapipes: IterDataPipe) -> IterDataPipe: + r""" + Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``). + + The output is stopped as soon as the shortest input DataPipe is exhausted. + + Args: + *datapipes: Iterable DataPipes being aggregated + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp1, dp2, dp3 = ( + ... IterableWrapper(range(5)), + ... IterableWrapper(range(10, 15)), + ... IterableWrapper(range(20, 25)), + ... ) + >>> list(dp1.zip(dp2, dp3)) + [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)] + """ + +class DFIterDataPipe(IterDataPipe): + def _is_dfpipe(self): ... + def __iter__(self): ... + +class _DataPipeSerializationWrapper: + def __init__(self, datapipe): ... + def __getstate__(self): ... + def __setstate__(self, state): ... + def __len__(self): ... + +class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe): + def __iter__(self): ... + +class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe): + def __getitem__(self, idx): ... diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/gen_pyi.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/gen_pyi.py new file mode 100644 index 0000000000000000000000000000000000000000..90f9d80a2e7fef61459d525d32486211415ad3ed --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/gen_pyi.py @@ -0,0 +1,336 @@ +# mypy: allow-untyped-defs +import os +from collections import defaultdict +from pathlib import Path +from typing import Any +from typing_extensions import deprecated + + +try: + from torchgen.api.python import format_function_signature + from torchgen.utils import FileManager +except ImportError: + import sys + + REPO_ROOT = Path(__file__).absolute().parents[4] + sys.path.insert(0, str(REPO_ROOT)) + + from torchgen.api.python import format_function_signature + from torchgen.utils import FileManager + + if len(sys.path) > 0 and sys.path[0] == str(REPO_ROOT): + del sys.path[0] + + +__all__: list[str] = [] # not intended to expose any symbols + + +def __dir__() -> list[str]: + return [] # appease public API test + + +@deprecated( + "`torch.utils.data.datapipes.gen_pyi.materialize_lines` is deprecated and will be removed in the future.", + category=FutureWarning, +) +def materialize_lines(lines: list[str], indentation: int) -> str: + output = "" + new_line_with_indent = "\n" + " " * indentation + for i, line in enumerate(lines): + if i != 0: + output += new_line_with_indent + output += line.replace("\n", new_line_with_indent) + return output + + +@deprecated( + "`torch.utils.data.datapipes.gen_pyi.gen_from_template` is deprecated and will be removed in the future.", + category=FutureWarning, +) +def gen_from_template( + dir: str, + template_name: str, + output_name: str, + replacements: list[tuple[str, Any, int]], +) -> None: + template_path = os.path.join(dir, template_name) + output_path = os.path.join(dir, output_name) + + with open(template_path, encoding="utf-8") as f: + content = f.read() + for placeholder, lines, indentation in replacements: + with open(output_path, "w", encoding="utf-8") as f: + content = content.replace( + placeholder, materialize_lines(lines, indentation) + ) + f.write(content) + + +def find_file_paths(dir_paths: list[str], files_to_exclude: set[str]) -> set[str]: + """ + When given a path to a directory, returns the paths to the relevant files within it. + + This function does NOT recursive traverse to subdirectories. + """ + paths: set[str] = set() + for dir_path in dir_paths: + all_files = os.listdir(dir_path) + python_files = {fname for fname in all_files if ".py" == fname[-3:]} + filter_files = { + fname for fname in python_files if fname not in files_to_exclude + } + paths.update({os.path.join(dir_path, fname) for fname in filter_files}) + return paths + + +def extract_method_name(line: str) -> str: + """Extract method name from decorator in the form of "@functional_datapipe({method_name})".""" + if '("' in line: + start_token, end_token = '("', '")' + elif "('" in line: + start_token, end_token = "('", "')" + else: + raise RuntimeError( + f"Unable to find appropriate method name within line:\n{line}" + ) + start, end = line.find(start_token) + len(start_token), line.find(end_token) + return line[start:end] + + +def extract_class_name(line: str) -> str: + """Extract class name from class definition in the form of "class {CLASS_NAME}({Type}):".""" + start_token = "class " + end_token = "(" + start, end = line.find(start_token) + len(start_token), line.find(end_token) + return line[start:end] + + +def parse_datapipe_file( + file_path: str, +) -> tuple[dict[str, list[str]], dict[str, str], set[str], dict[str, list[str]]]: + """Given a path to file, parses the file and returns a dictionary of method names to function signatures.""" + method_to_signature, method_to_class_name, special_output_type = {}, {}, set() + doc_string_dict = defaultdict(list) + with open(file_path, encoding="utf-8") as f: + open_paren_count = 0 + method_name, class_name, signature = "", "", "" + skip = False + for line in f: + if line.count('"""') % 2 == 1: + skip = not skip + if skip or '"""' in line: # Saving docstrings + doc_string_dict[method_name].append(line) + continue + if "@functional_datapipe" in line: + method_name = extract_method_name(line) + doc_string_dict[method_name] = [] + continue + if method_name and "class " in line: + class_name = extract_class_name(line) + continue + if method_name and ("def __init__(" in line or "def __new__(" in line): + if "def __new__(" in line: + special_output_type.add(method_name) + open_paren_count += 1 + start = line.find("(") + len("(") + line = line[start:] + if open_paren_count > 0: + open_paren_count += line.count("(") + open_paren_count -= line.count(")") + if open_paren_count == 0: + end = line.rfind(")") + signature += line[:end] + method_to_signature[method_name] = process_signature(signature) + method_to_class_name[method_name] = class_name + method_name, class_name, signature = "", "", "" + elif open_paren_count < 0: + raise RuntimeError( + "open parenthesis count < 0. This shouldn't be possible." + ) + else: + signature += line.strip() + return ( + method_to_signature, + method_to_class_name, + special_output_type, + doc_string_dict, + ) + + +def parse_datapipe_files( + file_paths: set[str], +) -> tuple[dict[str, list[str]], dict[str, str], set[str], dict[str, list[str]]]: + methods_and_signatures = {} + methods_and_class_names = {} + methods_with_special_output_types = set() + methods_and_doc_strings = {} + for path in file_paths: + ( + method_to_signature, + method_to_class_name, + methods_needing_special_output_types, + doc_string_dict, + ) = parse_datapipe_file(path) + methods_and_signatures.update(method_to_signature) + methods_and_class_names.update(method_to_class_name) + methods_with_special_output_types.update(methods_needing_special_output_types) + methods_and_doc_strings.update(doc_string_dict) + return ( + methods_and_signatures, + methods_and_class_names, + methods_with_special_output_types, + methods_and_doc_strings, + ) + + +def split_outside_bracket(line: str, delimiter: str = ",") -> list[str]: + """Given a line of text, split it on comma unless the comma is within a bracket '[]'.""" + bracket_count = 0 + curr_token = "" + res = [] + for char in line: + if char == "[": + bracket_count += 1 + elif char == "]": + bracket_count -= 1 + elif char == delimiter and bracket_count == 0: + res.append(curr_token) + curr_token = "" + continue + curr_token += char + res.append(curr_token) + return res + + +def process_signature(line: str) -> list[str]: + """ + Clean up a given raw function signature. + + This includes removing the self-referential datapipe argument, default + arguments of input functions, newlines, and spaces. + """ + tokens: list[str] = split_outside_bracket(line) + for i, token in enumerate(tokens): + tokens[i] = token.strip(" ") + if token == "cls": + tokens[i] = "self" + elif i > 0 and ("self" == tokens[i - 1]) and (tokens[i][0] != "*"): + # Remove the datapipe after 'self' or 'cls' unless it has '*' + tokens[i] = "" + elif "Callable =" in token: # Remove default argument if it is a function + head = token.rpartition("=")[0] + tokens[i] = head.strip(" ") + " = ..." + tokens = [t for t in tokens if t != ""] + return tokens + + +def get_method_definitions( + file_path: str | list[str], + files_to_exclude: set[str], + deprecated_files: set[str], + default_output_type: str, + method_to_special_output_type: dict[str, str], + root: str = "", +) -> list[str]: + """ + #.pyi generation for functional DataPipes Process. + + # 1. Find files that we want to process (exclude the ones who don't) + # 2. Parse method name and signature + # 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces + """ + if root == "": + root = str(Path(__file__).parent.resolve()) + file_path = [file_path] if isinstance(file_path, str) else file_path + file_path = [os.path.join(root, path) for path in file_path] + file_paths = find_file_paths( + file_path, files_to_exclude=files_to_exclude.union(deprecated_files) + ) + ( + methods_and_signatures, + methods_and_class_names, + methods_w_special_output_types, + methods_and_doc_strings, + ) = parse_datapipe_files(file_paths) + + for fn_name in method_to_special_output_type: + if fn_name not in methods_w_special_output_types: + methods_w_special_output_types.add(fn_name) + + method_definitions = [] + for method_name, arguments in methods_and_signatures.items(): + class_name = methods_and_class_names[method_name] + if method_name in methods_w_special_output_types: + output_type = method_to_special_output_type[method_name] + else: + output_type = default_output_type + doc_string = "".join(methods_and_doc_strings[method_name]) + if doc_string == "": + doc_string = " ..." + else: + doc_string = "\n" + doc_string + definition = format_function_signature(method_name, arguments, output_type) + method_definitions.append( + f"# Functional form of '{class_name}'\n" + + definition.removesuffix("...").rstrip() # remove "..." + + doc_string, + ) + method_definitions.sort( + key=lambda s: s.split("\n")[1] + ) # sorting based on method_name + + return method_definitions + + +# Defined outside of main() so they can be imported by TorchData +iterDP_file_path: str = "iter" +iterDP_files_to_exclude: set[str] = {"__init__.py", "utils.py"} +iterDP_deprecated_files: set[str] = set() +iterDP_method_to_special_output_type: dict[str, str] = { + "demux": "list[IterDataPipe]", + "fork": "list[IterDataPipe]", +} + +mapDP_file_path: str = "map" +mapDP_files_to_exclude: set[str] = {"__init__.py", "utils.py"} +mapDP_deprecated_files: set[str] = set() +mapDP_method_to_special_output_type: dict[str, str] = {"shuffle": "IterDataPipe"} + + +def main() -> None: + """ + # Inject file into template datapipe.pyi.in. + + TODO: The current implementation of this script only generates interfaces for built-in methods. To generate + interface for user-defined DataPipes, consider changing `IterDataPipe.register_datapipe_as_function`. + """ + iter_method_definitions = get_method_definitions( + iterDP_file_path, + iterDP_files_to_exclude, + iterDP_deprecated_files, + "IterDataPipe", + iterDP_method_to_special_output_type, + ) + + map_method_definitions = get_method_definitions( + mapDP_file_path, + mapDP_files_to_exclude, + mapDP_deprecated_files, + "MapDataPipe", + mapDP_method_to_special_output_type, + ) + + path = Path(__file__).absolute().parent + fm = FileManager(install_dir=path, template_dir=path, dry_run=False) + fm.write_with_template( + "datapipe.pyi", + "datapipe.pyi.in", + lambda: { + "IterDataPipeMethods": iter_method_definitions, + "MapDataPipeMethods": map_method_definitions, + }, + ) + + +if __name__ == "__main__": + main() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05831250da468cc76e8c2cc8e4018373e8191951 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__init__.py @@ -0,0 +1,66 @@ +from torch.utils.data.datapipes.iter.callable import ( + CollatorIterDataPipe as Collator, + MapperIterDataPipe as Mapper, +) +from torch.utils.data.datapipes.iter.combinatorics import ( + SamplerIterDataPipe as Sampler, + ShufflerIterDataPipe as Shuffler, +) +from torch.utils.data.datapipes.iter.combining import ( + ConcaterIterDataPipe as Concater, + DemultiplexerIterDataPipe as Demultiplexer, + ForkerIterDataPipe as Forker, + MultiplexerIterDataPipe as Multiplexer, + ZipperIterDataPipe as Zipper, +) +from torch.utils.data.datapipes.iter.filelister import ( + FileListerIterDataPipe as FileLister, +) +from torch.utils.data.datapipes.iter.fileopener import ( + FileOpenerIterDataPipe as FileOpener, +) +from torch.utils.data.datapipes.iter.grouping import ( + BatcherIterDataPipe as Batcher, + GrouperIterDataPipe as Grouper, + UnBatcherIterDataPipe as UnBatcher, +) +from torch.utils.data.datapipes.iter.routeddecoder import ( + RoutedDecoderIterDataPipe as RoutedDecoder, +) +from torch.utils.data.datapipes.iter.selecting import FilterIterDataPipe as Filter +from torch.utils.data.datapipes.iter.sharding import ( + ShardingFilterIterDataPipe as ShardingFilter, +) +from torch.utils.data.datapipes.iter.streamreader import ( + StreamReaderIterDataPipe as StreamReader, +) +from torch.utils.data.datapipes.iter.utils import ( + IterableWrapperIterDataPipe as IterableWrapper, +) + + +__all__ = [ + "Batcher", + "Collator", + "Concater", + "Demultiplexer", + "FileLister", + "FileOpener", + "Filter", + "Forker", + "Grouper", + "IterableWrapper", + "Mapper", + "Multiplexer", + "RoutedDecoder", + "Sampler", + "ShardingFilter", + "Shuffler", + "StreamReader", + "UnBatcher", + "Zipper", +] + +# Please keep this list sorted +if __all__ != sorted(__all__): + raise AssertionError("__all__ is not sorted") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7142a69111e6fcefad372796fce5041389bf2829 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/callable.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/callable.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d42b7eecf42f62239267a627f4b0ecff8486d49 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/callable.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/combinatorics.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/combinatorics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1d61ebc0068f0cee0d0a29da03540cf165fa9fc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/combinatorics.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/combining.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/combining.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e708cc66b8f88fe9c65d269366a94d940cc3513e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/combining.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/filelister.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/filelister.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01b144aa5ca0add84ed5c320d7d4f9538c7ee6d0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/filelister.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/fileopener.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/fileopener.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb215d6f9d1081a7d2df28d807c7770a93df9a84 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/fileopener.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/grouping.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/grouping.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2767c9bc9f1c750a539466315c3c1d1699564aaa Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/grouping.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/routeddecoder.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/routeddecoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..842590716db99c314ed0636d3f138e566bbf05e7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/routeddecoder.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/selecting.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/selecting.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79d45a1dfb38f2044525862e1471ce0a5cfe6431 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/selecting.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/sharding.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/sharding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45e34a14563e2cf890ce7a5ff46a010583f59866 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/sharding.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/streamreader.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/streamreader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f0a9f6b88234ed33f188251f70bef0bb1145f06 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/streamreader.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69df291ec857b3272a96e37db32324a61d3791a0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/callable.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/callable.py new file mode 100644 index 0000000000000000000000000000000000000000..af1d9792c097b277c088bf03a5dd05c57ba75706 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/callable.py @@ -0,0 +1,244 @@ +# mypy: allow-untyped-defs +import functools +from collections import namedtuple +from collections.abc import Callable, Iterator, Sized +from typing import Any, TypeVar + +import torch +from torch.utils.data._utils.collate import default_collate +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.utils.common import ( + _check_unpickable_fn, + validate_input_col, +) + + +__all__ = [ + "CollatorIterDataPipe", + "MapperIterDataPipe", +] + + +_T_co = TypeVar("_T_co", covariant=True) + + +@functional_datapipe("map") +class MapperIterDataPipe(IterDataPipe[_T_co]): + r""" + Applies a function over each item from the source DataPipe (functional name: ``map``). + + The function can be any regular Python function or partial object. Lambda + function is not recommended as it is not supported by pickle. + + Args: + datapipe: Source Iterable DataPipe + fn: Function being applied over each item + input_col: Index or indices of data which ``fn`` is applied, such as: + + - ``None`` as default to apply ``fn`` to the data directly. + - Integer(s) is used for list/tuple. + - Key(s) is used for dict. + + output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified + only when ``input_col`` is not ``None`` + + - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with + multiple indices, the left-most one is used, and other indices will be removed. + - Integer is used for list/tuple. ``-1`` represents to append result at the end. + - Key is used for dict. New key is acceptable. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper, Mapper + >>> def add_one(x): + ... return x + 1 + >>> dp = IterableWrapper(range(10)) + >>> # Invocation via functional form is preferred + ... map_dp_1 = dp.map(add_one) + >>> list(map_dp_1) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle` + >>> # Use `functools.partial` or explicitly define the function instead + >>> map_dp_2 = Mapper(dp, lambda x: x + 1) + >>> list(map_dp_2) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + """ + + datapipe: IterDataPipe + fn: Callable + + def __init__( + self, + datapipe: IterDataPipe, + fn: Callable, + input_col=None, + output_col=None, + ) -> None: + torch._C._log_api_usage_once("python.data_pipes.map") + super().__init__() + self.datapipe = datapipe + + _check_unpickable_fn(fn) + self.fn = fn # type: ignore[assignment] + + self.input_col = input_col + if input_col is None and output_col is not None: + raise ValueError("`output_col` must be None when `input_col` is None.") + if isinstance(output_col, (list, tuple)): + if len(output_col) > 1: + raise ValueError("`output_col` must be a single-element list or tuple") + output_col = output_col[0] + self.output_col = output_col + validate_input_col(fn, input_col) + + def _apply_fn(self, data): + if self.input_col is None and self.output_col is None: + return self.fn(data) + + if self.input_col is None: + res = self.fn(data) + elif isinstance(self.input_col, (list, tuple)): + args = tuple(data[col] for col in self.input_col) + res = self.fn(*args) + else: + res = self.fn(data[self.input_col]) + + # Copy tuple to list and run in-place modification because tuple is immutable. + if isinstance(data, tuple): + t_flag = True + data = list(data) + else: + t_flag = False + + if self.output_col is None: + if isinstance(self.input_col, (list, tuple)): + data[self.input_col[0]] = res + for idx in sorted(self.input_col[1:], reverse=True): + del data[idx] + else: + # pyrefly: ignore [unsupported-operation] + data[self.input_col] = res + else: + if self.output_col == -1: + data.append(res) + else: + data[self.output_col] = res + + # Convert list back to tuple + return tuple(data) if t_flag else data + + def __iter__(self) -> Iterator[_T_co]: + for data in self.datapipe: + yield self._apply_fn(data) + + def __len__(self) -> int: + if isinstance(self.datapipe, Sized): + return len(self.datapipe) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + + +def _collate_helper(conversion, item): + # TODO(VitalyFedyunin): Verify that item is any sort of batch + if len(item.items) > 1: + # TODO(VitalyFedyunin): Compact all batch dataframes into one + raise RuntimeError("Only supports one DataFrame per batch") + df = item[0] + columns_name = df_wrapper.get_columns(df) + tuple_names: list = [] + tuple_values: list = [] + + for name in conversion: + if name not in columns_name: + raise RuntimeError("Conversion keys mismatch") + + for name in columns_name: + if name in conversion: + if not callable(conversion[name]): + raise RuntimeError( + "Collate (DF)DataPipe requires callable as dict values" + ) + collation_fn = conversion[name] + else: + # TODO(VitalyFedyunin): Add default collation into df_wrapper + try: + import torcharrow.pytorch as tap # type: ignore[import] + + collation_fn = tap.rec.Default() + except Exception as e: + raise RuntimeError( + "unable to import default collation function from the TorchArrow" + ) from e + + tuple_names.append(str(name)) + value = collation_fn(df[name]) + tuple_values.append(value) + + # TODO(VitalyFedyunin): We can dynamically extract types from the tuple_values here + # TODO(VitalyFedyunin): Instead of ignoring mypy error, make sure tuple_names is not empty + tpl_cls = namedtuple("CollateResult", tuple_names) # type: ignore[misc] + tuple = tpl_cls(*tuple_values) + return tuple + + +@functional_datapipe("collate") +class CollatorIterDataPipe(MapperIterDataPipe): + r""" + Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``). + + By default, it uses :func:`torch.utils.data.default_collate`. + + .. note:: + While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the + default behavior and `functools.partial` to specify any additional arguments. + + Args: + datapipe: Iterable DataPipe being collated + collate_fn: Customized collate function to collect and combine data or a batch of data. + Default function collates to Tensor(s) based on data type. + + Example: + >>> # xdoctest: +SKIP + >>> # Convert integer data to float Tensor + >>> class MyIterDataPipe(torch.utils.data.IterDataPipe): + ... def __init__(self, start, end): + ... super(MyIterDataPipe).__init__() + ... assert end > start, "this example only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... return iter(range(self.start, self.end)) + ... + ... def __len__(self): + ... return self.end - self.start + >>> ds = MyIterDataPipe(start=3, end=7) + >>> print(list(ds)) + [3, 4, 5, 6] + >>> def collate_fn(batch): + ... return torch.tensor(batch, dtype=torch.float) + >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) + >>> print(list(collated_ds)) + [tensor(3.), tensor(4.), tensor(5.), tensor(6.)] + """ + + def __init__( + self, + datapipe: IterDataPipe, + conversion: Callable[..., Any] + | dict[str | Any, Callable | Any] + | None = default_collate, + collate_fn: Callable | None = None, + ) -> None: + # TODO(VitalyFedyunin): Replace `Callable[..., Any]` with `Callable[[IColumn], Any]` + # TODO(VitalyFedyunin): Replace with `Dict[Union[str, IColumn], Union[Callable, Enum]]` + if collate_fn is not None: + super().__init__(datapipe, fn=collate_fn) + else: + if callable(conversion): + super().__init__(datapipe, fn=conversion) + else: + # TODO(VitalyFedyunin): Validate passed dictionary + collate_fn = functools.partial(_collate_helper, conversion) + super().__init__(datapipe, fn=collate_fn) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/combinatorics.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/combinatorics.py new file mode 100644 index 0000000000000000000000000000000000000000..79a774c5e63db9494c526a94b45ff5284e8e4ec1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/combinatorics.py @@ -0,0 +1,193 @@ +# mypy: allow-untyped-defs +import random +from collections.abc import Iterator, Sized +from typing import TypeVar + +import torch +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.sampler import Sampler, SequentialSampler + + +__all__ = [ + "SamplerIterDataPipe", + "ShufflerIterDataPipe", +] + + +_T_co = TypeVar("_T_co", covariant=True) + + +class SamplerIterDataPipe(IterDataPipe[_T_co]): + r""" + Generate sample elements using the provided ``Sampler`` (defaults to :class:`SequentialSampler`). + + Args: + datapipe: IterDataPipe to sample from + sampler: Sampler class to generate sample elements from input DataPipe. + Default is :class:`SequentialSampler` for IterDataPipe + """ + + datapipe: IterDataPipe + sampler: Sampler + + def __init__( + self, + datapipe: IterDataPipe, + sampler: type[Sampler] = SequentialSampler, + sampler_args: tuple | None = None, + sampler_kwargs: dict | None = None, + ) -> None: + if not isinstance(datapipe, Sized): + raise AssertionError( + "Sampler class requires input datapipe implemented `__len__`" + ) + super().__init__() + # pyrefly: ignore [bad-assignment] + self.datapipe = datapipe + self.sampler_args = () if sampler_args is None else sampler_args + self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs + self.sampler_kwargs["data_source"] = self.datapipe + self.sampler = sampler(*self.sampler_args, **self.sampler_kwargs) + + def __iter__(self) -> Iterator[_T_co]: + return iter(self.sampler) + + def __len__(self) -> int: + # Dataset has been tested as `Sized` + if isinstance(self.sampler, Sized): + return len(self.sampler) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + + +@functional_datapipe("shuffle") +class ShufflerIterDataPipe(IterDataPipe[_T_co]): + r""" + Shuffle the input DataPipe with a buffer (functional name: ``shuffle``). + + The buffer with ``buffer_size`` is filled with elements from the datapipe first. Then, + each item will be yielded from the buffer by reservoir sampling via iterator. + + ``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the + datapipe is not shuffled. In order to fully shuffle all elements from datapipe, + ``buffer_size`` is required to be greater than or equal to the size of datapipe. + + When it is used with :class:`torch.utils.data.DataLoader`, the methods to + set up random seed are different based on :attr:`num_workers`. + + For single-process mode (:attr:`num_workers == 0`), the random seed is set before + the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process + mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed + for each worker process. + + Args: + datapipe: The IterDataPipe being shuffled + buffer_size: The buffer size for shuffling (default to ``10000``) + unbatch_level: Specifies if it is necessary to unbatch source data before + applying the shuffle + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp = IterableWrapper(range(10)) + >>> shuffle_dp = dp.shuffle() + >>> list(shuffle_dp) + [0, 4, 1, 6, 3, 2, 9, 5, 7, 8] + """ + + datapipe: IterDataPipe[_T_co] + buffer_size: int + _buffer: list[_T_co] + _enabled: bool + _seed: int | None + _rng: random.Random + + def __init__( + self, + datapipe: IterDataPipe[_T_co], + *, + buffer_size: int = 10000, + unbatch_level: int = 0, + ) -> None: + super().__init__() + # TODO: Performance optimization + # buffer can be a fixed size and remove expensive `append()` and `len()` operations + self._buffer: list[_T_co] = [] + if buffer_size <= 0: + raise AssertionError("buffer_size should be larger than 0") + if unbatch_level == 0: + self.datapipe = datapipe + else: + self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level) + self.buffer_size = buffer_size + self._enabled = True + self._seed = None + self._rng = random.Random() + + def set_shuffle(self, shuffle=True): + self._enabled = shuffle + return self + + def set_seed(self, seed: int): + self._seed = seed + return self + + def __iter__(self) -> Iterator[_T_co]: + if not self._enabled: + yield from self.datapipe + else: + for x in self.datapipe: + if len(self._buffer) == self.buffer_size: + idx = self._rng.randint(0, len(self._buffer) - 1) + val, self._buffer[idx] = self._buffer[idx], x + yield val + else: + self._buffer.append(x) + while self._buffer: + idx = self._rng.randint(0, len(self._buffer) - 1) + yield self._buffer.pop(idx) + + def __len__(self) -> int: + if isinstance(self.datapipe, Sized): + return len(self.datapipe) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + + def reset(self) -> None: + self._buffer = [] + if self._enabled: + if self._seed is None: + self._seed = int(torch.empty((), dtype=torch.int64).random_().item()) + self._rng.seed(self._seed) + self._seed = None + + def __getstate__(self): + state = ( + self.datapipe, + self.buffer_size, + self._enabled, + self._seed, + self._buffer, + self._rng.getstate(), + self._valid_iterator_id, + self._number_of_samples_yielded, + ) + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __setstate__(self, state): + ( + self.datapipe, + self.buffer_size, + self._enabled, + self._seed, + self._buffer, + rng_state, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) = state + self._rng = random.Random() + self._rng.setstate(rng_state) + + def __del__(self) -> None: + self._buffer.clear() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/combining.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/combining.py new file mode 100644 index 0000000000000000000000000000000000000000..4915e4c3d7c52a2844d1c65ce3adcc089622b25f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/combining.py @@ -0,0 +1,715 @@ +# mypy: allow-untyped-defs +import copy as copymodule +import warnings +from abc import ABC, abstractmethod +from collections import deque +from collections.abc import Callable, Iterator, Sized +from typing import Any, Literal, TypeVar + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes._hook_iterator import _SnapshotState +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, StreamWrapper + + +__all__ = [ + "ConcaterIterDataPipe", + "DemultiplexerIterDataPipe", + "ForkerIterDataPipe", + "MultiplexerIterDataPipe", + "ZipperIterDataPipe", +] + + +_T_co = TypeVar("_T_co", covariant=True) + + +@functional_datapipe("concat") +class ConcaterIterDataPipe(IterDataPipe): + r""" + Concatenates multiple Iterable DataPipes (functional name: ``concat``). + + The resulting DataPipe will yield all the elements from the first input DataPipe, before yielding from the subsequent ones. + + Args: + datapipes: Iterable DataPipes being concatenated + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> import random + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp1 = IterableWrapper(range(3)) + >>> dp2 = IterableWrapper(range(5)) + >>> list(dp1.concat(dp2)) + [0, 1, 2, 0, 1, 2, 3, 4] + """ + + datapipes: tuple[IterDataPipe] + + def __init__(self, *datapipes: IterDataPipe) -> None: + if len(datapipes) == 0: + raise ValueError("Expected at least one DataPipe, but got nothing") + if not all(isinstance(dp, IterDataPipe) for dp in datapipes): + raise TypeError("Expected all inputs to be `IterDataPipe`") + self.datapipes = datapipes # type: ignore[assignment] + + def __iter__(self) -> Iterator: + for dp in self.datapipes: + yield from dp + + def __len__(self) -> int: + if all(isinstance(dp, Sized) for dp in self.datapipes): + # pyrefly: ignore [bad-argument-type] + return sum(len(dp) for dp in self.datapipes) + else: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + + +@functional_datapipe("fork") +class ForkerIterDataPipe(IterDataPipe): + r""" + Creates multiple instances of the same Iterable DataPipe (functional name: ``fork``). + + Args: + datapipe: Iterable DataPipe being copied + num_instances: number of instances of the datapipe to create + buffer_size: this restricts how far ahead the leading child DataPipe + can read relative to the slowest child DataPipe. + Defaults to ``1000``. Use ``-1`` for the unlimited buffer. + copy: copy strategy to use for items yielded by each branch. Supported + options are ``None`` for no copying, ``"shallow"`` for shallow object + copies, and ``"deep"`` for deep object copies. Defaults to ``None``. + + Note: + All branches of the forked pipeline return the identical object unless + the copy parameter is supplied. If the object is mutable or contains + mutable objects, changing them in one branch will affect all others. + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> source_dp = IterableWrapper(range(5)) + >>> dp1, dp2 = source_dp.fork(num_instances=2) + >>> list(dp1) + [0, 1, 2, 3, 4] + >>> list(dp2) + [0, 1, 2, 3, 4] + """ + + def __new__( + cls, + datapipe: IterDataPipe, + num_instances: int, + buffer_size: int = 1000, + copy: Literal["shallow", "deep"] | None = None, + ): + if num_instances < 1: + raise ValueError( + f"Expected `num_instances` larger than 0, but {num_instances} is found" + ) + if num_instances == 1: + return datapipe + container = _ForkerIterDataPipe(datapipe, num_instances, buffer_size, copy) # type: ignore[abstract] + return [_ChildDataPipe(container, i) for i in range(num_instances)] + + +class _ContainerTemplate(ABC): + r"""Abstract class for container ``DataPipes``. The followings are three required methods.""" + + @abstractmethod + def get_next_element_by_instance(self, instance_id: int): ... + + @abstractmethod + def is_every_instance_exhausted(self) -> bool: ... + + @abstractmethod + def reset(self) -> None: ... + + @abstractmethod + def get_length_by_instance(self, instance_id: int): + r"""Raise TypeError if it's not supposed to be implemented to support `list(datapipe)`.""" + + +def _no_op(x): + return x + + +class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate): + r""" + Container to hold instance-specific information on behalf of ForkerIterDataPipe. + + It tracks the state of its child DataPipes, maintains the buffer, and yields the next value + as requested by the child DataPipes. + """ + + def __init__( + self, + datapipe: IterDataPipe, + num_instances: int, + buffer_size: int = 1000, + copy: Literal["shallow", "deep"] | None = None, + ) -> None: + self.main_datapipe = datapipe + self._datapipe_iterator: Iterator[Any] | None = None + self.num_instances = num_instances + self.buffer: deque = deque() + self.buffer_size = buffer_size + if self.buffer_size < 0: + warnings.warn( + "Unlimited buffer size is set for `fork`, " + "please be aware of OOM at random places", + UserWarning, + stacklevel=2, + ) + if copy is None: + self.copy_fn = _no_op + elif copy == "shallow": + self.copy_fn = copymodule.copy + elif copy == "deep": + self.copy_fn = copymodule.deepcopy + else: + raise ValueError( + f"Unknown copy method `{copy}` requested, choose one of None, `shallow` or `deep`." + ) + + self.child_pointers: list[int] = [ + 0 + ] * num_instances # Indicate the indices of the next element to get + self.slowest_ptr = 0 # The index to read by the slowest child + self.leading_ptr = 0 # The index to read by the fastest child + self.end_ptr: int | None = None # The index to stop child + self._child_stop: list[bool] = [True for _ in range(num_instances)] + + def __len__(self) -> int: + # pyrefly: ignore [bad-argument-type] + return len(self.main_datapipe) + + def get_next_element_by_instance(self, instance_id: int): + if self._datapipe_iterator is None and self._child_stop[instance_id]: + self._datapipe_iterator = iter(self.main_datapipe) + self._snapshot_state = _SnapshotState.Iterating + for i in range(self.num_instances): + self._child_stop[i] = False + try: + while not self._child_stop[instance_id]: + self.child_pointers[instance_id] += 1 + if ( + self.end_ptr is not None + and self.child_pointers[instance_id] == self.end_ptr + ): + self._child_stop[instance_id] = True + break + # Use buffer + if self.buffer and self.child_pointers[instance_id] <= self.leading_ptr: + idx = self.child_pointers[instance_id] - self.slowest_ptr - 1 + return_val = self.buffer[idx] + else: # Retrieve one element from main datapipe + self.leading_ptr = self.child_pointers[instance_id] + try: + return_val = next(self._datapipe_iterator) # type: ignore[arg-type] + self.buffer.append(return_val) + except StopIteration: + self._child_stop[instance_id] = True + self._datapipe_iterator = None + self.end_ptr = self.leading_ptr + continue + if self.child_pointers[instance_id] == self.slowest_ptr + 1: + new_min = min( + self.child_pointers + ) # Can optimize by avoiding the call to min() + if self.slowest_ptr < new_min: + self.slowest_ptr = new_min + self.buffer.popleft() + if ( + self.buffer_size >= 0 + and self.leading_ptr > self.buffer_size + self.slowest_ptr + ): + raise BufferError( + "ForkerIterDataPipe buffer overflow," + + f"buffer size {self.buffer_size} is insufficient." + ) + + yield self.copy_fn(return_val) # type: ignore[possibly-undefined] + finally: + self._child_stop[instance_id] = True + # Cleanup _datapipe_iterator for the case that fork exits earlier + if all(self._child_stop): + self._datapipe_iterator = None + self._cleanup() + + def is_every_instance_exhausted(self) -> bool: + return self.end_ptr is not None and all(self._child_stop) + + def get_length_by_instance(self, instance_id: int) -> int: + # pyrefly: ignore [bad-argument-type] + return len(self.main_datapipe) + + def reset(self) -> None: + self._datapipe_iterator = None + self.buffer = deque() + self.child_pointers = [0] * self.num_instances + self.slowest_ptr = 0 + self.leading_ptr = 0 + self.end_ptr = None + self._child_stop = [True for _ in range(self.num_instances)] + + def __getstate__(self): + state = ( + self.main_datapipe, + self.num_instances, + self.buffer_size, + self.copy_fn, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __setstate__(self, state): + ( + self.main_datapipe, + self.num_instances, + self.buffer_size, + self.copy_fn, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) = state + self._datapipe_iterator = None + self.buffer = deque() + self.child_pointers = [0] * self.num_instances + self.slowest_ptr = 0 + self.leading_ptr = 0 + self.end_ptr = None + self._child_stop = [True for _ in range(self.num_instances)] + + def _cleanup(self) -> None: + while self.buffer: + d = self.buffer.popleft() + StreamWrapper.close_streams(d) + + def __del__(self) -> None: + self._cleanup() + + +class _ChildDataPipe(IterDataPipe): + r""" + Iterable Datapipe that is a child of a main DataPipe. + + The instance of this class will pass its instance_id to get the next value from its main DataPipe. + + Note: + ChildDataPipe, like all other IterDataPipe, follows the single iterator per IterDataPipe constraint. + Since ChildDataPipes share a common buffer, when an iterator is created for one of the ChildDataPipes, + the previous iterators for all ChildDataPipes must be invalidated, with the exception when a ChildDataPipe + hasn't had an iterator created from it since the last invalidation. See the example below. + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> # Singler Iterator per IteraDataPipe Invalidation + >>> from torchdata.datapipes.iter import IterableWrapper + >>> source_dp = IterableWrapper(range(10)) + >>> cdp1, cdp2 = source_dp.fork(num_instances=2) + >>> it1, it2 = iter(cdp1), iter(cdp2) + >>> it3 = iter(cdp1) + >>> # The line above invalidates `it1` and `it2`, and resets `ForkerIterDataPipe`. + >>> it4 = iter(cdp2) + >>> # The line above doesn't invalidate `it3`, because an iterator for `cdp2` hasn't been created since + >>> # the last invalidation. + + Args: + main_datapipe: Main DataPipe with a method 'get_next_element_by_instance(instance_id)' + instance_id: integer identifier of this instance + """ + + _is_child_datapipe: bool = True + + def __init__(self, main_datapipe: IterDataPipe, instance_id: int) -> None: + if not isinstance(main_datapipe, _ContainerTemplate): + raise AssertionError("main_datapipe must implement _ContainerTemplate") + + # pyrefly: ignore [bad-assignment] + self.main_datapipe: IterDataPipe = main_datapipe + self.instance_id = instance_id + + def __iter__(self): + # Note that the logic behind setting iterator ID and `reset` are handled within `hook_iterator` + # We want to separate the code for reset and yield, so that 'reset' executes before __next__ is called + return self.main_datapipe.get_next_element_by_instance(self.instance_id) + + def __len__(self) -> int: + return self.main_datapipe.get_length_by_instance(self.instance_id) + + # This method is called by `hook_iterator` in `_typing.py`. + def _set_main_datapipe_valid_iterator_id(self) -> int: + r""" + Update the valid iterator ID for both this DataPipe object and `main_datapipe`. + + `main_datapipe.reset()` is called when the ID is incremented to a new generation. + """ + # 1. First time any child iterator is created + if self.main_datapipe._valid_iterator_id is None: + self.main_datapipe._valid_iterator_id = 0 # type: ignore[attr-defined] + # 2. This instance was already in the same generation as `main_datapipe`, + # we need to increment the ID further by 1 + elif self.main_datapipe._valid_iterator_id == self._valid_iterator_id: # type: ignore[has-type] + self.main_datapipe._valid_iterator_id += 1 # type: ignore[attr-defined] + # Whenever a new generation of iterator is created, the `main_datapipe` must reset + if not self.main_datapipe.is_every_instance_exhausted(): + warnings.warn( + "Some child DataPipes are not exhausted when __iter__ is called. We are resetting " + "the buffer and each child DataPipe will read from the start again.", + UserWarning, + stacklevel=2, + ) + self.main_datapipe.reset() + # 3. Otherwise, the iterator is behind the others, so it will just need to catch up by setting + # the instance's iterator to match that of `main_datapipe` + self._valid_iterator_id = self.main_datapipe._valid_iterator_id + return self._valid_iterator_id + + # This method is called by `hook_iterator` in `_typing.py`. + def _check_valid_iterator_id(self, iterator_id) -> bool: + r"""Check the valid iterator ID against that of DataPipe object and that of `main_datapipe`.""" + return ( + iterator_id == self._valid_iterator_id + and iterator_id == self.main_datapipe._valid_iterator_id + ) + + +@functional_datapipe("demux") +class DemultiplexerIterDataPipe(IterDataPipe): + r""" + Splits the input DataPipe into multiple child DataPipes, using the given classification function (functional name: ``demux``). + + A list of the child DataPipes is returned from this operation. + + Args: + datapipe: Iterable DataPipe being filtered + num_instances: number of instances of the DataPipe to create + classifier_fn: a function that maps values to an integer within the range ``[0, num_instances - 1]`` or ``None`` + drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None`` + buffer_size: this defines the maximum number of inputs that the buffer can hold across all child + DataPipes while waiting for their values to be yielded. + Defaults to ``1000``. Use ``-1`` for the unlimited buffer. + + Examples: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> def odd_or_even(n): + ... return n % 2 + >>> source_dp = IterableWrapper(range(5)) + >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even) + >>> list(dp1) + [0, 2, 4] + >>> list(dp2) + [1, 3] + >>> # It can also filter out any element that gets `None` from the `classifier_fn` + >>> def odd_or_even_no_zero(n): + ... return n % 2 if n != 0 else None + >>> dp1, dp2 = source_dp.demux( + ... num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True + ... ) + >>> list(dp1) + [2, 4] + >>> list(dp2) + [1, 3] + """ + + def __new__( + cls, + datapipe: IterDataPipe, + num_instances: int, + classifier_fn: Callable[[_T_co], int | None], + drop_none: bool = False, + buffer_size: int = 1000, + ): + if num_instances < 1: + raise ValueError( + f"Expected `num_instances` larger than 0, but {num_instances} is found" + ) + + _check_unpickable_fn(classifier_fn) + + # When num_instances == 1, demux can be replaced by filter, + # but keep it as Demultiplexer for the sake of consistency + # like throwing Error when classification result is out of o range + container = _DemultiplexerIterDataPipe( + datapipe, num_instances, classifier_fn, drop_none, buffer_size + ) # type: ignore[abstract] + return [_ChildDataPipe(container, i) for i in range(num_instances)] + + +class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate): + r""" + Container to hold instance-specific information on behalf of DemultiplexerIterDataPipe. + + It tracks the state of its child DataPipes, maintains the buffer, classifies and yields the next correct value + as requested by the child DataPipes. + """ + + def __init__( + self, + datapipe: IterDataPipe[_T_co], + num_instances: int, + classifier_fn: Callable[[_T_co], int | None], + drop_none: bool, + buffer_size: int, + ) -> None: + # pyrefly: ignore [invalid-type-var] + self.main_datapipe = datapipe + self._datapipe_iterator: Iterator[Any] | None = None + self.num_instances = num_instances + self.buffer_size = buffer_size + if self.buffer_size < 0: + warnings.warn( + "Unlimited buffer size is set for `demux`, " + "please be aware of OOM at random places", + UserWarning, + stacklevel=2, + ) + self.current_buffer_usage = 0 + # pyrefly: ignore [invalid-type-var] + self.child_buffers: list[deque[_T_co]] = [deque() for _ in range(num_instances)] + # pyrefly: ignore [invalid-type-var] + self.classifier_fn = classifier_fn + self.drop_none = drop_none + self.main_datapipe_exhausted = False + self._child_stop: list[bool] = [True for _ in range(num_instances)] + + def _find_next(self, instance_id: int) -> _T_co: # type: ignore[type-var] + while True: + if self.main_datapipe_exhausted or self._child_stop[instance_id]: + raise StopIteration + if self._datapipe_iterator is None: + raise ValueError( + "_datapipe_iterator has not been set, likely because this private method is called directly " + "without invoking get_next_element_by_instance() first." + ) + value = next(self._datapipe_iterator) + classification = self.classifier_fn(value) + if classification is None and self.drop_none: + StreamWrapper.close_streams(value) + continue + if ( + classification is None + or classification >= self.num_instances + or classification < 0 + ): + raise ValueError( + f"Output of the classification fn should be between 0 and {self.num_instances - 1}. " + + f"{classification} is returned." + ) + if classification == instance_id: + return value + self.child_buffers[classification].append(value) + self.current_buffer_usage += 1 + if self.buffer_size >= 0 and self.current_buffer_usage > self.buffer_size: + raise BufferError( + f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.buffer_size} is insufficient." + ) + + def get_next_element_by_instance(self, instance_id: int): + if self._datapipe_iterator is None and self._child_stop[instance_id]: + self._datapipe_iterator = iter(self.main_datapipe) + self._snapshot_state = ( + _SnapshotState.Iterating + ) # This is necessary for the DataPipe to reset properly. + self.main_datapipe_exhausted = False + for i in range(self.num_instances): + self._child_stop[i] = False + + try: + while not self._child_stop[instance_id]: + if self.child_buffers[instance_id]: + self.current_buffer_usage -= 1 + yield self.child_buffers[instance_id].popleft() + else: + try: + yield self._find_next(instance_id) + except StopIteration: + self._child_stop[instance_id] = True + self.main_datapipe_exhausted = True + self._datapipe_iterator = None + finally: + self._child_stop[instance_id] = True + # Cleanup _datapipe_iterator for the case that demux exits earlier + if all(self._child_stop): + self._datapipe_iterator = None + if self.child_buffers[instance_id]: + self._cleanup(instance_id) + + def is_every_instance_exhausted(self) -> bool: + return self.main_datapipe_exhausted and all(self._child_stop) + + def get_length_by_instance(self, instance_id: int) -> int: + raise TypeError + + def reset(self) -> None: + self._datapipe_iterator = None + self.current_buffer_usage = 0 + self.child_buffers = [deque() for _ in range(self.num_instances)] + self._child_stop = [True for _ in range(self.num_instances)] + self.main_datapipe_exhausted = False + + def __getstate__(self): + state = ( + self.main_datapipe, + self.num_instances, + self.buffer_size, + self.classifier_fn, + self.drop_none, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __setstate__(self, state): + ( + self.main_datapipe, + self.num_instances, + self.buffer_size, + self.classifier_fn, + self.drop_none, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) = state + self._datapipe_iterator = None + self.current_buffer_usage = 0 + self.child_buffers = [deque() for _ in range(self.num_instances)] + self._child_stop = [True for _ in range(self.num_instances)] + self.main_datapipe_exhausted = False + + def _cleanup(self, instance_id: int | None = None) -> None: + ids = ( + range(self.num_instances) + if instance_id is None + else [ + instance_id, + ] + ) + for i in ids: + q = self.child_buffers[i] + while q: + d = q.popleft() + StreamWrapper.close_streams(d) + + def __del__(self) -> None: + self._cleanup() + + +@functional_datapipe("mux") +class MultiplexerIterDataPipe(IterDataPipe): + r""" + Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``). + + As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration, + and so on. It ends when the shortest input DataPipe is exhausted. + + Args: + datapipes: Iterable DataPipes that will take turn to yield their elements, until the shortest DataPipe is exhausted + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp1, dp2, dp3 = ( + ... IterableWrapper(range(3)), + ... IterableWrapper(range(10, 15)), + ... IterableWrapper(range(20, 25)), + ... ) + >>> list(dp1.mux(dp2, dp3)) + [0, 10, 20, 1, 11, 21, 2, 12, 22] + """ + + def __init__(self, *datapipes) -> None: + self.datapipes = datapipes + self.buffer: list = [] # Store values to be yielded only when every iterator provides one + + def __iter__(self): + iterators = [iter(x) for x in self.datapipes] + while iterators: + for it in iterators: + try: + value = next(it) + self.buffer.append(value) + except StopIteration: + self.buffer.clear() + return + yield from self.buffer + self.buffer.clear() + + def __len__(self) -> int: + if all(isinstance(dp, Sized) for dp in self.datapipes): + return min(len(dp) for dp in self.datapipes) * len(self.datapipes) + else: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + + def reset(self) -> None: + self.buffer = [] + + def __getstate__(self): + state = ( + self.datapipes, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __setstate__(self, state): + ( + self.datapipes, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) = state + self.buffer = [] + + def __del__(self) -> None: + self.buffer.clear() + + +@functional_datapipe("zip") +class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]): + r""" + Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``). + + The output is stopped as soon as the shortest input DataPipe is exhausted. + + Args: + *datapipes: Iterable DataPipes being aggregated + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp1, dp2, dp3 = ( + ... IterableWrapper(range(5)), + ... IterableWrapper(range(10, 15)), + ... IterableWrapper(range(20, 25)), + ... ) + >>> list(dp1.zip(dp2, dp3)) + [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)] + """ + + datapipes: tuple[IterDataPipe] + + def __init__(self, *datapipes: IterDataPipe) -> None: + if not all(isinstance(dp, IterDataPipe) for dp in datapipes): + raise TypeError( + "All inputs are required to be `IterDataPipe` for `ZipIterDataPipe`." + ) + super().__init__() + self.datapipes = datapipes # type: ignore[assignment] + + def __iter__(self) -> Iterator[tuple[_T_co]]: + iterators = [iter(datapipe) for datapipe in self.datapipes] + yield from zip(*iterators, strict=False) + + def __len__(self) -> int: + if all(isinstance(dp, Sized) for dp in self.datapipes): + # pyrefly: ignore [bad-argument-type] + return min(len(dp) for dp in self.datapipes) + else: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/filelister.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/filelister.py new file mode 100644 index 0000000000000000000000000000000000000000..352d3c01e12d278cb8e1308ee78feff0610808bc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/filelister.py @@ -0,0 +1,67 @@ +from collections.abc import Iterator, Sequence + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.iter.utils import IterableWrapperIterDataPipe +from torch.utils.data.datapipes.utils.common import get_file_pathnames_from_root + + +__all__ = ["FileListerIterDataPipe"] + + +@functional_datapipe("list_files") +class FileListerIterDataPipe(IterDataPipe[str]): + r""" + Given path(s) to the root directory, yields file pathname(s) (path + filename) of files within the root directory. + + Multiple root directories can be provided (functional name: ``list_files``). + + Args: + root: Root directory or a sequence of root directories + masks: Unix style filter string or string list for filtering file name(s) + recursive: Whether to return pathname from nested directories or not + abspath: Whether to return relative pathname or absolute pathname + non_deterministic: Whether to return pathname in sorted order or not. + If ``False``, the results yielded from each root directory will be sorted + length: Nominal length of the datapipe + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import FileLister + >>> dp = FileLister(root=".", recursive=True) + >>> list(dp) + ['example.py', './data/data.tar'] + """ + + def __init__( + self, + root: str | Sequence[str] | IterDataPipe = ".", + masks: str | list[str] = "", + *, + recursive: bool = False, + abspath: bool = False, + non_deterministic: bool = False, + length: int = -1, + ) -> None: + super().__init__() + if isinstance(root, str): + root = [root] + if not isinstance(root, IterDataPipe): + root = IterableWrapperIterDataPipe(root) + self.datapipe: IterDataPipe = root + self.masks: str | list[str] = masks + self.recursive: bool = recursive + self.abspath: bool = abspath + self.non_deterministic: bool = non_deterministic + self.length: int = length + + def __iter__(self) -> Iterator[str]: + for path in self.datapipe: + yield from get_file_pathnames_from_root( + path, self.masks, self.recursive, self.abspath, self.non_deterministic + ) + + def __len__(self) -> int: + if self.length == -1: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + return self.length diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/fileopener.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/fileopener.py new file mode 100644 index 0000000000000000000000000000000000000000..e77f7a4c8e660ec0e2ff5374fcdc8a30c474ea03 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/fileopener.py @@ -0,0 +1,79 @@ +from collections.abc import Iterable, Iterator +from io import IOBase + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.utils.common import get_file_binaries_from_pathnames + + +__all__ = [ + "FileOpenerIterDataPipe", +] + + +@functional_datapipe("open_files") +class FileOpenerIterDataPipe(IterDataPipe[tuple[str, IOBase]]): + r""" + Given pathnames, opens files and yield pathname and file stream in a tuple (functional name: ``open_files``). + + Args: + datapipe: Iterable datapipe that provides pathnames + mode: An optional string that specifies the mode in which + the file is opened by ``open()``. It defaults to ``r``, other options are + ``b`` for reading in binary mode and ``t`` for text mode. + encoding: An optional string that specifies the encoding of the + underlying file. It defaults to ``None`` to match the default encoding of ``open``. + length: Nominal length of the datapipe + + Note: + The opened file handles will be closed by Python's GC periodically. Users can choose + to close them explicitly. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import ( + ... FileLister, + ... FileOpener, + ... StreamReader, + ... ) + >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith(".txt")) + >>> dp = FileOpener(dp) + >>> dp = StreamReader(dp) + >>> list(dp) + [('./abc.txt', 'abc')] + """ + + def __init__( + self, + datapipe: Iterable[str], + mode: str = "r", + encoding: str | None = None, + length: int = -1, + ) -> None: + super().__init__() + self.datapipe: Iterable[str] = datapipe + self.mode: str = mode + self.encoding: str | None = encoding + + if self.mode not in ("b", "t", "rb", "rt", "r"): + raise ValueError(f"Invalid mode {mode}") + # TODO: enforce typing for each instance based on mode, otherwise + # `argument_validation` with this DataPipe may be potentially broken + + if "b" in mode and encoding is not None: + raise ValueError("binary mode doesn't take an encoding argument") + + self.length: int = length + + # Remove annotation due to 'IOBase' is a general type and true type + # is determined at runtime based on mode. Some `DataPipe` requiring + # a subtype would cause mypy error. + def __iter__(self) -> Iterator[tuple[str, IOBase]]: + yield from get_file_binaries_from_pathnames( + self.datapipe, self.mode, self.encoding + ) + + def __len__(self) -> int: + if self.length == -1: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + return self.length diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/grouping.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/grouping.py new file mode 100644 index 0000000000000000000000000000000000000000..b773f06823a768f07e5a5a528e2afc8b0467d548 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/grouping.py @@ -0,0 +1,326 @@ +# mypy: allow-untyped-defs +from collections import defaultdict +from collections.abc import Callable, Iterator, Sized +from typing import Any, NoReturn, TypeVar + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import DataChunk, IterDataPipe +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn + + +__all__ = [ + "BatcherIterDataPipe", + "GrouperIterDataPipe", + "UnBatcherIterDataPipe", +] + + +_T_co = TypeVar("_T_co", covariant=True) + + +def __getattr__(name: str) -> NoReturn: + raise AttributeError(f"module {__name__} has no attribute {name}") + + +@functional_datapipe("batch") +class BatcherIterDataPipe(IterDataPipe[DataChunk]): + r""" + Creates mini-batches of data (functional name: ``batch``). + + An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, or ``length % batch_size`` for the + last batch if ``drop_last`` is set to ``False``. + + Args: + datapipe: Iterable DataPipe being batched + batch_size: The size of each batch + drop_last: Option to drop the last batch if it's not full + wrapper_class: wrapper to apply onto each batch (type ``List``) before yielding, + defaults to ``DataChunk`` + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp = IterableWrapper(range(10)) + >>> dp = dp.batch(batch_size=3, drop_last=True) + >>> list(dp) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + + datapipe: IterDataPipe + batch_size: int + drop_last: bool + + def __init__( + self, + datapipe: IterDataPipe, + batch_size: int, + drop_last: bool = False, + wrapper_class: type[DataChunk] = DataChunk, + ) -> None: + if batch_size <= 0: + raise AssertionError("Batch size is required to be larger than 0!") + super().__init__() + self.datapipe = datapipe + self.batch_size = batch_size + self.drop_last = drop_last + self.wrapper_class = wrapper_class + + def __iter__(self) -> Iterator[DataChunk]: + batch: list = [] + for x in self.datapipe: + batch.append(x) + if len(batch) == self.batch_size: + yield self.wrapper_class(batch) + batch = [] + if len(batch) > 0: + if not self.drop_last: + yield self.wrapper_class(batch) + + def __len__(self) -> int: + if isinstance(self.datapipe, Sized): + if self.drop_last: + return len(self.datapipe) // self.batch_size + else: + return (len(self.datapipe) + self.batch_size - 1) // self.batch_size + else: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + + +@functional_datapipe("unbatch") +class UnBatcherIterDataPipe(IterDataPipe): + r""" + Undos batching of data (functional name: ``unbatch``). + + In other words, it flattens the data up to the specified level within a batched DataPipe. + + Args: + datapipe: Iterable DataPipe being un-batched + unbatch_level: Defaults to ``1`` (only flattening the top level). If set to ``2``, + it will flatten the top two levels, and ``-1`` will flatten the entire DataPipe. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> source_dp = IterableWrapper([[[0, 1], [2]], [[3, 4], [5]], [[6]]]) + >>> dp1 = source_dp.unbatch() + >>> list(dp1) + [[0, 1], [2], [3, 4], [5], [6]] + >>> dp2 = source_dp.unbatch(unbatch_level=2) + >>> list(dp2) + [0, 1, 2, 3, 4, 5, 6] + """ + + def __init__(self, datapipe: IterDataPipe, unbatch_level: int = 1) -> None: + self.datapipe = datapipe + self.unbatch_level = unbatch_level + + def __iter__(self): + for element in self.datapipe: + yield from self._dive(element, unbatch_level=self.unbatch_level) + + def _dive(self, element, unbatch_level): + if unbatch_level < -1: + raise ValueError("unbatch_level must be -1 or >= 0") + if unbatch_level == -1: + if isinstance(element, (list, DataChunk)): + for item in element: + yield from self._dive(item, unbatch_level=-1) + else: + yield element + elif unbatch_level == 0: + yield element + else: + if isinstance(element, (list, DataChunk)): + for item in element: + yield from self._dive(item, unbatch_level=unbatch_level - 1) + else: + raise IndexError( + f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe" + ) + + +@functional_datapipe("groupby") +class GrouperIterDataPipe(IterDataPipe[DataChunk]): + r""" + Groups data from IterDataPipe by keys from ``group_key_fn``, yielding a ``DataChunk`` with batch size up to ``group_size``. + + (functional name: ``groupby``). + + The samples are read sequentially from the source ``datapipe``, and a batch of samples belonging to the same group + will be yielded as soon as the size of the batch reaches ``group_size``. When the buffer is full, + the DataPipe will yield the largest batch with the same key, provided that its size is larger + than ``guaranteed_group_size``. If its size is smaller, it will be dropped if ``drop_remaining=True``. + + After iterating through the entirety of source ``datapipe``, everything not dropped due to the buffer capacity + will be yielded from the buffer, even if the group sizes are smaller than ``guaranteed_group_size``. + + Args: + datapipe: Iterable datapipe to be grouped + group_key_fn: Function used to generate group key from the data of the source datapipe + keep_key: Option to yield the matching key along with the items in a tuple, + resulting in `(key, [items])` otherwise returning [items] + buffer_size: The size of buffer for ungrouped data + group_size: The max size of each group, a batch is yielded as soon as it reaches this size + guaranteed_group_size: The guaranteed minimum group size to be yielded in case the buffer is full + drop_remaining: Specifies if the group smaller than ``guaranteed_group_size`` will be dropped from buffer + when the buffer is full + + Example: + >>> import os + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> def group_fn(file): + ... return os.path.basename(file).split(".")[0] + >>> source_dp = IterableWrapper( + ... ["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"] + ... ) + >>> dp0 = source_dp.groupby(group_key_fn=group_fn) + >>> list(dp0) + [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']] + >>> # A group is yielded as soon as its size equals to `group_size` + >>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2) + >>> list(dp1) + [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] + >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size` + >>> dp2 = source_dp.groupby( + ... group_key_fn=group_fn, + ... buffer_size=3, + ... group_size=3, + ... guaranteed_group_size=2, + ... ) + >>> list(dp2) + [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] + """ + + def __init__( + self, + datapipe: IterDataPipe[_T_co], + group_key_fn: Callable[[_T_co], Any], + *, + keep_key: bool = False, + buffer_size: int = 10000, + group_size: int | None = None, + guaranteed_group_size: int | None = None, + drop_remaining: bool = False, + ) -> None: + _check_unpickable_fn(group_key_fn) + # pyrefly: ignore [invalid-type-var] + self.datapipe = datapipe + # pyrefly: ignore [invalid-type-var] + self.group_key_fn = group_key_fn + + self.keep_key = keep_key + self.max_buffer_size = buffer_size + self.buffer_elements: defaultdict[Any, list] = defaultdict(list) + self.curr_buffer_size = 0 + self.group_size = group_size + self.guaranteed_group_size = None + if group_size is not None and buffer_size is not None: + if not (0 < group_size <= buffer_size): + raise AssertionError("group_size must be > 0 and <= buffer_size") + # pyrefly: ignore [bad-assignment] + self.guaranteed_group_size = group_size + if guaranteed_group_size is not None: + if group_size is None or not (0 < guaranteed_group_size <= group_size): + raise AssertionError( + "guaranteed_group_size must be > 0 and <= group_size and group_size must be set" + ) + # pyrefly: ignore [bad-assignment] + self.guaranteed_group_size = guaranteed_group_size + self.drop_remaining = drop_remaining + self.wrapper_class = DataChunk + + def _remove_biggest_key(self): + biggest_key = None + biggest_size = 0 + result_to_yield = None + for findkey in self.buffer_elements: + if len(self.buffer_elements[findkey]) > biggest_size: + biggest_size = len(self.buffer_elements[findkey]) + biggest_key = findkey + + if ( + self.guaranteed_group_size is not None + and biggest_size < self.guaranteed_group_size + and not self.drop_remaining + ): + raise RuntimeError( + "Failed to group items", str(self.buffer_elements[biggest_key]) + ) + + if ( + self.guaranteed_group_size is None + or biggest_size >= self.guaranteed_group_size + ): + result_to_yield = self.buffer_elements[biggest_key] + + self.curr_buffer_size -= biggest_size + del self.buffer_elements[biggest_key] + + return result_to_yield + + def __iter__(self): + for x in self.datapipe: + key = self.group_key_fn(x) + + self.buffer_elements[key].append(x) + self.curr_buffer_size += 1 + + if self.group_size is not None and self.group_size == len( + self.buffer_elements[key] + ): + result: DataChunk[Any] = self.wrapper_class(self.buffer_elements[key]) + yield (key, result) if self.keep_key else result + self.curr_buffer_size -= len(self.buffer_elements[key]) + del self.buffer_elements[key] + + if self.curr_buffer_size == self.max_buffer_size: + result_to_yield = self._remove_biggest_key() + if result_to_yield is not None: + result = self.wrapper_class(result_to_yield) + yield (key, result) if self.keep_key else result + + for key in tuple(self.buffer_elements.keys()): + result = self.wrapper_class(self.buffer_elements.pop(key)) + self.curr_buffer_size -= len(result) + yield (key, result) if self.keep_key else result + + def reset(self) -> None: + self.curr_buffer_size = 0 + self.buffer_elements = defaultdict(list) + + def __getstate__(self): + state = ( + self.datapipe, + self.group_key_fn, + self.keep_key, + self.max_buffer_size, + self.group_size, + self.guaranteed_group_size, + self.drop_remaining, + self.wrapper_class, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __setstate__(self, state): + ( + self.datapipe, + self.group_key_fn, + self.keep_key, + self.max_buffer_size, + self.group_size, + self.guaranteed_group_size, + self.drop_remaining, + self.wrapper_class, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) = state + self.curr_buffer_size = 0 + self.buffer_elements = defaultdict(list) + + def __del__(self) -> None: + self.buffer_elements.clear() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/routeddecoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/routeddecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ba4d708a0a318bd75ab67f456b0a5ef2f24b2c81 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/routeddecoder.py @@ -0,0 +1,70 @@ +from collections.abc import Callable, Iterable, Iterator, Sized +from io import BufferedIOBase +from typing import Any + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.utils.common import _deprecation_warning +from torch.utils.data.datapipes.utils.decoder import ( + basichandlers as decoder_basichandlers, + Decoder, + extension_extract_fn, + imagehandler as decoder_imagehandler, +) + + +__all__ = ["RoutedDecoderIterDataPipe"] + + +@functional_datapipe("routed_decode") +class RoutedDecoderIterDataPipe(IterDataPipe[tuple[str, Any]]): + r""" + Decodes binary streams from input DataPipe, yields pathname and decoded data in a tuple. + + (functional name: ``routed_decode``) + + Args: + datapipe: Iterable datapipe that provides pathname and binary stream in tuples + handlers: Optional user defined decoder handlers. If ``None``, basic and image decoder + handlers will be set as default. If multiple handles are provided, the priority + order follows the order of handlers (the first handler has the top priority) + key_fn: Function for decoder to extract key from pathname to dispatch handlers. + Default is set to extract file extension from pathname + + Note: + When ``key_fn`` is specified returning anything other than extension, the default + handler will not work and users need to specify custom handler. Custom handler + could use regex to determine the eligibility to handle data. + """ + + def __init__( + self, + datapipe: Iterable[tuple[str, BufferedIOBase]], + *handlers: Callable, + key_fn: Callable = extension_extract_fn, + ) -> None: + super().__init__() + self.datapipe: Iterable[tuple[str, BufferedIOBase]] = datapipe + if not handlers: + handlers = (decoder_basichandlers, decoder_imagehandler("torch")) + self.decoder = Decoder(*handlers, key_fn=key_fn) + _deprecation_warning( + type(self).__name__, + deprecation_version="1.12", + removal_version="1.13", + old_functional_name="routed_decode", + ) + + def add_handler(self, *handler: Callable) -> None: + self.decoder.add_handler(*handler) + + def __iter__(self) -> Iterator[tuple[str, Any]]: + for data in self.datapipe: + pathname = data[0] + result = self.decoder(data) + yield (pathname, result[pathname]) + + def __len__(self) -> int: + if isinstance(self.datapipe, Sized): + return len(self.datapipe) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/selecting.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/selecting.py new file mode 100644 index 0000000000000000000000000000000000000000..afb0e91d8557911aae6f20d830667c79f7764cc5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/selecting.py @@ -0,0 +1,102 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable, Iterator +from typing import TypeVar + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.utils.common import ( + _check_unpickable_fn, + StreamWrapper, + validate_input_col, +) + + +__all__ = ["FilterIterDataPipe"] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + + +@functional_datapipe("filter") +class FilterIterDataPipe(IterDataPipe[_T_co]): + r""" + Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``). + + Args: + datapipe: Iterable DataPipe being filtered + filter_fn: Customized function mapping an element to a boolean. + input_col: Index or indices of data which ``filter_fn`` is applied, such as: + + - ``None`` as default to apply ``filter_fn`` to the data directly. + - Integer(s) is used for list/tuple. + - Key(s) is used for dict. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> def is_even(n): + ... return n % 2 == 0 + >>> dp = IterableWrapper(range(5)) + >>> filter_dp = dp.filter(filter_fn=is_even) + >>> list(filter_dp) + [0, 2, 4] + """ + + datapipe: IterDataPipe[_T_co] + filter_fn: Callable + + def __init__( + self, + datapipe: IterDataPipe[_T_co], + filter_fn: Callable, + input_col=None, + ) -> None: + super().__init__() + self.datapipe = datapipe + + _check_unpickable_fn(filter_fn) + self.filter_fn = filter_fn # type: ignore[assignment] + + self.input_col = input_col + validate_input_col(filter_fn, input_col) + + def _apply_filter_fn(self, data) -> bool: + if self.input_col is None: + return self.filter_fn(data) + elif isinstance(self.input_col, (list, tuple)): + args = tuple(data[col] for col in self.input_col) + return self.filter_fn(*args) + else: + return self.filter_fn(data[self.input_col]) + + def __iter__(self) -> Iterator[_T_co]: + for data in self.datapipe: + condition, filtered = self._returnIfTrue(data) + if condition: + yield filtered + else: + StreamWrapper.close_streams(data) + + def _returnIfTrue(self, data: _T) -> tuple[bool, _T]: + condition = self._apply_filter_fn(data) + + if df_wrapper.is_column(condition): + # We are operating on DataFrames filter here + result = [] + for idx, mask in enumerate(df_wrapper.iterate(condition)): + if mask: + result.append(df_wrapper.get_item(data, idx)) + if result: + return True, df_wrapper.concat(result) + else: + return False, None # type: ignore[return-value] + + if not isinstance(condition, bool): + raise ValueError( + "Boolean output is required for `filter_fn` of FilterIterDataPipe, got", + type(condition), + ) + + return condition, data diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/sharding.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..494ea0106a041eb78d31287a6f05a5c8434c3321 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/sharding.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +from collections.abc import Sized +from enum import IntEnum +from typing import NoReturn + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe + + +__all__ = [ + "SHARDING_PRIORITIES", + "ShardingFilterIterDataPipe", +] + + +class SHARDING_PRIORITIES(IntEnum): + DEFAULT = 1 + DISTRIBUTED = 2 + MULTIPROCESSING = 3 + + +class _ShardingIterDataPipe(IterDataPipe): + def apply_sharding( + self, + num_of_instances: int, + instance_id: int, + sharding_group: SHARDING_PRIORITIES, + ) -> NoReturn: + raise NotImplementedError + + +@functional_datapipe("sharding_filter") +class ShardingFilterIterDataPipe(_ShardingIterDataPipe): + r""" + Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``). + + After ``apply_sharding`` is called, each instance of the DataPipe (on different workers) will have every `n`-th element of the + original DataPipe, where `n` equals to the number of instances. + + Args: + source_datapipe: Iterable DataPipe that will be sharded + """ + + def __init__( + self, source_datapipe: IterDataPipe, sharding_group_filter=None + ) -> None: + self.source_datapipe = source_datapipe + self.sharding_group_filter = sharding_group_filter + self.groups: dict[int, tuple[int, int]] = {} + self.num_of_instances = 1 + self.instance_id = 0 + self._update_num_of_instances() + + def apply_sharding( + self, num_of_instances, instance_id, sharding_group=SHARDING_PRIORITIES.DEFAULT + ): + if instance_id >= num_of_instances: + raise ValueError( + f"instance_id({instance_id}) should be smaller than num_of_instances({num_of_instances})" + ) + if sharding_group == SHARDING_PRIORITIES.DEFAULT: + if len(self.groups) and SHARDING_PRIORITIES.DEFAULT not in self.groups: + raise RuntimeError( + "ShardingFilter cannot mix DEFAULT and non DEFAULT groups" + ) + else: + if SHARDING_PRIORITIES.DEFAULT in self.groups: + raise RuntimeError( + "ShardingFilter cannot mix DEFAULT and non DEFAULT groups" + ) + self.groups[sharding_group] = (num_of_instances, instance_id) + self._update_num_of_instances() + + def _update_num_of_instances(self) -> None: + sorted_sharding_groups = [ + self.groups[key] + for key in sorted(self.groups.keys()) + if self.sharding_group_filter is None or key == self.sharding_group_filter + ] + + sorted_sharding_groups.reverse() + + self.num_of_instances = 1 + self.instance_id = 0 + + for group_num_of_instances, group_instance_id in sorted_sharding_groups: + self.instance_id += self.num_of_instances * group_instance_id + self.num_of_instances *= group_num_of_instances + + def __iter__(self): + for i, item in enumerate(self.source_datapipe): + if i % self.num_of_instances == self.instance_id: + yield item + + def __len__(self) -> int: + if isinstance(self.source_datapipe, Sized): + return len(self.source_datapipe) // self.num_of_instances + ( + 1 + if ( + self.instance_id < len(self.source_datapipe) % self.num_of_instances + ) + else 0 + ) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/streamreader.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/streamreader.py new file mode 100644 index 0000000000000000000000000000000000000000..1129c06548e1f406629e25a5a2f558dea3a1475e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/streamreader.py @@ -0,0 +1,45 @@ +from collections.abc import Iterator +from io import IOBase + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe + + +__all__ = ["StreamReaderIterDataPipe"] + + +@functional_datapipe("read_from_stream") +class StreamReaderIterDataPipe(IterDataPipe[tuple[str, bytes]]): + r""" + Given IO streams and their label names, yield bytes with label name as tuple. + + (functional name: ``read_from_stream``). + + Args: + datapipe: Iterable DataPipe provides label/URL and byte stream + chunk: Number of bytes to be read from stream per iteration. + If ``None``, all bytes will be read until the EOF. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper, StreamReader + >>> from io import StringIO + >>> dp = IterableWrapper([("alphabet", StringIO("abcde"))]) + >>> list(StreamReader(dp, chunk=1)) + [('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')] + """ + + def __init__( + self, datapipe: IterDataPipe[tuple[str, IOBase]], chunk: int | None = None + ) -> None: + self.datapipe = datapipe + self.chunk = chunk + + def __iter__(self) -> Iterator[tuple[str, bytes]]: + for furl, stream in self.datapipe: + while True: + d = stream.read(self.chunk) + if not d: + stream.close() + break + yield (furl, d) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e45ddab282f7b975732b28ab88339f979792646a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/utils.py @@ -0,0 +1,60 @@ +import copy +import warnings +from collections.abc import Iterable, Iterator, Sized +from typing import TypeVar + +from torch.utils.data.datapipes.datapipe import IterDataPipe + + +_T = TypeVar("_T") + +__all__ = ["IterableWrapperIterDataPipe"] + + +class IterableWrapperIterDataPipe(IterDataPipe[_T]): + r""" + Wraps an iterable object to create an IterDataPipe. + + Args: + iterable: Iterable object to be wrapped into an IterDataPipe + deepcopy: Option to deepcopy input iterable object for each + iterator. The copy is made when the first element is read in ``iter()``. + + .. note:: + If ``deepcopy`` is explicitly set to ``False``, users should ensure + that the data pipeline doesn't contain any in-place operations over + the iterable instance to prevent data inconsistency across iterations. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp = IterableWrapper(range(10)) + >>> list(dp) + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + """ + + def __init__(self, iterable: Iterable[_T], deepcopy: bool = True) -> None: + self.iterable = iterable + self.deepcopy = deepcopy + + def __iter__(self) -> Iterator[_T]: + source_data = self.iterable + if self.deepcopy: + try: + source_data = copy.deepcopy(self.iterable) + # For the case that data cannot be deep-copied, + # all in-place operations will affect iterable variable. + # When this DataPipe is iterated second time, it will + # yield modified items. + except TypeError: + warnings.warn( + "The input iterable can not be deepcopied, " + "please be aware of in-place modification would affect source data.", + stacklevel=2, + ) + yield from source_data + + def __len__(self) -> int: + if isinstance(self.iterable, Sized): + return len(self.iterable) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc555e8fdac26039d36c4c1e1ba8309bfa8b4e5a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__init__.py @@ -0,0 +1,20 @@ +# Functional DataPipe +from torch.utils.data.datapipes.map.callable import MapperMapDataPipe as Mapper +from torch.utils.data.datapipes.map.combinatorics import ( + ShufflerIterDataPipe as Shuffler, +) +from torch.utils.data.datapipes.map.combining import ( + ConcaterMapDataPipe as Concater, + ZipperMapDataPipe as Zipper, +) +from torch.utils.data.datapipes.map.grouping import BatcherMapDataPipe as Batcher +from torch.utils.data.datapipes.map.utils import ( + SequenceWrapperMapDataPipe as SequenceWrapper, +) + + +__all__ = ["Batcher", "Concater", "Mapper", "SequenceWrapper", "Shuffler", "Zipper"] + +# Please keep this list sorted +if __all__ != sorted(__all__): + raise AssertionError("__all__ is not sorted") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ccc693e81014fffbd268c1804a6d7e677b024fa Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/callable.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/callable.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbad27841dde0bcc5facd1ffb32878ce075991a2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/callable.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/combinatorics.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/combinatorics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75d09a13d6b330da5db0c2fc62f0f126e87fed02 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/combinatorics.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/combining.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/combining.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93744ca132368329201724d37583b00705a598c7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/combining.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/grouping.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/grouping.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f281fcdd45478d9a4697dc8749d4e132aac0a4dc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/grouping.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c756e4ddd243d16b8516fa9858f81d6740269a99 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/callable.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/callable.py new file mode 100644 index 0000000000000000000000000000000000000000..3696d34b2a815599709bb09d9b0dfcaca988a6eb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/callable.py @@ -0,0 +1,67 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable +from typing import TypeVar + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import MapDataPipe +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn + + +__all__ = ["MapperMapDataPipe", "default_fn"] + + +_T_co = TypeVar("_T_co", covariant=True) + + +# Default function to return each item directly +# In order to keep datapipe picklable, eliminates the usage +# of python lambda function +def default_fn(data): + return data + + +@functional_datapipe("map") +class MapperMapDataPipe(MapDataPipe[_T_co]): + r""" + Apply the input function over each item from the source DataPipe (functional name: ``map``). + + The function can be any regular Python function or partial object. Lambda + function is not recommended as it is not supported by pickle. + + Args: + datapipe: Source MapDataPipe + fn: Function being applied to each item + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper, Mapper + >>> def add_one(x): + ... return x + 1 + >>> dp = SequenceWrapper(range(10)) + >>> map_dp_1 = dp.map(add_one) + >>> list(map_dp_1) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> map_dp_2 = Mapper(dp, lambda x: x + 1) + >>> list(map_dp_2) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + """ + + datapipe: MapDataPipe + fn: Callable + + def __init__( + self, + datapipe: MapDataPipe, + fn: Callable = default_fn, + ) -> None: + super().__init__() + self.datapipe = datapipe + _check_unpickable_fn(fn) + self.fn = fn # type: ignore[assignment] + + def __len__(self) -> int: + # pyrefly: ignore [bad-argument-type] + return len(self.datapipe) + + def __getitem__(self, index) -> _T_co: + return self.fn(self.datapipe[index]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/combinatorics.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/combinatorics.py new file mode 100644 index 0000000000000000000000000000000000000000..af4792fc805b824d45a966851e0fae2d853ff99f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/combinatorics.py @@ -0,0 +1,132 @@ +# mypy: allow-untyped-defs +import random +from collections.abc import Iterator +from typing import TypeVar + +import torch +from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe + + +__all__ = ["ShufflerIterDataPipe"] + + +_T_co = TypeVar("_T_co", covariant=True) + + +# @functional_datapipe('shuffle') +class ShufflerIterDataPipe(IterDataPipe[_T_co]): + r""" + Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``). + + When it is used with :class:`~torch.utils.data.DataLoader`, the methods to + set up random seed are different based on :attr:`num_workers`. + + For single-process mode (:attr:`num_workers == 0`), the random seed is set before + the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process + mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed + for each worker process. + + Args: + datapipe: MapDataPipe being shuffled + indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp = SequenceWrapper(range(10)) + >>> shuffle_dp = dp.shuffle().set_seed(0) + >>> list(shuffle_dp) + [7, 8, 1, 5, 3, 4, 2, 0, 9, 6] + >>> list(shuffle_dp) + [6, 1, 9, 5, 2, 4, 7, 3, 8, 0] + >>> # Reset seed for Shuffler + >>> shuffle_dp = shuffle_dp.set_seed(0) + >>> list(shuffle_dp) + [7, 8, 1, 5, 3, 4, 2, 0, 9, 6] + + Note: + Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an + ``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to + the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order + of data during data-processing. + """ + + datapipe: MapDataPipe[_T_co] + _enabled: bool + _seed: int | None + _rng: random.Random + + def __init__( + self, + datapipe: MapDataPipe[_T_co], + *, + indices: list | None = None, + ) -> None: + super().__init__() + self.datapipe = datapipe + # pyrefly: ignore [bad-argument-type] + self.indices = list(range(len(datapipe))) if indices is None else indices + self._enabled = True + self._seed = None + self._rng = random.Random() + self._shuffled_indices: list = self.indices + + def set_shuffle(self, shuffle=True): + self._enabled = shuffle + return self + + def set_seed(self, seed: int): + self._seed = seed + return self + + def __iter__(self) -> Iterator[_T_co]: + if not self._enabled: + for idx in self.indices: + yield self.datapipe[idx] + else: + while self._shuffled_indices: + idx = self._shuffled_indices.pop() + yield self.datapipe[idx] + + def reset(self) -> None: + if self._enabled and self._seed is None: + self._seed = int(torch.empty((), dtype=torch.int64).random_().item()) + self._rng.seed(self._seed) + self._seed = None + self._shuffled_indices = self._rng.sample(self.indices, len(self.indices)) + + def __len__(self) -> int: + # pyrefly: ignore [bad-argument-type] + return len(self.datapipe) + + def __getstate__(self): + state = ( + self.datapipe, + self.indices, + self._enabled, + self._seed, + self._rng.getstate(), + self._shuffled_indices, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __setstate__(self, state): + ( + self.datapipe, + self.indices, + self._enabled, + self._seed, + rng_state, + self._shuffled_indices, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) = state + self._rng = random.Random() + self._rng.setstate(rng_state) + + +MapDataPipe.register_datapipe_as_function("shuffle", ShufflerIterDataPipe) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/combining.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/combining.py new file mode 100644 index 0000000000000000000000000000000000000000..c11d0bcd17d99b2fbceda986e229fb2257e1ec67 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/combining.py @@ -0,0 +1,109 @@ +# mypy: allow-untyped-defs +from collections.abc import Sized +from typing import TypeVar + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import MapDataPipe + + +__all__ = ["ConcaterMapDataPipe", "ZipperMapDataPipe"] + +_T_co = TypeVar("_T_co", covariant=True) + + +@functional_datapipe("concat") +class ConcaterMapDataPipe(MapDataPipe): + r""" + Concatenate multiple Map DataPipes (functional name: ``concat``). + + The new index of is the cumulative sum of source DataPipes. + For example, if there are 2 source DataPipes both with length 5, + index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to + elements of the first DataPipe, and 5 to 9 would refer to elements + of the second DataPipe. + + Args: + datapipes: Map DataPipes being concatenated + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp1 = SequenceWrapper(range(3)) + >>> dp2 = SequenceWrapper(range(3)) + >>> concat_dp = dp1.concat(dp2) + >>> list(concat_dp) + [0, 1, 2, 0, 1, 2] + """ + + datapipes: tuple[MapDataPipe] + + def __init__(self, *datapipes: MapDataPipe) -> None: + if len(datapipes) == 0: + raise ValueError("Expected at least one DataPipe, but got nothing") + if not all(isinstance(dp, MapDataPipe) for dp in datapipes): + raise TypeError("Expected all inputs to be `MapDataPipe`") + if not all(isinstance(dp, Sized) for dp in datapipes): + raise TypeError("Expected all inputs to be `Sized`") + self.datapipes = datapipes # type: ignore[assignment] + + def __getitem__(self, index) -> _T_co: # type: ignore[type-var] + offset = 0 + for dp in self.datapipes: + # pyrefly: ignore [bad-argument-type] + if index - offset < len(dp): + return dp[index - offset] + else: + # pyrefly: ignore [bad-argument-type] + offset += len(dp) + raise IndexError(f"Index {index} is out of range.") + + def __len__(self) -> int: + # pyrefly: ignore [bad-argument-type] + return sum(len(dp) for dp in self.datapipes) + + +@functional_datapipe("zip") +class ZipperMapDataPipe(MapDataPipe[tuple[_T_co, ...]]): + r""" + Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``). + + This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted. + + Args: + *datapipes: Map DataPipes being aggregated + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp1 = SequenceWrapper(range(3)) + >>> dp2 = SequenceWrapper(range(10, 13)) + >>> zip_dp = dp1.zip(dp2) + >>> list(zip_dp) + [(0, 10), (1, 11), (2, 12)] + """ + + datapipes: tuple[MapDataPipe[_T_co], ...] + + def __init__(self, *datapipes: MapDataPipe[_T_co]) -> None: + if len(datapipes) == 0: + raise ValueError("Expected at least one DataPipe, but got nothing") + if not all(isinstance(dp, MapDataPipe) for dp in datapipes): + raise TypeError("Expected all inputs to be `MapDataPipe`") + if not all(isinstance(dp, Sized) for dp in datapipes): + raise TypeError("Expected all inputs to be `Sized`") + self.datapipes = datapipes + + def __getitem__(self, index) -> tuple[_T_co, ...]: + res = [] + for dp in self.datapipes: + try: + res.append(dp[index]) + except IndexError as e: + raise IndexError( + f"Index {index} is out of range for one of the input MapDataPipes {dp}." + ) from e + return tuple(res) + + def __len__(self) -> int: + # pyrefly: ignore [bad-argument-type] + return min(len(dp) for dp in self.datapipes) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/grouping.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/grouping.py new file mode 100644 index 0000000000000000000000000000000000000000..5929cab2427913d1ed3cae7494ef757513c73a40 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/grouping.py @@ -0,0 +1,75 @@ +# mypy: allow-untyped-defs +from collections.abc import Sized +from typing import TypeVar + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import DataChunk, MapDataPipe + + +__all__ = ["BatcherMapDataPipe"] + + +_T = TypeVar("_T") + + +@functional_datapipe("batch") +class BatcherMapDataPipe(MapDataPipe[DataChunk]): + r""" + Create mini-batches of data (functional name: ``batch``). + + An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, + or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``. + + Args: + datapipe: Iterable DataPipe being batched + batch_size: The size of each batch + drop_last: Option to drop the last batch if it's not full + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp = SequenceWrapper(range(10)) + >>> batch_dp = dp.batch(batch_size=2) + >>> list(batch_dp) + [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] + """ + + datapipe: MapDataPipe + batch_size: int + drop_last: bool + + def __init__( + self, + datapipe: MapDataPipe[_T], + batch_size: int, + drop_last: bool = False, + wrapper_class: type[DataChunk] = DataChunk, + ) -> None: + if batch_size <= 0: + raise AssertionError("Batch size is required to be larger than 0!") + super().__init__() + self.datapipe = datapipe + self.batch_size = batch_size + self.drop_last = drop_last + self.wrapper_class = wrapper_class + + def __getitem__(self, index) -> DataChunk: + batch: list = [] + indices = range(index * self.batch_size, (index + 1) * self.batch_size) + try: + batch.extend(self.datapipe[i] for i in indices) + return self.wrapper_class(batch) + except IndexError as e: + if not self.drop_last and len(batch) > 0: + return self.wrapper_class(batch) + else: + raise IndexError(f"Index {index} is out of bound.") from e + + def __len__(self) -> int: + if isinstance(self.datapipe, Sized): + if self.drop_last: + return len(self.datapipe) // self.batch_size + else: + return (len(self.datapipe) + self.batch_size - 1) // self.batch_size + else: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b9075f1dbbc66a84dfd14d0778cc96ca604da0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/map/utils.py @@ -0,0 +1,61 @@ +import copy +import warnings +from collections.abc import Mapping, Sequence +from typing import Any, TypeVar + +from torch.utils.data.datapipes.datapipe import MapDataPipe + + +_T = TypeVar("_T") + +__all__ = ["SequenceWrapperMapDataPipe"] + + +class SequenceWrapperMapDataPipe(MapDataPipe[_T]): + r""" + Wraps a sequence object into a MapDataPipe. + + Args: + sequence: Sequence object to be wrapped into an MapDataPipe + deepcopy: Option to deepcopy input sequence object + + .. note:: + If ``deepcopy`` is set to False explicitly, users should ensure + that data pipeline doesn't contain any in-place operations over + the iterable instance, in order to prevent data inconsistency + across iterations. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp = SequenceWrapper(range(10)) + >>> list(dp) + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + >>> dp = SequenceWrapper({"a": 100, "b": 200, "c": 300, "d": 400}) + >>> dp["a"] + 100 + """ + + sequence: Sequence[_T] | Mapping[Any, _T] + + def __init__( + self, sequence: Sequence[_T] | Mapping[Any, _T], deepcopy: bool = True + ) -> None: + if deepcopy: + try: + self.sequence = copy.deepcopy(sequence) + except TypeError: + warnings.warn( + "The input sequence can not be deepcopied, " + "please be aware of in-place modification would affect source data", + stacklevel=2, + ) + self.sequence = sequence + else: + self.sequence = sequence + + def __getitem__(self, index: int) -> _T: + return self.sequence[index] + + def __len__(self) -> int: + return len(self.sequence) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3329659302f4ef7e89b7ce6f0825c10559aa8778 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d8149b38cfd7452f002fa0071abe9727b6bdd07 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/common.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/decoder.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/decoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b846d2bfe867d84fcb3bd1addc47f036241194d4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/decoder.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/snapshot.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/snapshot.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a49342cf1dc444b77f339482430719b98bd607b2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/snapshot.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..4fcc617b3b722b4b9acfe0006198017858eb60b3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/common.py @@ -0,0 +1,415 @@ +# mypy: allow-untyped-defs +import fnmatch +import functools +import inspect +import os +import warnings +from collections.abc import Callable, Iterable +from io import IOBase +from typing import Any, NoReturn + +from torch.utils._import_utils import dill_available + + +__all__ = [ + "validate_input_col", + "StreamWrapper", + "get_file_binaries_from_pathnames", + "get_file_pathnames_from_root", + "match_masks", + "validate_pathname_binary_tuple", +] + + +# BC for torchdata +DILL_AVAILABLE = dill_available() + + +def validate_input_col(fn: Callable, input_col: int | tuple | list | None) -> None: + """ + Check that function used in a callable datapipe works with the input column. + + This simply ensures that the number of positional arguments matches the size + of the input column. The function must not contain any non-default + keyword-only arguments. + + Examples: + >>> # xdoctest: +SKIP("Failing on some CI machines") + >>> def f(a, b, *, c=1): + >>> return a + b + c + >>> def f_def(a, b=1, *, c=1): + >>> return a + b + c + >>> assert validate_input_col(f, [1, 2]) + >>> assert validate_input_col(f_def, 1) + >>> assert validate_input_col(f_def, [1, 2]) + + Notes: + If the function contains variable positional (`inspect.VAR_POSITIONAL`) arguments, + for example, f(a, *args), the validator will accept any size of input column + greater than or equal to the number of positional arguments. + (in this case, 1). + + Args: + fn: The function to check. + input_col: The input column to check. + + Raises: + ValueError: If the function is not compatible with the input column. + """ + try: + sig = inspect.signature(fn) + except ( + ValueError + ): # Signature cannot be inspected, likely it is a built-in fn or written in C + return + if isinstance(input_col, (list, tuple)): + input_col_size = len(input_col) + else: + input_col_size = 1 + + pos = [] + var_positional = False + non_default_kw_only = [] + + for p in sig.parameters.values(): + if p.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + pos.append(p) + elif p.kind is inspect.Parameter.VAR_POSITIONAL: + var_positional = True + elif p.kind is inspect.Parameter.KEYWORD_ONLY: + if p.default is p.empty: + non_default_kw_only.append(p) + else: + continue + + if isinstance(fn, functools.partial): + fn_name = getattr(fn.func, "__name__", repr(fn.func)) + else: + fn_name = getattr(fn, "__name__", repr(fn)) + + if len(non_default_kw_only) > 0: + raise ValueError( + f"The function {fn_name} takes {len(non_default_kw_only)} " + f"non-default keyword-only parameters, which is not allowed." + ) + + if len(sig.parameters) < input_col_size: + if not var_positional: + raise ValueError( + f"The function {fn_name} takes {len(sig.parameters)} " + f"parameters, but {input_col_size} are required." + ) + else: + if len(pos) > input_col_size: + if any(p.default is p.empty for p in pos[input_col_size:]): + raise ValueError( + f"The function {fn_name} takes {len(pos)} " + f"positional parameters, but {input_col_size} are required." + ) + elif len(pos) < input_col_size: + if not var_positional: + raise ValueError( + f"The function {fn_name} takes {len(pos)} " + f"positional parameters, but {input_col_size} are required." + ) + + +def _is_local_fn(fn): + # Functions or Methods + if hasattr(fn, "__code__"): + return fn.__code__.co_flags & inspect.CO_NESTED + # Callable Objects + else: + if hasattr(fn, "__qualname__"): + return "" in fn.__qualname__ + fn_type = type(fn) + if hasattr(fn_type, "__qualname__"): + return "" in fn_type.__qualname__ + return False + + +def _check_unpickable_fn(fn: Callable) -> None: + """ + Check function is pickable or not. + + If it is a lambda or local function, a UserWarning will be raised. If it's not a callable function, a TypeError will be raised. + """ + if not callable(fn): + raise TypeError(f"A callable function is expected, but {type(fn)} is provided.") + + # Extract function from partial object + # Nested partial function is automatically expanded as a single partial object + if isinstance(fn, functools.partial): + fn = fn.func + + # Local function + if _is_local_fn(fn) and not dill_available(): + warnings.warn( + "Local function is not supported by pickle, please use " + "regular python function or functools.partial instead.", + stacklevel=2, + ) + return + + # Lambda function + if hasattr(fn, "__name__") and fn.__name__ == "" and not dill_available(): + warnings.warn( + "Lambda function is not supported by pickle, please use " + "regular python function or functools.partial instead.", + stacklevel=2, + ) + return + + +def match_masks(name: str, masks: str | list[str]) -> bool: + # empty mask matches any input name + if not masks: + return True + + if isinstance(masks, str): + return fnmatch.fnmatch(name, masks) + + for mask in masks: + if fnmatch.fnmatch(name, mask): + return True + return False + + +def get_file_pathnames_from_root( + root: str, + masks: str | list[str], + recursive: bool = False, + abspath: bool = False, + non_deterministic: bool = False, +) -> Iterable[str]: + # print out an error message and raise the error out + def onerror(err: OSError) -> NoReturn: + warnings.warn(err.filename + " : " + err.strerror, stacklevel=2) + raise err + + if os.path.isfile(root): + path = root + if abspath: + path = os.path.abspath(path) + fname = os.path.basename(path) + if match_masks(fname, masks): + yield path + else: + # pyrefly: ignore [bad-assignment] + for path, dirs, files in os.walk(root, onerror=onerror): + if abspath: + path = os.path.abspath(path) + if not non_deterministic: + files.sort() + for f in files: + if match_masks(f, masks): + yield os.path.join(path, f) + if not recursive: + break + if not non_deterministic: + # Note that this is in-place modifying the internal list from `os.walk` + # This only works because `os.walk` doesn't shallow copy before turn + # https://github.com/python/cpython/blob/f4c03484da59049eb62a9bf7777b963e2267d187/Lib/os.py#L407 + dirs.sort() + + +def get_file_binaries_from_pathnames( + pathnames: Iterable, mode: str, encoding: str | None = None +): + if not isinstance(pathnames, Iterable): + pathnames = [ + pathnames, + ] + + if mode in ("b", "t"): + mode = "r" + mode + + for pathname in pathnames: + if not isinstance(pathname, str): + raise TypeError( + f"Expected string type for pathname, but got {type(pathname)}" + ) + yield pathname, StreamWrapper(open(pathname, mode, encoding=encoding)) # noqa:SIM115 + + +def validate_pathname_binary_tuple(data: tuple[str, IOBase]) -> None: + if not isinstance(data, tuple): + raise TypeError( + f"pathname binary data should be tuple type, but it is type {type(data)}" + ) + if len(data) != 2: + raise TypeError( + f"pathname binary stream tuple length should be 2, but got {len(data)}" + ) + if not isinstance(data[0], str): + raise TypeError( + f"pathname within the tuple should have string type pathname, but it is type {type(data[0])}" + ) + if not isinstance(data[1], IOBase) and not isinstance(data[1], StreamWrapper): + raise TypeError( + f"binary stream within the tuple should have IOBase or" + f"its subclasses as type, but it is type {type(data[1])}" + ) + + +# Deprecated function names and its corresponding DataPipe type and kwargs for the `_deprecation_warning` function +_iter_deprecated_functional_names: dict[str, dict] = {} +_map_deprecated_functional_names: dict[str, dict] = {} + + +def _deprecation_warning( + old_class_name: str, + *, + deprecation_version: str, + removal_version: str, + old_functional_name: str = "", + old_argument_name: str = "", + new_class_name: str = "", + new_functional_name: str = "", + new_argument_name: str = "", + deprecate_functional_name_only: bool = False, +) -> None: + if new_functional_name and not old_functional_name: + raise ValueError( + "Old functional API needs to be specified for the deprecation warning." + ) + if new_argument_name and not old_argument_name: + raise ValueError( + "Old argument name needs to be specified for the deprecation warning." + ) + + if old_functional_name and old_argument_name: + raise ValueError( + "Deprecating warning for functional API and argument should be separated." + ) + + msg = f"`{old_class_name}()`" + if deprecate_functional_name_only and old_functional_name: + msg = f"{msg}'s functional API `.{old_functional_name}()` is" + elif old_functional_name: + msg = f"{msg} and its functional API `.{old_functional_name}()` are" + elif old_argument_name: + msg = f"The argument `{old_argument_name}` of {msg} is" + else: + msg = f"{msg} is" + msg = ( + f"{msg} deprecated since {deprecation_version} and will be removed in {removal_version}." + f"\nSee https://github.com/pytorch/data/issues/163 for details." + ) + + if new_class_name or new_functional_name: + msg = f"{msg}\nPlease use" + if new_class_name: + msg = f"{msg} `{new_class_name}()`" + if new_class_name and new_functional_name: + msg = f"{msg} or" + if new_functional_name: + msg = f"{msg} `.{new_functional_name}()`" + msg = f"{msg} instead." + + if new_argument_name: + msg = f"{msg}\nPlease use `{old_class_name}({new_argument_name}=)` instead." + + warnings.warn(msg, FutureWarning, stacklevel=2) + + +class StreamWrapper: + """ + StreamWrapper is introduced to wrap file handler generated by DataPipe operation like `FileOpener`. + + StreamWrapper would guarantee the wrapped file handler is closed when it's out of scope. + """ + + session_streams: dict[Any, int] = {} + debug_unclosed_streams: bool = False + + def __init__(self, file_obj, parent_stream=None, name=None) -> None: + self.file_obj = file_obj + self.child_counter = 0 + self.parent_stream = parent_stream + self.close_on_last_child = False + self.name = name + self.closed = False + if parent_stream is not None: + if not isinstance(parent_stream, StreamWrapper): + raise RuntimeError( + f"Parent stream should be StreamWrapper, {type(parent_stream)} was given" + ) + parent_stream.child_counter += 1 + self.parent_stream = parent_stream + if StreamWrapper.debug_unclosed_streams: + StreamWrapper.session_streams[self] = 1 + + @classmethod + def close_streams(cls, v, depth=0) -> None: + """Traverse structure and attempts to close all found StreamWrappers on best effort basis.""" + if depth > 10: + return + if isinstance(v, StreamWrapper): + v.close() + else: + # Traverse only simple structures + if isinstance(v, dict): + for vv in v.values(): + cls.close_streams(vv, depth=depth + 1) + elif isinstance(v, (list, tuple)): + for vv in v: + cls.close_streams(vv, depth=depth + 1) + + def __getattr__(self, name): + file_obj = self.__dict__["file_obj"] + return getattr(file_obj, name) + + def close(self, *args, **kwargs) -> None: + if self.closed: + return + if StreamWrapper.debug_unclosed_streams: + del StreamWrapper.session_streams[self] + if hasattr(self, "parent_stream") and self.parent_stream is not None: + self.parent_stream.child_counter -= 1 + if ( + not self.parent_stream.child_counter + and self.parent_stream.close_on_last_child + ): + self.parent_stream.close() + try: + self.file_obj.close(*args, **kwargs) + except AttributeError: + pass + self.closed = True + + def autoclose(self) -> None: + """Automatically close stream when all child streams are closed or if there are none.""" + self.close_on_last_child = True + if self.child_counter == 0: + self.close() + + def __dir__(self): + attrs = list(self.__dict__.keys()) + list(StreamWrapper.__dict__.keys()) + attrs += dir(self.file_obj) + return list(set(attrs)) + + def __del__(self) -> None: + if not self.closed: + self.close() + + def __iter__(self): + yield from self.file_obj + + def __next__(self): + return next(self.file_obj) + + def __repr__(self) -> str: + if self.name is None: + return f"StreamWrapper<{self.file_obj!r}>" + else: + return f"StreamWrapper<{self.name},{self.file_obj!r}>" + + def __getstate__(self): + return self.file_obj + + def __setstate__(self, obj): + self.file_obj = obj diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3b907ffebdd22d663cef50b0cc55166c58ec6192 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/decoder.py @@ -0,0 +1,389 @@ +# mypy: allow-untyped-defs +# This file takes partial of the implementation from NVIDIA's webdataset at here: +# https://github.com/tmbdev/webdataset/blob/master/webdataset/autodecode.py + +import io +import json +import os.path +import pickle +import tempfile + +import torch +from torch.utils.data.datapipes.utils.common import StreamWrapper + + +__all__ = [ + "Decoder", + "ImageHandler", + "MatHandler", + "audiohandler", + "basichandlers", + "extension_extract_fn", + "handle_extension", + "imagehandler", + "mathandler", + "videohandler", +] + + +################################################################ +# handle basic datatypes +################################################################ +def basichandlers(extension: str, data): + """Transforms raw data (byte stream) into python objects. + + Looks at the extension and loads the data into a python object supporting + the corresponding extension. + + Args: + extension (str): The file extension + data (byte stream): Data to load into a python object. + + Returns: + object: The data loaded into a corresponding python object + supporting the extension. + + Example: + >>> import pickle + >>> data = pickle.dumps("some data") + >>> new_data = basichandlers("pickle", data) + >>> new_data + some data + + The transformation of data for extensions are: + - txt, text, transcript: utf-8 decoded data of str format + - cls, cls2, class, count, index, inx, id: int + - json, jsn: json loaded data + - pickle, pyd: pickle loaded data + - pt: torch loaded data + """ + + if extension in "txt text transcript": + return data.decode("utf-8") + + if extension in ["cls", "cls2", "class", "count", "index", "inx", "id"]: + try: + return int(data) + except ValueError: + return None + + if extension in "json jsn": + return json.loads(data) + + if extension in ["pyd", "pickle"]: + return pickle.loads(data) + + if extension in ["pt"]: + stream = io.BytesIO(data) + return torch.load(stream) + + # if extension in "ten tb".split(): + # from . import tenbin + # return tenbin.decode_buffer(data) + + # if extension in "mp msgpack msg".split(): + # import msgpack + # return msgpack.unpackb(data) + + return None + + +################################################################ +# handle images +################################################################ +imagespecs = { + "l8": ("numpy", "uint8", "l"), + "rgb8": ("numpy", "uint8", "rgb"), + "rgba8": ("numpy", "uint8", "rgba"), + "l": ("numpy", "float", "l"), + "rgb": ("numpy", "float", "rgb"), + "rgba": ("numpy", "float", "rgba"), + "torchl8": ("torch", "uint8", "l"), + "torchrgb8": ("torch", "uint8", "rgb"), + "torchrgba8": ("torch", "uint8", "rgba"), + "torchl": ("torch", "float", "l"), + "torchrgb": ("torch", "float", "rgb"), + "torch": ("torch", "float", "rgb"), + "torchrgba": ("torch", "float", "rgba"), + "pill": ("pil", None, "l"), + "pil": ("pil", None, "rgb"), + "pilrgb": ("pil", None, "rgb"), + "pilrgba": ("pil", None, "rgba"), +} + + +def handle_extension(extensions, f): + """ + Return a decoder handler function for the list of extensions. + + Extensions can be a space separated list of extensions. + Extensions can contain dots, in which case the corresponding number + of extension components must be present in the key given to f. + Comparisons are case insensitive. + Examples: + handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg + handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg + """ + extensions = extensions.lower().split() + + def g(key, data): + extension = key.lower().split(".") + + for target in extensions: + target = target.split(".") + if len(target) > len(extension): + continue + + if extension[-len(target) :] == target: + return f(data) + return None + + return g + + +class ImageHandler: + """ + Decode image data using the given `imagespec`. + + The `imagespec` specifies whether the image is decoded + to numpy/torch/pi, decoded to uint8/float, and decoded + to l/rgb/rgba: + + - l8: numpy uint8 l + - rgb8: numpy uint8 rgb + - rgba8: numpy uint8 rgba + - l: numpy float l + - rgb: numpy float rgb + - rgba: numpy float rgba + - torchl8: torch uint8 l + - torchrgb8: torch uint8 rgb + - torchrgba8: torch uint8 rgba + - torchl: torch float l + - torchrgb: torch float rgb + - torch: torch float rgb + - torchrgba: torch float rgba + - pill: pil None l + - pil: pil None rgb + - pilrgb: pil None rgb + - pilrgba: pil None rgba + """ + + def __init__(self, imagespec) -> None: + if imagespec not in list(imagespecs.keys()): + raise AssertionError(f"unknown image specification: {imagespec}") + self.imagespec = imagespec.lower() + + def __call__(self, extension, data): + if extension.lower() not in ["jpg", "jpeg", "png", "ppm", "pgm", "pbm", "pnm"]: + return None + + try: + import numpy as np + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Package `numpy` is required to be installed for default image decoder." + "Please use `pip install numpy` to install the package" + ) from e + + try: + import PIL.Image + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Package `PIL` is required to be installed for default image decoder." + "Please use `pip install Pillow` to install the package" + ) from e + + imagespec = self.imagespec + atype, etype, mode = imagespecs[imagespec] + + with io.BytesIO(data) as stream: + img = PIL.Image.open(stream) + img.load() + img = img.convert(mode.upper()) + if atype == "pil": + return img + elif atype == "numpy": + result = np.asarray(img) + if result.dtype != np.uint8: + raise AssertionError( + f"numpy image array should be type uint8, but got {result.dtype}" + ) + if etype == "uint8": + return result + else: + return result.astype("f") / 255.0 + elif atype == "torch": + result = np.asarray(img) + if result.dtype != np.uint8: + raise AssertionError( + f"numpy image array should be type uint8, but got {result.dtype}" + ) + + if etype == "uint8": + result = np.array(result.transpose(2, 0, 1)) + return torch.tensor(result) + else: + result = np.array(result.transpose(2, 0, 1)) + return torch.tensor(result) / 255.0 + return None + + +def imagehandler(imagespec): + return ImageHandler(imagespec) + + +################################################################ +# torch video +################################################################ +def videohandler(extension, data): + if extension not in [ + "mp4", + "ogv", + "mjpeg", + "avi", + "mov", + "h264", + "mpg", + "webm", + "wmv", + ]: + return None + + try: + import torchvision.io + except ImportError as e: + raise ModuleNotFoundError( + "Package `torchvision` is required to be installed for default video file loader." + "Please use `pip install torchvision`" + "to install the package" + ) from e + + with tempfile.TemporaryDirectory() as dirname: + fname = os.path.join(dirname, f"file.{extension}") + with open(fname, "wb") as stream: + stream.write(data) + return torchvision.io.read_video(fname) + + +################################################################ +# torchaudio +################################################################ +def audiohandler(extension, data): + if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]: + return None + + try: + import torchaudio # type: ignore[import] + except ImportError as e: + raise ModuleNotFoundError( + "Package `torchaudio` is required to be installed for default audio file loader." + "Please use `pip install torchaudio`" + "to install the package" + ) from e + + with tempfile.TemporaryDirectory() as dirname: + fname = os.path.join(dirname, f"file.{extension}") + with open(fname, "wb") as stream: + stream.write(data) + return torchaudio.load(fname) + + +################################################################ +# mat +################################################################ +class MatHandler: + def __init__(self, **loadmat_kwargs) -> None: + try: + import scipy.io as sio + except ImportError as e: + raise ModuleNotFoundError( + "Package `scipy` is required to be installed for mat file." + "Please use `pip install scipy`" + "to install the package" + ) from e + self.sio = sio + self.loadmat_kwargs = loadmat_kwargs + + def __call__(self, extension, data): + if extension != "mat": + return None + with io.BytesIO(data) as stream: + return self.sio.loadmat(stream, **self.loadmat_kwargs) + + +def mathandler(**loadmat_kwargs): + return MatHandler(**loadmat_kwargs) + + +################################################################ +# a sample decoder +################################################################ +# Extract extension from pathname +def extension_extract_fn(pathname): + ext = os.path.splitext(pathname)[1] + # Remove dot + if ext: + ext = ext[1:] + return ext + + +class Decoder: + """ + Decode key/data sets using a list of handlers. + + For each key/data item, this iterates through the list of + handlers until some handler returns something other than None. + """ + + def __init__(self, *handler, key_fn=extension_extract_fn) -> None: + self.handlers = list(handler) if handler else [] + self.key_fn = key_fn + + # Insert new handler from the beginning of handlers list to make sure the new + # handler having the highest priority + def add_handler(self, *handler) -> None: + if not handler: + return + self.handlers = list(handler) + self.handlers + + @staticmethod + def _is_stream_handle(data): + obj_to_check = data.file_obj if isinstance(data, StreamWrapper) else data + return isinstance(obj_to_check, (io.BufferedIOBase, io.RawIOBase)) + + def decode1(self, key, data): + if not data: + return data + + # if data is a stream handle, we need to read all the content before decoding + if Decoder._is_stream_handle(data): + ds = data + # The behavior of .read can differ between streams (e.g. HTTPResponse), hence this is used instead + data = b"".join(data) + ds.close() + + for f in self.handlers: + result = f(key, data) + if result is not None: + return result + return data + + def decode(self, data): + result = {} + # single data tuple(pathname, data stream) + if isinstance(data, tuple): + data = [data] + + if data is not None: + for k, v in data: + # TODO: xinyu, figure out why Nvidia do this? + if k[0] == "_": + if isinstance(v, bytes): + v = v.decode("utf-8") + result[k] = v + continue + result[k] = self.decode1(self.key_fn(k), v) + return result + + def __call__(self, data): + return self.decode(data) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/snapshot.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/snapshot.py new file mode 100644 index 0000000000000000000000000000000000000000..42aec1aa308a9b21b251de595cddfbe171930bb6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/snapshot.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +from torch.utils.data.datapipes._hook_iterator import _SnapshotState +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.graph_settings import apply_random_seed + + +# TODO: Caveats +# 1. Caller (either the ReadingService or DataLoader) must pass in the initial RNG +# 2. `in_batch_shuffle` and `bucketbatch` are not compatible with this because they currently +# lack the option to `set_seed`. +def _simple_graph_snapshot_restoration( + datapipe: IterDataPipe, n_iterations: int, rng=None +) -> None: + r""" + Fast-forward the given DataPipe and its parents by ``n_iterations``, re-doing computations to restore a snapshot. + + For instance, applying this function to the final DataPipe of a graph will restore the snapshot + (via fast-forward) every DataPipe within the graph. + + After you deserialize a DataPipe, you can use its `_number_of_samples_yielded` attribute as the input + to this function to forward the DataPipe. + + A DataPipe cannot be restored twice in a row unless there is an iteration started between the restoration + attempts. + + Note: + This is the simplest but least efficient way to fast-forward a DataPipe. Usage of other fast-forwarding + methods (custom ones if necessary) are recommended. + + Args: + datapipe: IterDataPipe to be fast-forwarded + n_iterations: number of iterations to fast-forward + rng: ``Optional[torch.Generator]``. If not ``None``, this RNG will be used for shuffling. The generator + should be in its `initial` state as it was first passed into ``DataLoader`` or ``ReadingService``. + """ + if datapipe._snapshot_state == _SnapshotState.Restored: + raise RuntimeError( + "Snapshot restoration cannot be applied. You can only restore simple snapshot to the graph " + "if your graph has not been restored." + ) + + # For this snapshot restoration function, we want the DataPipe to be at its initial state prior to + # simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`, + # the first reset will not actually reset. + datapipe.reset() # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`. + # pyrefly: ignore [bad-argument-type] + apply_random_seed(datapipe, rng) + + remainder = n_iterations + it = iter(datapipe) # This always reset the DataPipe if it hasn't already. + while remainder > 0: + try: + next(it) + remainder -= 1 + except StopIteration as e: + raise RuntimeError( + f"Fast-forward {datapipe} by {n_iterations} iterations " + "exceeds the number of samples available." + ) from e + datapipe._fast_forward_iterator = it + # While the DataPipe has `_fast_forward_iterator`, `next()` will get result from there instead of elsewhere. + + # This will prevent the DataPipe from resetting in the `iter()` call + # If another DataPipe is consuming it, it won't have to start over again + datapipe._snapshot_state = _SnapshotState.Restored diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/dataset.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..19ec449f040dd9ff87bbd85ca9ea4a003d6f17d1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/dataset.py @@ -0,0 +1,481 @@ +# mypy: allow-untyped-defs +import bisect +import itertools +import math +import warnings +from collections.abc import Sequence + +# UP006 wants 'Iterable' to be imported from collections.abc but it needs to +# stay from typing for now due to BC concerns. In particular several internal +# targets fail to typecheck with: +# TypeError: Cannot create a consistent method resolution order (MRO) for +# bases Iterable, Generic +from typing import cast, Generic, Iterable, TypeVar # noqa: UP035 +from typing_extensions import deprecated + +# No 'default_generator' in torch/__init__.pyi +from torch import default_generator, Generator, randperm, Tensor + + +__all__ = [ + "Dataset", + "IterableDataset", + "TensorDataset", + "StackDataset", + "ConcatDataset", + "ChainDataset", + "Subset", + "random_split", +] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_T_dict = dict[str, _T_co] +_T_tuple = tuple[_T_co, ...] +_T_stack = TypeVar("_T_stack", _T_tuple, _T_dict) + + +class Dataset(Generic[_T_co]): + r"""An abstract class representing a :class:`Dataset`. + + All datasets that represent a map from keys to data samples should subclass + it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a + data sample for a given key. Subclasses could also optionally overwrite + :meth:`__len__`, which is expected to return the size of the dataset by many + :class:`~torch.utils.data.Sampler` implementations and the default options + of :class:`~torch.utils.data.DataLoader`. Subclasses could also + optionally implement :meth:`__getitems__`, for speedup batched samples + loading. This method accepts list of indices of samples of batch and returns + list of samples. + + .. note:: + :class:`~torch.utils.data.DataLoader` by default constructs an index + sampler that yields integral indices. To make it work with a map-style + dataset with non-integral indices/keys, a custom sampler must be provided. + """ + + def __getitem__(self, index) -> _T_co: + raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") + + # def __getitems__(self, indices: List) -> List[_T_co]: + # Not implemented to prevent false-positives in fetcher check in + # torch.utils.data._utils.fetch._MapDatasetFetcher + + def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]": + return ConcatDataset([self, other]) + + # No `def __len__(self)` default? + # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + # in pytorch/torch/utils/data/sampler.py + + +class IterableDataset(Dataset[_T_co], Iterable[_T_co]): + r"""An iterable Dataset. + + All datasets that represent an iterable of data samples should subclass it. + Such form of datasets is particularly useful when data come from a stream. + + All subclasses should overwrite :meth:`__iter__`, which would return an + iterator of samples in this dataset. + + When a subclass is used with :class:`~torch.utils.data.DataLoader`, each + item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` + iterator. When :attr:`num_workers > 0`, each worker process will have a + different copy of the dataset object, so it is often desired to configure + each copy independently to avoid having duplicate data returned from the + workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker + process, returns information about the worker. It can be used in either the + dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's + :attr:`worker_init_fn` option to modify each copy's behavior. + + Example 1: splitting workload across all workers in :meth:`__iter__`:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) + >>> # xdoctest: +SKIP("Fails on MacOS12") + >>> class MyIterableDataset(torch.utils.data.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... worker_info = torch.utils.data.get_worker_info() + ... if worker_info is None: # single-process data loading, return the full iterator + ... iter_start = self.start + ... iter_end = self.end + ... else: # in a worker process + ... # split workload + ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) + ... worker_id = worker_info.id + ... iter_start = self.start + worker_id * per_worker + ... iter_end = min(iter_start + per_worker, self.end) + ... return iter(range(iter_start, iter_end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) + [tensor([3]), tensor([4]), tensor([5]), tensor([6])] + + >>> # xdoctest: +REQUIRES(POSIX) + >>> # Multi-process loading with two worker processes + >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. + >>> # xdoctest: +IGNORE_WANT("non deterministic") + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) + [tensor([3]), tensor([5]), tensor([4]), tensor([6])] + + >>> # With even more workers + >>> # xdoctest: +IGNORE_WANT("non deterministic") + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) + [tensor([3]), tensor([5]), tensor([4]), tensor([6])] + + Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) + >>> class MyIterableDataset(torch.utils.data.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... return iter(range(self.start, self.end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) + [3, 4, 5, 6] + >>> + >>> # Directly doing multi-process loading yields duplicate data + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) + [3, 3, 4, 4, 5, 5, 6, 6] + + >>> # Define a `worker_init_fn` that configures each dataset copy differently + >>> def worker_init_fn(worker_id): + ... worker_info = torch.utils.data.get_worker_info() + ... dataset = worker_info.dataset # the dataset copy in this worker process + ... overall_start = dataset.start + ... overall_end = dataset.end + ... # configure the dataset to only process the split workload + ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) + ... worker_id = worker_info.id + ... dataset.start = overall_start + worker_id * per_worker + ... dataset.end = min(dataset.start + per_worker, overall_end) + ... + + >>> # Mult-process loading with the custom `worker_init_fn` + >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) + [3, 5, 4, 6] + + >>> # With even more workers + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) + [3, 4, 5, 6] + """ + + def __add__(self, other: Dataset[_T_co]): + return ChainDataset([self, other]) + + # No `def __len__(self)` default? Subclasses raise `TypeError` when needed. + # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + + +class TensorDataset(Dataset[tuple[Tensor, ...]]): + r"""Dataset wrapping tensors. + + Each sample will be retrieved by indexing tensors along the first dimension. + + Args: + *tensors (Tensor): tensors that have the same size of the first dimension. + """ + + tensors: tuple[Tensor, ...] + + def __init__(self, *tensors: Tensor) -> None: + if all(tensors[0].size(0) != tensor.size(0) for tensor in tensors): + raise AssertionError("Size mismatch between tensors") + self.tensors = tensors + + def __getitem__(self, index): + return tuple(tensor[index] for tensor in self.tensors) + + def __len__(self) -> int: + return self.tensors[0].size(0) + + +class StackDataset(Dataset[_T_stack]): + r"""Dataset as a stacking of multiple datasets. + + This class is useful to assemble different parts of complex input data, given as datasets. + + Example: + >>> # xdoctest: +SKIP + >>> images = ImageDataset() + >>> texts = TextDataset() + >>> tuple_stack = StackDataset(images, texts) + >>> tuple_stack[0] == (images[0], texts[0]) + >>> dict_stack = StackDataset(image=images, text=texts) + >>> dict_stack[0] == {"image": images[0], "text": texts[0]} + + Args: + *args (Dataset): Datasets for stacking returned as tuple. + **kwargs (Dataset): Datasets for stacking returned as dict. + """ + + datasets: tuple | dict + + def __init__(self, *args: Dataset[_T_co], **kwargs: Dataset[_T_co]) -> None: + if args: + if kwargs: + raise ValueError( + "Supported either ``tuple``- (via ``args``) or" + "``dict``- (via ``kwargs``) like input/output, but both types are given." + ) + self._length = len(args[0]) # type: ignore[arg-type] + if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type] + raise ValueError("Size mismatch between datasets") + self.datasets = args + elif kwargs: + tmp = list(kwargs.values()) + self._length = len(tmp[0]) # type: ignore[arg-type] + if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type] + raise ValueError("Size mismatch between datasets") + self.datasets = kwargs + else: + raise ValueError("At least one dataset should be passed") + + def __getitem__(self, index): + if isinstance(self.datasets, dict): + return {k: dataset[index] for k, dataset in self.datasets.items()} + return tuple(dataset[index] for dataset in self.datasets) + + def __getitems__(self, indices: list): + # add batched sampling support when parent datasets supports it. + if isinstance(self.datasets, dict): + dict_batch: list[_T_dict] = [{} for _ in indices] + for k, dataset in self.datasets.items(): + if callable(getattr(dataset, "__getitems__", None)): + items = dataset.__getitems__(indices) # type: ignore[attr-defined] + if len(items) != len(indices): + raise ValueError( + "Nested dataset's output size mismatch." + f" Expected {len(indices)}, got {len(items)}" + ) + for data, d_sample in zip(items, dict_batch, strict=True): + d_sample[k] = data + else: + for idx, d_sample in zip(indices, dict_batch, strict=True): + d_sample[k] = dataset[idx] + return dict_batch + + # tuple data + list_batch: list[list] = [[] for _ in indices] + for dataset in self.datasets: + if callable(getattr(dataset, "__getitems__", None)): + items = dataset.__getitems__(indices) # type: ignore[attr-defined] + if len(items) != len(indices): + raise ValueError( + "Nested dataset's output size mismatch." + f" Expected {len(indices)}, got {len(items)}" + ) + for data, t_sample in zip(items, list_batch, strict=True): + t_sample.append(data) + else: + for idx, t_sample in zip(indices, list_batch, strict=True): + t_sample.append(dataset[idx]) + tuple_batch: list[_T_tuple] = [tuple(sample) for sample in list_batch] + return tuple_batch + + def __len__(self) -> int: + return self._length + + +class ConcatDataset(Dataset[_T_co]): + r"""Dataset as a concatenation of multiple datasets. + + This class is useful to assemble different existing datasets. + + Args: + datasets (sequence): List of datasets to be concatenated + """ + + datasets: list[Dataset[_T_co]] + cumulative_sizes: list[int] + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__() + self.datasets = list(datasets) + if len(self.datasets) == 0: + raise AssertionError("datasets should not be an empty iterable") + for d in self.datasets: + if isinstance(d, IterableDataset): + raise AssertionError("ConcatDataset does not support IterableDataset") + self.cumulative_sizes = self.cumsum(self.datasets) + + def __len__(self) -> int: + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError( + "absolute value of index should not exceed dataset length" + ) + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + @property + @deprecated( + "`cummulative_sizes` attribute is renamed to `cumulative_sizes`", + category=FutureWarning, + ) + def cummulative_sizes(self): + return self.cumulative_sizes + + +class ChainDataset(IterableDataset): + r"""Dataset for chaining multiple :class:`IterableDataset` s. + + This class is useful to assemble different existing dataset streams. The + chaining operation is done on-the-fly, so concatenating large-scale + datasets with this class will be efficient. + + Args: + datasets (iterable of IterableDataset): datasets to be chained together + """ + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__() + self.datasets = datasets + + def __iter__(self): + for d in self.datasets: + if not isinstance(d, IterableDataset): + raise AssertionError("ChainDataset only supports IterableDataset") + yield from d + + def __len__(self) -> int: + total = 0 + for d in self.datasets: + if not isinstance(d, IterableDataset): + raise AssertionError("ChainDataset only supports IterableDataset") + total += len(d) # type: ignore[arg-type] + return total + + +class Subset(Dataset[_T_co]): + r""" + Subset of a dataset at specified indices. + + Args: + dataset (Dataset): The whole Dataset + indices (sequence): Indices in the whole set selected for subset + """ + + dataset: Dataset[_T_co] + indices: Sequence[int] + + def __init__(self, dataset: Dataset[_T_co], indices: Sequence[int]) -> None: + self.dataset = dataset + self.indices = indices + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.dataset[[self.indices[i] for i in idx]] + return self.dataset[self.indices[idx]] + + def __getitems__(self, indices: list[int]) -> list[_T_co]: + # add batched sampling support when parent dataset supports it. + # see torch.utils.data._utils.fetch._MapDatasetFetcher + if callable(getattr(self.dataset, "__getitems__", None)): + return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined] + else: + return [self.dataset[self.indices[idx]] for idx in indices] + + def __len__(self) -> int: + return len(self.indices) + + +def random_split( + dataset: Dataset[_T], + lengths: Sequence[int | float], + generator: Generator | None = default_generator, +) -> list[Subset[_T]]: + r""" + Randomly split a dataset into non-overlapping new datasets of given lengths. + + If a list of fractions that sum up to 1 is given, + the lengths will be computed automatically as + floor(frac * len(dataset)) for each fraction provided. + + After computing the lengths, if there are any remainders, 1 count will be + distributed in round-robin fashion to the lengths + until there are no remainders left. + + Optionally fix the generator for reproducible results, e.g.: + + Example: + >>> # xdoctest: +SKIP + >>> generator1 = torch.Generator().manual_seed(42) + >>> generator2 = torch.Generator().manual_seed(42) + >>> random_split(range(10), [3, 7], generator=generator1) + >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) + + Args: + dataset (Dataset): Dataset to be split + lengths (sequence): lengths or fractions of splits to be produced + generator (Generator): Generator used for the random permutation. + """ + if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: + subset_lengths: list[int] = [] + for i, frac in enumerate(lengths): + if frac < 0 or frac > 1: + raise ValueError(f"Fraction at index {i} is not between 0 and 1") + n_items_in_split = math.floor(len(dataset) * frac) # type: ignore[arg-type] + subset_lengths.append(n_items_in_split) + remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type] + # add 1 to all the lengths in round-robin fashion until the remainder is 0 + for i in range(remainder): + idx_to_add_at = i % len(subset_lengths) + subset_lengths[idx_to_add_at] += 1 + lengths = subset_lengths + for i, length in enumerate(lengths): + if length == 0: + warnings.warn( + f"Length of split at index {i} is 0. " + f"This might result in an empty dataset.", + stacklevel=2, + ) + + # Cannot verify that dataset is Sized + if sum(lengths) != len(dataset): # type: ignore[arg-type] + raise ValueError( + "Sum of input lengths does not equal the length of the input dataset!" + ) + + indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload] + lengths = cast(Sequence[int], lengths) + return [ + Subset(dataset, indices[offset - length : offset]) + for offset, length in zip(itertools.accumulate(lengths), lengths, strict=True) + ] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/distributed.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..5179d7698ffee0f2acda62a2b2073df176aae794 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/distributed.py @@ -0,0 +1,157 @@ +import math +from collections.abc import Iterator +from typing import TypeVar + +import torch +import torch.distributed as dist +from torch.utils.data.dataset import Dataset +from torch.utils.data.sampler import Sampler + + +__all__ = ["DistributedSampler"] + + +_T_co = TypeVar("_T_co", covariant=True) + + +class DistributedSampler(Sampler[_T_co]): + r"""Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each + process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a + :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the + original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size and that any instance of it always + returns the same elements in the same order. + + Args: + dataset: Dataset used for sampling. + num_replicas (int, optional): Number of processes participating in + distributed training. By default, :attr:`world_size` is retrieved from the + current distributed group. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + + .. warning:: + In distributed mode, calling the :meth:`set_epoch` method at + the beginning of each epoch **before** creating the :class:`DataLoader` iterator + is necessary to make shuffling work properly across multiple epochs. Otherwise, + the same ordering will be always used. + + Example:: + + >>> # xdoctest: +SKIP + >>> sampler = DistributedSampler(dataset) if is_distributed else None + >>> loader = DataLoader(dataset, shuffle=(sampler is None), + ... sampler=sampler) + >>> for epoch in range(start_epoch, n_epochs): + ... if is_distributed: + ... sampler.set_epoch(epoch) + ... train(loader) + """ + + def __init__( + self, + dataset: Dataset, + num_replicas: int | None = None, + rank: int | None = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" + ) + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator[_T_co]: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + if len(indices) != self.total_size: + raise AssertionError( + f"Number of indices ({len(indices)}) does not match total_size ({self.total_size})" + ) + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + if len(indices) != self.num_samples: + raise AssertionError( + f"Number of subsampled indices ({len(indices)}) does not match num_samples ({self.num_samples})" + ) + + # pyrefly: ignore [bad-return] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r""" + Set the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/graph.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..f735aa35fec110ccf5d36febf6519227ec166b28 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/graph.py @@ -0,0 +1,161 @@ +# mypy: allow-untyped-defs +import io +import pickle +import warnings +from collections.abc import Collection + +from torch.utils._import_utils import dill_available +from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe + + +__all__ = ["traverse", "traverse_dps"] + +DataPipe = IterDataPipe | MapDataPipe +DataPipeGraph = dict[int, tuple[DataPipe, "DataPipeGraph"]] + + +def _stub_unpickler() -> str: + return "STUB" + + +# TODO(VitalyFedyunin): Make sure it works without dill module installed +def _list_connected_datapipes( + scan_obj: DataPipe, only_datapipe: bool, cache: set[int] +) -> list[DataPipe]: + f = io.BytesIO() + p = pickle.Pickler( + f + ) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is + if dill_available(): + from dill import Pickler as dill_Pickler + + d = dill_Pickler(f) + else: + d = None + + captured_connections = [] + + def getstate_hook(ori_state): + state = None + if isinstance(ori_state, dict): + state = {} + for k, v in ori_state.items(): + if isinstance(v, (IterDataPipe, MapDataPipe, Collection)): + state[k] = v + elif isinstance(ori_state, (tuple, list)): + state = [] # type: ignore[assignment] + for v in ori_state: + if isinstance(v, (IterDataPipe, MapDataPipe, Collection)): + state.append(v) # type: ignore[attr-defined] + elif isinstance(ori_state, (IterDataPipe, MapDataPipe, Collection)): + state = ori_state # type: ignore[assignment] + return state + + def reduce_hook(obj): + if obj == scan_obj or id(obj) in cache: + raise NotImplementedError + else: + captured_connections.append(obj) + # Adding id to remove duplicate DataPipe serialized at the same level + cache.add(id(obj)) + return _stub_unpickler, () + + datapipe_classes: tuple[type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment] + + try: + for cls in datapipe_classes: + cls.set_reduce_ex_hook(reduce_hook) + if only_datapipe: + cls.set_getstate_hook(getstate_hook) + try: + p.dump(scan_obj) + except (pickle.PickleError, AttributeError, TypeError): + if dill_available(): + # pyrefly: ignore [missing-attribute] + d.dump(scan_obj) + else: + raise + finally: + for cls in datapipe_classes: + cls.set_reduce_ex_hook(None) + if only_datapipe: + cls.set_getstate_hook(None) + if dill_available(): + from dill import extend as dill_extend + + dill_extend(False) # Undo change to dispatch table + return captured_connections + + +def traverse_dps(datapipe: DataPipe) -> DataPipeGraph: + r""" + Traverse the DataPipes and their attributes to extract the DataPipe graph. + + This only looks into the attribute from each DataPipe that is either a + DataPipe and a Python collection object such as ``list``, ``tuple``, + ``set`` and ``dict``. + + Args: + datapipe: the end DataPipe of the graph + Returns: + A graph represented as a nested dictionary, where keys are ids of DataPipe instances + and values are tuples of DataPipe instance and the sub-graph + """ + cache: set[int] = set() + return _traverse_helper(datapipe, only_datapipe=True, cache=cache) + + +def traverse(datapipe: DataPipe, only_datapipe: bool | None = None) -> DataPipeGraph: + r""" + Traverse the DataPipes and their attributes to extract the DataPipe graph. + + [Deprecated] + When ``only_dataPipe`` is specified as ``True``, it would only look into the + attribute from each DataPipe that is either a DataPipe and a Python collection object + such as ``list``, ``tuple``, ``set`` and ``dict``. + + Note: + This function is deprecated. Please use `traverse_dps` instead. + + Args: + datapipe: the end DataPipe of the graph + only_datapipe: If ``False`` (default), all attributes of each DataPipe are traversed. + This argument is deprecating and will be removed after the next release. + Returns: + A graph represented as a nested dictionary, where keys are ids of DataPipe instances + and values are tuples of DataPipe instance and the sub-graph + """ + msg = ( + "`traverse` function and will be removed after 1.13. " + "Please use `traverse_dps` instead." + ) + if not only_datapipe: + msg += " And, the behavior will be changed to the equivalent of `only_datapipe=True`." + warnings.warn(msg, FutureWarning, stacklevel=2) + if only_datapipe is None: + only_datapipe = False + cache: set[int] = set() + return _traverse_helper(datapipe, only_datapipe, cache) + + +# Add cache here to prevent infinite recursion on DataPipe +def _traverse_helper( + datapipe: DataPipe, only_datapipe: bool, cache: set[int] +) -> DataPipeGraph: + if not isinstance(datapipe, (IterDataPipe, MapDataPipe)): + raise RuntimeError( + f"Expected `IterDataPipe` or `MapDataPipe`, but {type(datapipe)} is found" + ) + + dp_id = id(datapipe) + if dp_id in cache: + return {} + cache.add(dp_id) + # Using cache.copy() here is to prevent the same DataPipe pollutes the cache on different paths + items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy()) + d: DataPipeGraph = {dp_id: (datapipe, {})} + for item in items: + # Using cache.copy() here is to prevent recursion on a single path rather than global graph + # Single DataPipe can present multiple times in different paths in graph + d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy())) + return d diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/graph_settings.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/graph_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..03096398a6738b29c22aad044caaf16e4c45a7d0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/graph_settings.py @@ -0,0 +1,173 @@ +# mypy: allow-untyped-defs +import inspect +import warnings +from typing import Any +from typing_extensions import deprecated + +import torch +from torch.utils.data.datapipes.iter.sharding import ( + _ShardingIterDataPipe, + SHARDING_PRIORITIES, +) +from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps + + +__all__ = [ + "apply_random_seed", + "apply_sharding", + "apply_shuffle_seed", + "apply_shuffle_settings", + "get_all_graph_pipes", +] + + +def get_all_graph_pipes(graph: DataPipeGraph) -> list[DataPipe]: + return _get_all_graph_pipes_helper(graph, set()) + + +def _get_all_graph_pipes_helper( + graph: DataPipeGraph, id_cache: set[int] +) -> list[DataPipe]: + results: list[DataPipe] = [] + for dp_id, (datapipe, sub_graph) in graph.items(): + if dp_id in id_cache: + continue + id_cache.add(dp_id) + results.append(datapipe) + results.extend(_get_all_graph_pipes_helper(sub_graph, id_cache)) + return results + + +def _is_sharding_datapipe(datapipe: DataPipe) -> bool: + return isinstance(datapipe, _ShardingIterDataPipe) or ( + hasattr(datapipe, "apply_sharding") + and inspect.ismethod(datapipe.apply_sharding) + ) + + +def apply_sharding( + datapipe: DataPipe, + num_of_instances: int, + instance_id: int, + sharding_group=SHARDING_PRIORITIES.DEFAULT, +) -> DataPipe: + r""" + Apply dynamic sharding over the ``sharding_filter`` DataPipe that has a method ``apply_sharding``. + + RuntimeError will be raised when multiple ``sharding_filter`` are presented in the same branch. + """ + graph = traverse_dps(datapipe) + + def _helper(graph, prev_applied=None) -> None: + for dp, sub_graph in graph.values(): + applied = None + if _is_sharding_datapipe(dp): + if prev_applied is not None: + raise RuntimeError( + "Sharding twice on a single pipeline is likely unintended and will cause data loss. " + f"Sharding already applied to {prev_applied} while trying to apply to {dp}" + ) + # For BC, only provide sharding_group if accepted + sig = inspect.signature(dp.apply_sharding) + if len(sig.parameters) < 3: + dp.apply_sharding(num_of_instances, instance_id) + else: + dp.apply_sharding( + num_of_instances, instance_id, sharding_group=sharding_group + ) + applied = dp + if applied is None: + applied = prev_applied + _helper(sub_graph, applied) + + _helper(graph) + + return datapipe + + +def _is_shuffle_datapipe(datapipe: DataPipe) -> bool: + return ( + hasattr(datapipe, "set_shuffle") + and hasattr(datapipe, "set_seed") + and inspect.ismethod(datapipe.set_shuffle) + and inspect.ismethod(datapipe.set_seed) + ) + + +def apply_shuffle_settings(datapipe: DataPipe, shuffle: bool | None = None) -> DataPipe: + r""" + Traverse the graph of ``DataPipes`` to find and set shuffle attribute. + + Apply the method to each `DataPipe` that has APIs of ``set_shuffle`` + and ``set_seed``. + + Args: + datapipe: DataPipe that needs to set shuffle attribute + shuffle: Shuffle option (default: ``None`` and no-op to the graph) + """ + if shuffle is None: + return datapipe + + graph = traverse_dps(datapipe) + all_pipes = get_all_graph_pipes(graph) + shufflers = [pipe for pipe in all_pipes if _is_shuffle_datapipe(pipe)] + if not shufflers and shuffle: + warnings.warn( + "`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. " + "Be aware that the default buffer size might not be sufficient for your task.", + stacklevel=2, + ) + datapipe = datapipe.shuffle() + shufflers = [ + datapipe, + ] + + for shuffler in shufflers: + shuffler.set_shuffle(shuffle) + + return datapipe + + +@deprecated( + "`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases. " + "Please use `apply_random_seed` instead.", + category=FutureWarning, +) +def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe: + return apply_random_seed(datapipe, rng) + + +def _is_random_datapipe(datapipe: DataPipe) -> bool: + return hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed) + + +def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe: + r""" + Traverse the graph of ``DataPipes`` to find random ``DataPipe`` with an API of ``set_seed``. + + Then set the random seed based on the provided RNG to those ``DataPipe``. + + Args: + datapipe: DataPipe that needs to set randomness + rng: Random number generator to generate random seeds + """ + graph = traverse_dps(datapipe) + all_pipes = get_all_graph_pipes(graph) + # Using a set to track id of DataPipe to prevent setting randomness per DataPipe more than once. + # And, `id` is used in case of unhashable DataPipe + cache = set() + random_datapipes = [] + for pipe in all_pipes: + if id(pipe) in cache: + continue + if _is_random_datapipe(pipe): + random_datapipes.append(pipe) + cache.add(id(pipe)) + + for pipe in random_datapipes: + random_seed = int( + torch.empty((), dtype=torch.int64).random_(generator=rng).item() + ) + pipe.set_seed(random_seed) + + return datapipe diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/sampler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..aa13bb8e0a3e146bd7bfbc766fdfcb822efa9313 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/data/sampler.py @@ -0,0 +1,354 @@ +# mypy: allow-untyped-defs +import itertools +from collections.abc import Iterable, Iterator, Sequence, Sized +from typing import Generic, TypeVar + +import torch + + +# Note: For benchmarking changes to samplers, see: +# /benchmarks/data/samplers_bench.py +# This benchmark compares the performance of different sampler implementations +# and can be used to evaluate the impact of optimizations. + + +__all__ = [ + "BatchSampler", + "RandomSampler", + "Sampler", + "SequentialSampler", + "SubsetRandomSampler", + "WeightedRandomSampler", +] + + +_T_co = TypeVar("_T_co", covariant=True) + + +class Sampler(Generic[_T_co]): + r"""Base class for all Samplers. + + Every Sampler subclass has to provide an :meth:`__iter__` method, providing a + way to iterate over indices or lists of indices (batches) of dataset elements, + and may provide a :meth:`__len__` method that returns the length of the returned iterators. + + Example: + >>> # xdoctest: +SKIP + >>> class AccedingSequenceLengthSampler(Sampler[int]): + >>> def __init__(self, data: List[str]) -> None: + >>> self.data = data + >>> + >>> def __len__(self) -> int: + >>> return len(self.data) + >>> + >>> def __iter__(self) -> Iterator[int]: + >>> sizes = torch.tensor([len(x) for x in self.data]) + >>> yield from torch.argsort(sizes).tolist() + >>> + >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): + >>> def __init__(self, data: List[str], batch_size: int) -> None: + >>> self.data = data + >>> self.batch_size = batch_size + >>> + >>> def __len__(self) -> int: + >>> return (len(self.data) + self.batch_size - 1) // self.batch_size + >>> + >>> def __iter__(self) -> Iterator[List[int]]: + >>> sizes = torch.tensor([len(x) for x in self.data]) + >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): + >>> yield batch.tolist() + + .. note:: The :meth:`__len__` method isn't strictly required by + :class:`~torch.utils.data.DataLoader`, but is expected in any + calculation involving the length of a :class:`~torch.utils.data.DataLoader`. + """ + + def __iter__(self) -> Iterator[_T_co]: + raise NotImplementedError + + # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + # + # Many times we have an abstract class representing a collection/iterable of + # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally + # implementing a `__len__` method. In such cases, we must make sure to not + # provide a default implementation, because both straightforward default + # implementations have their issues: + # + # + `return NotImplemented`: + # Calling `len(subclass_instance)` raises: + # TypeError: 'NotImplementedType' object cannot be interpreted as an integer + # + # + `raise NotImplementedError`: + # This prevents triggering some fallback behavior. E.g., the built-in + # `list(X)` tries to call `len(X)` first, and executes a different code + # path if the method is not found or `NotImplemented` is returned, while + # raising a `NotImplementedError` will propagate and make the call fail + # where it could have used `__iter__` to complete the call. + # + # Thus, the only two sensible things to do are + # + # + **not** provide a default `__len__`. + # + # + raise a `TypeError` instead, which is what Python uses when users call + # a method that is not defined on an object. + # (@ssnl verifies that this works on at least Python 3.7.) + + +class SequentialSampler(Sampler[int]): + r"""Samples elements sequentially, always in the same order. + + Args: + data_source (Sized): data source to sample from. Must implement __len__. + """ + + data_source: Sized + + def __init__(self, data_source: Sized) -> None: + self.data_source = data_source + + def __iter__(self) -> Iterator[int]: + return iter(range(len(self.data_source))) + + def __len__(self) -> int: + return len(self.data_source) + + +class RandomSampler(Sampler[int]): + r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. + + If with replacement, then user can specify :attr:`num_samples` to draw. + + Args: + data_source (Sized): data source to sample from. Must implement __len__. + replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` + num_samples (int): number of samples to draw, default=`len(dataset)`. + generator (Generator): Generator used in sampling. + """ + + data_source: Sized + replacement: bool + + def __init__( + self, + data_source: Sized, + replacement: bool = False, + num_samples: int | None = None, + generator=None, + ) -> None: + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + + if not isinstance(self.replacement, bool): + raise TypeError( + f"replacement should be a boolean value, but got replacement={self.replacement}" + ) + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError( + f"num_samples should be a positive integer value, but got num_samples={self.num_samples}" + ) + + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + if self.generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = self.generator + + if self.replacement: + for _ in range(self.num_samples // 32): + yield from torch.randint( + high=n, size=(32,), dtype=torch.int64, generator=generator + ).tolist() + yield from torch.randint( + high=n, + size=(self.num_samples % 32,), + dtype=torch.int64, + generator=generator, + ).tolist() + else: + for _ in range(self.num_samples // n): + yield from torch.randperm(n, generator=generator).tolist() + yield from torch.randperm(n, generator=generator).tolist()[ + : self.num_samples % n + ] + + def __len__(self) -> int: + return self.num_samples + + +class SubsetRandomSampler(Sampler[int]): + r"""Samples elements randomly from a given list of indices, without replacement. + + Args: + indices (sequence): a sequence of indices + generator (Generator): Generator used in sampling. + """ + + indices: Sequence[int] + + def __init__(self, indices: Sequence[int], generator=None) -> None: + self.indices = indices + self.generator = generator + + def __iter__(self) -> Iterator[int]: + for i in torch.randperm(len(self.indices), generator=self.generator).tolist(): + yield self.indices[i] + + def __len__(self) -> int: + return len(self.indices) + + +class WeightedRandomSampler(Sampler[int]): + r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). + + Args: + weights (sequence) : a sequence of weights, not necessary summing up to one + num_samples (int): number of samples to draw + replacement (bool): if ``True``, samples are drawn with replacement. + If not, they are drawn without replacement, which means that when a + sample index is drawn for a row, it cannot be drawn again for that row. + generator (Generator): Generator used in sampling. + + Example: + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> list( + ... WeightedRandomSampler( + ... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True + ... ) + ... ) + [4, 4, 1, 4, 5] + >>> list( + ... WeightedRandomSampler( + ... [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False + ... ) + ... ) + [0, 1, 4, 3, 2] + """ + + weights: torch.Tensor + num_samples: int + replacement: bool + + def __init__( + self, + weights: Sequence[float], + num_samples: int, + replacement: bool = True, + generator=None, + ) -> None: + if ( + not isinstance(num_samples, int) + or isinstance(num_samples, bool) + or num_samples <= 0 + ): + raise ValueError( + f"num_samples should be a positive integer value, but got num_samples={num_samples}" + ) + if not isinstance(replacement, bool): + raise ValueError( + f"replacement should be a boolean value, but got replacement={replacement}" + ) + + weights_tensor = torch.as_tensor(weights, dtype=torch.double) + if len(weights_tensor.shape) != 1: + raise ValueError( + "weights should be a 1d sequence but given " + f"weights have shape {tuple(weights_tensor.shape)}" + ) + + self.weights = weights_tensor + self.num_samples = num_samples + self.replacement = replacement + self.generator = generator + + def __iter__(self) -> Iterator[int]: + rand_tensor = torch.multinomial( + self.weights, self.num_samples, self.replacement, generator=self.generator + ) + yield from iter(rand_tensor.tolist()) + + def __len__(self) -> int: + return self.num_samples + + +class BatchSampler(Sampler[list[int]]): + r"""Wraps another sampler to yield a mini-batch of indices. + + Args: + sampler (Sampler or Iterable): Base sampler. Can be any iterable object + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + + Example: + >>> list( + ... BatchSampler( + ... SequentialSampler(range(10)), batch_size=3, drop_last=False + ... ) + ... ) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + >>> list( + ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True) + ... ) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + + def __init__( + self, + sampler: Sampler[int] | Iterable[int], + batch_size: int, + drop_last: bool, + ) -> None: + # Since collections.abc.Iterable does not check for `__getitem__`, which + # is one way for an object to be an iterable, we don't do an `isinstance` + # check here. + if ( + not isinstance(batch_size, int) + or isinstance(batch_size, bool) + or batch_size <= 0 + ): + raise ValueError( + f"batch_size should be a positive integer value, but got batch_size={batch_size}" + ) + if not isinstance(drop_last, bool): + raise ValueError( + f"drop_last should be a boolean value, but got drop_last={drop_last}" + ) + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self) -> Iterator[list[int]]: + sampler_iter = iter(self.sampler) + if self.drop_last: + # Create multiple references to the same iterator + args = [sampler_iter] * self.batch_size + for batch_droplast in zip(*args, strict=False): + yield [*batch_droplast] + else: + batch = [*itertools.islice(sampler_iter, self.batch_size)] + while batch: + yield batch + batch = [*itertools.islice(sampler_iter, self.batch_size)] + + def __len__(self) -> int: + # Can only be called if self.sampler has __len__ implemented + # We cannot enforce this condition, so we turn off typechecking for the + # implementation below. + # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + if self.drop_last: + return len(self.sampler) // self.batch_size # type: ignore[arg-type] + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58f3ace6c03d093337c9fa417ccbe8bc267b6c69 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__init__.py @@ -0,0 +1 @@ +from .version import __version__ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20db67090acc230be6463a4ec87313be50bb6d24 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/constants.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/constants.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f15b4ef9323dcd06a7fbe93772eed97071e10855 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/constants.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/hipify_python.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/hipify_python.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7f409257bb8ca666a1eb8b403f5819d49680ca9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/hipify_python.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/version.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/version.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..121f10e1229d9d2befcc38dd994ce9041f58147b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/version.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/constants.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a9053b261ad44d1ef8b8cbdf3a27da0306d92f36 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/constants.py @@ -0,0 +1,62 @@ +"""Constants for annotations in the mapping. + +The constants defined here are used to annotate the mapping tuples in cuda_to_hip_mappings.py. +They are based on +https://github.com/ROCm/HIPIFY/blob/master/src/Statistics.h +and fall in three categories: 1) type of mapping, 2) API of mapping, 3) unsupported +mapping. +""" + +CONV_VERSION = 0, +CONV_INIT = 1 +CONV_DEVICE = 2 +CONV_MEM = 3 +CONV_KERN = 4 +CONV_COORD_FUNC = 5 +CONV_MATH_FUNC = 6 +CONV_DEVICE_FUNC = 7 +CONV_SPECIAL_FUNC = 8 +CONV_STREAM = 9 +CONV_EVENT = 10 +CONV_OCCUPANCY = 11 +CONV_CONTEXT = 12 +CONV_PEER = 13 +CONV_MODULE = 14 +CONV_CACHE = 15 +CONV_EXEC = 16 +CONV_ERROR = 17 +CONV_DEF = 18 +CONV_TEX = 19 +CONV_GL = 20 +CONV_GRAPHICS = 21 +CONV_SURFACE = 22 +CONV_JIT = 23 +CONV_D3D9 = 24 +CONV_D3D10 = 25 +CONV_D3D11 = 26 +CONV_VDPAU = 27 +CONV_EGL = 28 +CONV_THREAD = 29 +CONV_OTHER = 30 +CONV_INCLUDE = 31 +CONV_INCLUDE_CUDA_MAIN_H = 32 +CONV_TYPE = 33 +CONV_LITERAL = 34 +CONV_NUMERIC_LITERAL = 35 +CONV_LAST = 36 + +API_DRIVER = 37 +API_RUNTIME = 38 +API_BLAS = 39 +API_SPECIAL = 40 +API_RAND = 41 +API_LAST = 42 +API_FFT = 43 +API_RTC = 44 +API_ROCTX = 45 + +HIP_UNSUPPORTED = 46 +API_PYTORCH = 1337 +API_CAFFE2 = 1338 +API_C10 = 1339 +API_ROCMSMI = 1340 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/cuda_to_hip_mappings.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/cuda_to_hip_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..8bf93cf5e6d61122eabd9dc7a4884fcab9c4dad6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/cuda_to_hip_mappings.py @@ -0,0 +1,9492 @@ +import collections +import os + +from .constants import (API_BLAS, API_C10, API_CAFFE2, API_DRIVER, API_FFT, + API_PYTORCH, API_RAND, API_ROCTX, API_RTC, API_RUNTIME, + API_SPECIAL, API_ROCMSMI, CONV_CACHE, CONV_CONTEXT, CONV_D3D9, + CONV_D3D10, CONV_D3D11, CONV_DEF, CONV_DEVICE, + CONV_DEVICE_FUNC, CONV_EGL, CONV_ERROR, CONV_EVENT, + CONV_EXEC, CONV_GL, CONV_GRAPHICS, CONV_INCLUDE, + CONV_INCLUDE_CUDA_MAIN_H, CONV_INIT, CONV_JIT, + CONV_MATH_FUNC, CONV_MEM, CONV_MODULE, + CONV_NUMERIC_LITERAL, CONV_OCCUPANCY, CONV_OTHER, + CONV_PEER, CONV_SPECIAL_FUNC, CONV_STREAM, + CONV_SURFACE, CONV_TEX, CONV_THREAD, CONV_TYPE, + CONV_VDPAU, CONV_VERSION, HIP_UNSUPPORTED) + +""" Mapping of CUDA functions, include files, constants, and types to ROCm/HIP equivalents +This closely follows the implementation in hipify-clang +https://github.com/ROCm/hip/blob/59071b895ed1c86d9698b4c859cefcdd5acda06f/hipify-clang/src/CUDA2HipMap.cpp +and its structure. +There are different maps for fundamental names, include files, identifies, sparse, and +PyTorch specific translations. +Each of the entries in these maps translates a CUDA string to a tuple containing the +ROCm/HIP string, a type and API annotation and - optionally - an annotation if it is not +supported in ROCm/HIP yet. +""" + +_IS_FBCODE = os.environ.get("IS_FBCODE", "0") == "1" + +# FBCODE compiles against rccl sources instead of an installed rccl package. +# The header location is src/rccl.h versus rccl/rccl.h, respectively. +_RCCL_HEADER = "" if _IS_FBCODE else "" + +# List of math functions that should be replaced inside device code only. +MATH_TRANSPILATIONS = collections.OrderedDict( + [ + ("std::max", ("::max")), + ("std::min", ("::min")), + ("std::ceil", ("::ceil")), + ("std::floor", ("::floor")), + ("std::exp", ("::exp")), + ("std::log", ("::log")), + ("std::pow", ("::pow")), + ("std::fabs", ("::fabs")), + ("std::fmod", ("::fmod")), + ("std::remainder", ("::remainder")), + ("std::frexp", ("::frexp")), + ] +) + +# pyrefly: ignore [no-matching-overload] +CUDA_TYPE_NAME_MAP = collections.OrderedDict( + [ + ("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)), + ("cudaError_t", ("hipError_t", CONV_TYPE, API_RUNTIME)), + ("cudaError", ("hipError_t", CONV_TYPE, API_RUNTIME)), + ( + "CUDA_ARRAY3D_DESCRIPTOR", + ("HIP_ARRAY3D_DESCRIPTOR", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUDA_ARRAY_DESCRIPTOR", ("HIP_ARRAY_DESCRIPTOR", CONV_TYPE, API_DRIVER)), + ("CUDA_MEMCPY2D", ("hip_Memcpy2D", CONV_TYPE, API_DRIVER)), + ("CUDA_MEMCPY3D", ("HIP_MEMCPY3D", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUDA_MEMCPY3D_PEER", + ("HIP_MEMCPY3D_PEER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_POINTER_ATTRIBUTE_P2P_TOKENS", + ( + "HIP_POINTER_ATTRIBUTE_P2P_TOKENS", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CUDA_RESOURCE_DESC", + ("HIP_RESOURCE_DESC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_RESOURCE_VIEW_DESC", + ("HIP_RESOURCE_VIEW_DESC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUipcEventHandle", + ("hipIpcEventHandle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUipcMemHandle", ("hipIpcMemHandle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUaddress_mode", ("hipAddress_mode", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUarray_cubemap_face", + ("hipArray_cubemap_face", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUarray_format", ("hipArray_format", CONV_TYPE, API_DRIVER)), + ("CUcomputemode", ("hipComputemode", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUmem_advise", ("hipMemAdvise", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUmem_range_attribute", + ("hipMemRangeAttribute", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUctx_flags", ("hipCctx_flags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUdevice", ("hipDevice_t", CONV_TYPE, API_DRIVER)), + ("CUdevice_attribute_enum", ("hipDeviceAttribute_t", CONV_TYPE, API_DRIVER)), + ("CUdevice_attribute", ("hipDeviceAttribute_t", CONV_TYPE, API_DRIVER)), + ("CUpointer_attribute", ("hipPointer_attribute", CONV_TYPE, API_DRIVER)), + ("CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL", ("HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL", CONV_TYPE, API_DRIVER)), + ("CU_POINTER_ATTRIBUTE_BUFFER_ID", ("HIP_POINTER_ATTRIBUTE_BUFFER_ID", CONV_TYPE, API_DRIVER)), + ("CUdeviceptr", ("hipDeviceptr_t", CONV_TYPE, API_DRIVER)), + ("CUarray_st", ("hipArray", CONV_TYPE, API_DRIVER)), + ("CUarray", ("hipArray *", CONV_TYPE, API_DRIVER)), + ("CUdevprop_st", ("hipDeviceProp_t", CONV_TYPE, API_DRIVER)), + ("CUdevprop", ("hipDeviceProp_t", CONV_TYPE, API_DRIVER)), + ("CUfunction", ("hipFunction_t", CONV_TYPE, API_DRIVER)), + ( + "CUgraphicsResource", + ("hipGraphicsResource_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUmipmappedArray", + ("hipMipmappedArray_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUfunction_attribute", + ("hipFuncAttribute_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUfunction_attribute_enum", + ("hipFuncAttribute_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsMapResourceFlags", + ("hipGraphicsMapFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsMapResourceFlags_enum", + ("hipGraphicsMapFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsRegisterFlags", + ("hipGraphicsRegisterFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsRegisterFlags_enum", + ("hipGraphicsRegisterFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUoccupancy_flags", + ("hipOccupancyFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUoccupancy_flags_enum", + ("hipOccupancyFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUfunc_cache_enum", ("hipFuncCache", CONV_TYPE, API_DRIVER)), + ("CUfunc_cache", ("hipFuncCache", CONV_TYPE, API_DRIVER)), + ("CUipcMem_flags", ("hipIpcMemFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUipcMem_flags_enum", + ("hipIpcMemFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUjit_cacheMode", ("hipJitCacheMode", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUjit_cacheMode_enum", + ("hipJitCacheMode", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUjit_fallback", ("hipJitFallback", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUjit_fallback_enum", + ("hipJitFallback", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUjit_option", ("hipJitOption", CONV_JIT, API_DRIVER)), + ("CUjit_option_enum", ("hipJitOption", CONV_JIT, API_DRIVER)), + ("CUjit_target", ("hipJitTarget", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ("CUjit_target_enum", ("hipJitTarget", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ("CUjitInputType", ("hipJitInputType", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUjitInputType_enum", + ("hipJitInputType", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUlimit", ("hipLimit_t", CONV_TYPE, API_DRIVER)), + ("CUlimit_enum", ("hipLimit_t", CONV_TYPE, API_DRIVER)), + ("CUmemAccessDesc", ("hipMemAccessDesc", CONV_TYPE, API_DRIVER)), + ("CUmemAccessDesc_st", ("hipMemAccessDesc", CONV_TYPE, API_DRIVER)), + ("CUmemAccessDesc_v1", ("hipMemAccessDesc", CONV_TYPE, API_DRIVER)), + ( + "CUmemAttach_flags", + ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUmemAttach_flags_enum", + ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUmemAllocationGranularity_flags", ("hipMemAllocationGranularity_flags", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationGranularity_flags_enum", ("hipMemAllocationGranularity_flags", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationHandleType", ("hipMemAllocationHandleType", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationHandleType_enum", ("hipMemAllocationHandleType", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationProp", ("hipMemAllocationProp", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationProp_st", ("hipMemAllocationProp", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationProp_v1", ("hipMemAllocationProp", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationType", ("hipMemAllocationType", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationType_enum", ("hipMemAllocationType", CONV_TYPE, API_DRIVER)), + ("CUmemGenericAllocationHandle", ("hipMemGenericAllocationHandle_t", CONV_TYPE, API_DRIVER)), + ("CUmemGenericAllocationHandle_v1", ("hipMemGenericAllocationHandle_t", CONV_TYPE, API_DRIVER)), + ("CUmemHandleType", ("hipMemHandleType", CONV_TYPE, API_DRIVER)), + ("CUmemHandleType_enum", ("hipMemHandleType", CONV_TYPE, API_DRIVER)), + ("CUmemLocation", ("hipMemLocation", CONV_TYPE, API_DRIVER)), + ("CUmemLocationType", ("hipMemLocationType", CONV_TYPE, API_DRIVER)), + ("CUmemLocationType_enum", ("hipMemLocationType", CONV_TYPE, API_DRIVER)), + ("CUmemLocation_st", ("hipMemLocation", CONV_TYPE, API_DRIVER)), + ("CUmemLocation_v1", ("hipMemLocation", CONV_TYPE, API_DRIVER)), + ("CUmemOperationType", ("hipMemOperationType", CONV_TYPE, API_DRIVER)), + ("CUmemOperationType_enum", ("hipMemOperationType", CONV_TYPE, API_DRIVER)), + ("CUmemPoolHandle_st", ("ihipMemPoolHandle_t", CONV_TYPE, API_DRIVER)), + ("CUmemPoolProps", ("hipMemPoolProps", CONV_TYPE, API_DRIVER)), + ("CUmemPoolProps_st", ("hipMemPoolProps", CONV_TYPE, API_DRIVER)), + ("CUmemPoolProps_v1", ("hipMemPoolProps", CONV_TYPE, API_DRIVER)), + ("CUmemPoolPtrExportData", ("hipMemPoolPtrExportData", CONV_TYPE, API_DRIVER)), + ("CUmemPoolPtrExportData_st", ("hipMemPoolPtrExportData", CONV_TYPE, API_DRIVER)), + ("CUmemPoolPtrExportData_v1", ("hipMemPoolPtrExportData", CONV_TYPE, API_DRIVER)), + ("CUmemPool_attribute", ("hipMemPoolAttr", CONV_TYPE, API_DRIVER)), + ("CUmemPool_attribute_enum", ("hipMemPoolAttr", CONV_TYPE, API_DRIVER)), + ("CUmem_advise_enum", ("hipMemoryAdvise", CONV_TYPE, API_DRIVER)), + ("CUmem_range_attribute_enum", ("hipMemRangeAttribute", CONV_TYPE, API_DRIVER)), + ("CUmemoryPool", ("hipMemPool_t", CONV_TYPE, API_DRIVER)), + ("CUmemorytype", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUmemorytype_enum", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUresourcetype", ("hipResourceType", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUresourcetype_enum", + ("hipResourceType", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUresourceViewFormat", ("hipResourceViewFormat", CONV_TEX, API_DRIVER)), + ("CUresourceViewFormat_enum", ("hipResourceViewFormat", CONV_TEX, API_DRIVER)), + ("CUsharedconfig", ("hipSharedMemConfig", CONV_TYPE, API_DRIVER)), + ("CUsharedconfig_enum", ("hipSharedMemConfig", CONV_TYPE, API_DRIVER)), + ("CUcontext", ("hipCtx_t", CONV_TYPE, API_DRIVER)), + ("CUmodule", ("hipModule_t", CONV_TYPE, API_DRIVER)), + ("CUstream", ("hipStream_t", CONV_TYPE, API_DRIVER)), + ("CUstream_st", ("ihipStream_t", CONV_TYPE, API_DRIVER)), + ("CUstreamCallback", ("hipStreamCallback_t", CONV_TYPE, API_DRIVER)), + ("CUsurfObject", ("hipSurfaceObject", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUsurfref", + ("hipSurfaceReference_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUtexObject", ("hipTextureObject_t", CONV_TYPE, API_DRIVER)), + ("CUtexref", ("textureReference", CONV_TYPE, API_DRIVER)), + ("CUstream_flags", ("hipStreamFlags", CONV_TYPE, API_DRIVER)), + ( + "CUstreamWaitValue_flags", + ("hipStreamWaitValueFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUstreamWriteValue_flags", + ("hipStreamWriteValueFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUstreamBatchMemOpType", + ("hipStreamBatchMemOpType", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUdevice_P2PAttribute", + ("hipDeviceP2PAttribute", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUevent", ("hipEvent_t", CONV_TYPE, API_DRIVER)), + ("CUevent_st", ("ihipEvent_t", CONV_TYPE, API_DRIVER)), + ("CUevent_flags", ("hipEventFlags", CONV_EVENT, API_DRIVER, HIP_UNSUPPORTED)), + ("CUfilter_mode", ("hipTextureFilterMode", CONV_TEX, API_DRIVER)), + ("CUGLDeviceList", ("hipGLDeviceList", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ("CUGLmap_flags", ("hipGLMapFlags", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUd3d9DeviceList", + ("hipD3D9DeviceList", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d9map_flags", + ("hipD3D9MapFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d9register_flags", + ("hipD3D9RegisterFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d10DeviceList", + ("hipd3d10DeviceList", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d10map_flags", + ("hipD3D10MapFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d10register_flags", + ("hipD3D10RegisterFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d11DeviceList", + ("hipd3d11DeviceList", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUeglStreamConnection_st", + ("hipEglStreamConnection", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUeglStreamConnection", + ("hipEglStreamConnection", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "libraryPropertyType_t", + ("hipLibraryPropertyType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "libraryPropertyType", + ("hipLibraryPropertyType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaStreamCallback_t", ("hipStreamCallback_t", CONV_TYPE, API_RUNTIME)), + ("cudaArray", ("hipArray", CONV_MEM, API_RUNTIME)), + ("cudaArray_t", ("hipArray_t", CONV_MEM, API_RUNTIME)), + ("cudaArray_const_t", ("hipArray_const_t", CONV_MEM, API_RUNTIME)), + ("cudaMipmappedArray_t", ("hipMipmappedArray_t", CONV_MEM, API_RUNTIME)), + ( + "cudaMipmappedArray_const_t", + ("hipMipmappedArray_const_t", CONV_MEM, API_RUNTIME), + ), + ("cudaArrayDefault", ("hipArrayDefault", CONV_MEM, API_RUNTIME)), + ("cudaArrayLayered", ("hipArrayLayered", CONV_MEM, API_RUNTIME)), + ( + "cudaArraySurfaceLoadStore", + ("hipArraySurfaceLoadStore", CONV_MEM, API_RUNTIME), + ), + ("cudaArrayCubemap", ("hipArrayCubemap", CONV_MEM, API_RUNTIME)), + ("cudaArrayTextureGather", ("hipArrayTextureGather", CONV_MEM, API_RUNTIME)), + ("cudaMemoryAdvise", ("hipMemoryAdvise", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaMemRangeAttribute", + ("hipMemRangeAttribute", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpyKind", ("hipMemcpyKind", CONV_MEM, API_RUNTIME)), + ("cudaMemoryType", ("hipMemoryType", CONV_MEM, API_RUNTIME)), + ("cudaExtent", ("hipExtent", CONV_MEM, API_RUNTIME)), + ("cudaPitchedPtr", ("hipPitchedPtr", CONV_MEM, API_RUNTIME)), + ("cudaPos", ("hipPos", CONV_MEM, API_RUNTIME)), + ("cudaEvent_t", ("hipEvent_t", CONV_TYPE, API_RUNTIME)), + ("cudaStream_t", ("hipStream_t", CONV_TYPE, API_RUNTIME)), + ("cudaHostFn_t", ("hipHostFn_t", CONV_TYPE, API_RUNTIME)), + ("cudaPointerAttributes", ("hipPointerAttribute_t", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceAttr", ("hipDeviceAttribute_t", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceProp", ("hipDeviceProp_t", CONV_TYPE, API_RUNTIME)), + ( + "cudaDeviceP2PAttr", + ("hipDeviceP2PAttribute", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeMode", + ("hipComputeMode", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaFuncCache", ("hipFuncCache_t", CONV_CACHE, API_RUNTIME)), + ( + "cudaFuncAttributes", + ("hipFuncAttributes", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaSharedMemConfig", ("hipSharedMemConfig", CONV_TYPE, API_RUNTIME)), + ("cudaLimit", ("hipLimit_t", CONV_TYPE, API_RUNTIME)), + ("cudaOutputMode", ("hipOutputMode", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED)), + ("cudaTextureReadMode", ("hipTextureReadMode", CONV_TEX, API_RUNTIME)), + ("cudaTextureFilterMode", ("hipTextureFilterMode", CONV_TEX, API_RUNTIME)), + ("cudaChannelFormatKind", ("hipChannelFormatKind", CONV_TEX, API_RUNTIME)), + ("cudaChannelFormatDesc", ("hipChannelFormatDesc", CONV_TEX, API_RUNTIME)), + ("cudaResourceDesc", ("hipResourceDesc", CONV_TEX, API_RUNTIME)), + ("cudaResourceViewDesc", ("hipResourceViewDesc", CONV_TEX, API_RUNTIME)), + ("cudaTextureDesc", ("hipTextureDesc", CONV_TEX, API_RUNTIME)), + ( + "surfaceReference", + ("hipSurfaceReference", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaTextureObject_t", ("hipTextureObject_t", CONV_TEX, API_RUNTIME)), + ("cudaResourceType", ("hipResourceType", CONV_TEX, API_RUNTIME)), + ("cudaResourceViewFormat", ("hipResourceViewFormat", CONV_TEX, API_RUNTIME)), + ("cudaTextureAddressMode", ("hipTextureAddressMode", CONV_TEX, API_RUNTIME)), + ( + "cudaSurfaceBoundaryMode", + ("hipSurfaceBoundaryMode", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaSurfaceFormatMode", + ("hipSurfaceFormatMode", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaTextureType1D", ("hipTextureType1D", CONV_TEX, API_RUNTIME)), + ("cudaTextureType2D", ("hipTextureType2D", CONV_TEX, API_RUNTIME)), + ("cudaTextureType3D", ("hipTextureType3D", CONV_TEX, API_RUNTIME)), + ("cudaTextureTypeCubemap", ("hipTextureTypeCubemap", CONV_TEX, API_RUNTIME)), + ( + "cudaTextureType1DLayered", + ("hipTextureType1DLayered", CONV_TEX, API_RUNTIME), + ), + ( + "cudaTextureType2DLayered", + ("hipTextureType2DLayered", CONV_TEX, API_RUNTIME), + ), + ( + "cudaTextureTypeCubemapLayered", + ("hipTextureTypeCubemapLayered", CONV_TEX, API_RUNTIME), + ), + ("cudaIpcEventHandle_t", ("hipIpcEventHandle_t", CONV_TYPE, API_RUNTIME)), + ("cudaIpcEventHandle_st", ("hipIpcEventHandle_t", CONV_TYPE, API_RUNTIME)), + ("cudaIpcMemHandle_t", ("hipIpcMemHandle_t", CONV_TYPE, API_RUNTIME)), + ("cudaIpcMemHandle_st", ("hipIpcMemHandle_t", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphicsCubeFace", + ("hipGraphicsCubeFace", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsMapFlags", + ("hipGraphicsMapFlags", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsRegisterFlags", + ("hipGraphicsRegisterFlags", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLDeviceList", + ("hipGLDeviceList", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaGLMapFlags", ("hipGLMapFlags", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaD3D9DeviceList", + ("hipD3D9DeviceList", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9MapFlags", + ("hipD3D9MapFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9RegisterFlags", + ("hipD3D9RegisterFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10DeviceList", + ("hipd3d10DeviceList", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10MapFlags", + ("hipD3D10MapFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10RegisterFlags", + ("hipD3D10RegisterFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11DeviceList", + ("hipd3d11DeviceList", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEglStreamConnection", + ("hipEglStreamConnection", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cublasHandle_t", ("hipblasHandle_t", CONV_TYPE, API_BLAS)), + ("cublasOperation_t", ("hipblasOperation_t", CONV_TYPE, API_BLAS)), + ("cublasStatus_t", ("hipblasStatus_t", CONV_TYPE, API_BLAS)), + ("cublasFillMode_t", ("hipblasFillMode_t", CONV_TYPE, API_BLAS)), + ("cublasDiagType_t", ("hipblasDiagType_t", CONV_TYPE, API_BLAS)), + ("cublasSideMode_t", ("hipblasSideMode_t", CONV_TYPE, API_BLAS)), + ("cublasPointerMode_t", ("hipblasPointerMode_t", CONV_TYPE, API_BLAS)), + ("cublasGemmAlgo_t", ("hipblasGemmAlgo_t", CONV_TYPE, API_BLAS)), + ( + "cublasAtomicsMode_t", + ("hipblasAtomicsMode_t", CONV_TYPE, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDataType_t", + ("hipblasDatatype_t", CONV_TYPE, API_BLAS, HIP_UNSUPPORTED), + ), + ("curandStatus", ("hiprandStatus_t", CONV_TYPE, API_RAND)), + ("curandStatus_t", ("hiprandStatus_t", CONV_TYPE, API_RAND)), + ("curandRngType", ("hiprandRngType_t", CONV_TYPE, API_RAND)), + ("curandRngType_t", ("hiprandRngType_t", CONV_TYPE, API_RAND)), + ("curandGenerator_st", ("hiprandGenerator_st", CONV_TYPE, API_RAND)), + ("curandGenerator_t", ("hiprandGenerator_t", CONV_TYPE, API_RAND)), + ( + "curandDirectionVectorSet", + ("hiprandDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDirectionVectorSet_t", + ("hiprandDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ("curandOrdering", ("hiprandOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), + ( + "curandOrdering_t", + ("hiprandOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistribution_st", + ("hiprandDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2V_st", + ("hiprandDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistribution_t", + ("hiprandDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2V_t", + ("hiprandDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionShift_st", + ("hiprandDistributionShift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionShift_t", + ("hiprandDistributionShift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionM2Shift_st", + ("hiprandDistributionM2Shift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionM2Shift_t", + ("hiprandDistributionM2Shift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2_st", + ("hiprandHistogramM2_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2_t", + ("hiprandHistogramM2_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2K_st", + ("hiprandHistogramM2K_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2K_t", + ("hiprandHistogramM2K_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDiscreteDistribution_st", + ("hiprandDiscreteDistribution_st", CONV_TYPE, API_RAND), + ), + ( + "curandDiscreteDistribution_t", + ("hiprandDiscreteDistribution_t", CONV_TYPE, API_RAND), + ), + ("curandMethod", ("hiprandMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), + ("curandMethod_t", ("hiprandMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), + ( + "curandDirectionVectors32_t", + ("hiprandDirectionVectors32_t", CONV_TYPE, API_RAND), + ), + ( + "curandDirectionVectors64_t", + ("hiprandDirectionVectors64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ("curandStateMtgp32_t", ("hiprandStateMtgp32_t", CONV_TYPE, API_RAND)), + ("curandStateMtgp32", ("hiprandStateMtgp32_t", CONV_TYPE, API_RAND)), + ( + "curandStateScrambledSobol64_t", + ("hiprandStateScrambledSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandStateSobol64_t", + ("hiprandStateSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandStateScrambledSobol32_t", + ("hiprandStateScrambledSobol32_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ("curandStateSobol32_t", ("hiprandStateSobol32_t", CONV_TYPE, API_RAND)), + ("curandStateMRG32k3a_t", ("hiprandStateMRG32k3a_t", CONV_TYPE, API_RAND)), + ( + "curandStatePhilox4_32_10_t", + ("hiprandStatePhilox4_32_10_t", CONV_TYPE, API_RAND), + ), + ("curandStateXORWOW_t", ("hiprandStateXORWOW_t", CONV_TYPE, API_RAND)), + ("curandState_t", ("hiprandState_t", CONV_TYPE, API_RAND)), + ("curandState", ("hiprandState_t", CONV_TYPE, API_RAND)), + ("CUuuid", ("hipUUID", CONV_TYPE, API_RUNTIME)), + ("cudaGraph_t", ("hipGraph_t", CONV_TYPE, API_RAND)), + ("cudaGraphNode_t", ("hipGraphNode_t", CONV_TYPE, API_RAND)), + ("cudaGraphExec_t", ("hipGraphExec_t", CONV_TYPE, API_RAND)), + ("__nv_bfloat16", ("__hip_bfloat16", CONV_TYPE, API_RUNTIME)), + ("__nv_bfloat162", ("__hip_bfloat162", CONV_TYPE, API_RUNTIME)), + ] +) + +# pyrefly: ignore [no-matching-overload] +CUDA_INCLUDE_MAP = collections.OrderedDict( + [ + # since pytorch uses "\b{pattern}\b" as the actual re pattern, + # patterns listed here have to begin and end with alnum chars + ( + "include " to differentiate + ("", (_RCCL_HEADER, CONV_INCLUDE, API_RUNTIME)), + ("nvrtc.h", ("hip/hiprtc.h", CONV_INCLUDE, API_RTC)), + ("thrust/system/cuda", ("thrust/system/hip", CONV_INCLUDE, API_BLAS)), + ("cub/util_allocator.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_reduce.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_raking_layout.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/cub.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/config.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/util_ptx.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/util_type.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_run_length_encode.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_load.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_store.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_scan.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_radix_sort.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_reduce.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_scan.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_select.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("nvtx3/nvtx3.hpp", ("roctracer/roctx.h", CONV_INCLUDE, API_ROCTX)), + ("nvToolsExt.h", ("roctracer/roctx.h", CONV_INCLUDE, API_ROCTX)), + ("nvml.h", ("rocm_smi/rocm_smi.h", CONV_INCLUDE, API_ROCMSMI)), + ] +) + +# pyrefly: ignore [no-matching-overload] +CUDA_IDENTIFIER_MAP = collections.OrderedDict( + [ + ("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)), + ( + "CUDA_ERROR_INVALID_CONTEXT", + ("hipErrorInvalidContext", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_CONTEXT_ALREADY_CURRENT", + ("hipErrorContextAlreadyCurrent", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_ARRAY_IS_MAPPED", + ("hipErrorArrayIsMapped", CONV_TYPE, API_DRIVER), + ), + ("CUDA_ERROR_ALREADY_MAPPED", ("hipErrorAlreadyMapped", CONV_TYPE, API_DRIVER)), + ( + "CUDA_ERROR_ALREADY_ACQUIRED", + ("hipErrorAlreadyAcquired", CONV_TYPE, API_DRIVER), + ), + ("CUDA_ERROR_NOT_MAPPED", ("hipErrorNotMapped", CONV_TYPE, API_DRIVER)), + ( + "CUDA_ERROR_NOT_MAPPED_AS_ARRAY", + ("hipErrorNotMappedAsArray", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_NOT_MAPPED_AS_POINTER", + ("hipErrorNotMappedAsPointer", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_CONTEXT_ALREADY_IN_USE", + ("hipErrorContextAlreadyInUse", CONV_TYPE, API_DRIVER), + ), + ("CUDA_ERROR_INVALID_SOURCE", ("hipErrorInvalidSource", CONV_TYPE, API_DRIVER)), + ("CUDA_ERROR_FILE_NOT_FOUND", ("hipErrorFileNotFound", CONV_TYPE, API_DRIVER)), + ("CUDA_ERROR_NOT_FOUND", ("hipErrorNotFound", CONV_TYPE, API_DRIVER)), + ( + "CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING", + ( + "hipErrorLaunchIncompatibleTexturing", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE", + ("hipErrorPrimaryContextActive", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_CONTEXT_IS_DESTROYED", + ("hipErrorContextIsDestroyed", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NOT_PERMITTED", + ("hipErrorNotPermitted", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NOT_SUPPORTED", + ("hipErrorNotSupported", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMissingConfiguration", + ("hipErrorMissingConfiguration", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorPriorLaunchFailure", + ("hipErrorPriorLaunchFailure", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidDeviceFunction", + ("hipErrorInvalidDeviceFunction", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidConfiguration", + ("hipErrorInvalidConfiguration", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidPitchValue", + ("hipErrorInvalidPitchValue", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidSymbol", + ("hipErrorInvalidSymbol", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidHostPointer", + ("hipErrorInvalidHostPointer", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidDevicePointer", + ("hipErrorInvalidDevicePointer", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaErrorInvalidTexture", + ("hipErrorInvalidTexture", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidTextureBinding", + ("hipErrorInvalidTextureBinding", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidChannelDescriptor", + ( + "hipErrorInvalidChannelDescriptor", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaErrorInvalidMemcpyDirection", + ("hipErrorInvalidMemcpyDirection", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorAddressOfConstant", + ("hipErrorAddressOfConstant", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorTextureFetchFailed", + ("hipErrorTextureFetchFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorTextureNotBound", + ("hipErrorTextureNotBound", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorSynchronizationError", + ("hipErrorSynchronizationError", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidFilterSetting", + ("hipErrorInvalidFilterSetting", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidNormSetting", + ("hipErrorInvalidNormSetting", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMixedDeviceExecution", + ("hipErrorMixedDeviceExecution", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorNotYetImplemented", + ("hipErrorNotYetImplemented", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMemoryValueTooLarge", + ("hipErrorMemoryValueTooLarge", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInsufficientDriver", + ("hipErrorInsufficientDriver", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorSetOnActiveProcess", + ("hipErrorSetOnActiveProcess", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorContextIsDestroyed", + ("hipErrorContextIsDestroyed", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaErrorInvalidSurface", + ("hipErrorInvalidSurface", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDuplicateVariableName", + ("hipErrorDuplicateVariableName", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDuplicateTextureName", + ("hipErrorDuplicateTextureName", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDuplicateSurfaceName", + ("hipErrorDuplicateSurfaceName", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDevicesUnavailable", + ("hipErrorDevicesUnavailable", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorIncompatibleDriverContext", + ( + "hipErrorIncompatibleDriverContext", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaErrorDeviceAlreadyInUse", + ("hipErrorDeviceAlreadyInUse", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchMaxDepthExceeded", + ("hipErrorLaunchMaxDepthExceeded", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchFileScopedTex", + ("hipErrorLaunchFileScopedTex", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchFileScopedSurf", + ("hipErrorLaunchFileScopedSurf", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorSyncDepthExceeded", + ("hipErrorSyncDepthExceeded", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchPendingCountExceeded", + ( + "hipErrorLaunchPendingCountExceeded", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaErrorNotPermitted", + ("hipErrorNotPermitted", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorNotSupported", + ("hipErrorNotSupported", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorStartupFailure", + ("hipErrorStartupFailure", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorApiFailureBase", + ("hipErrorApiFailureBase", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_SUCCESS", ("hipSuccess", CONV_TYPE, API_DRIVER)), + ("cudaSuccess", ("hipSuccess", CONV_TYPE, API_RUNTIME)), + ("CUDA_ERROR_INVALID_VALUE", ("hipErrorInvalidValue", CONV_TYPE, API_DRIVER)), + ("cudaErrorInvalidValue", ("hipErrorInvalidValue", CONV_TYPE, API_RUNTIME)), + ( + "CUDA_ERROR_OUT_OF_MEMORY", + ("hipErrorMemoryAllocation", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorMemoryAllocation", + ("hipErrorMemoryAllocation", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_NOT_INITIALIZED", + ("hipErrorNotInitialized", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInitializationError", + ("hipErrorInitializationError", CONV_TYPE, API_RUNTIME), + ), + ("CUDA_ERROR_DEINITIALIZED", ("hipErrorDeinitialized", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorCudartUnloading", + ("hipErrorDeinitialized", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_DISABLED", + ("hipErrorProfilerDisabled", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerDisabled", + ("hipErrorProfilerDisabled", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_NOT_INITIALIZED", + ("hipErrorProfilerNotInitialized", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerNotInitialized", + ("hipErrorProfilerNotInitialized", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_ALREADY_STARTED", + ("hipErrorProfilerAlreadyStarted", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerAlreadyStarted", + ("hipErrorProfilerAlreadyStarted", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_ALREADY_STOPPED", + ("hipErrorProfilerAlreadyStopped", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerAlreadyStopped", + ("hipErrorProfilerAlreadyStopped", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_ERROR_NO_DEVICE", ("hipErrorNoDevice", CONV_TYPE, API_DRIVER)), + ("cudaErrorNoDevice", ("hipErrorNoDevice", CONV_TYPE, API_RUNTIME)), + ("CUDA_ERROR_INVALID_DEVICE", ("hipErrorInvalidDevice", CONV_TYPE, API_DRIVER)), + ("cudaErrorInvalidDevice", ("hipErrorInvalidDevice", CONV_TYPE, API_RUNTIME)), + ("CUDA_ERROR_INVALID_IMAGE", ("hipErrorInvalidImage", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorInvalidKernelImage", + ("hipErrorInvalidImage", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_ERROR_MAP_FAILED", ("hipErrorMapFailed", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorMapBufferObjectFailed", + ("hipErrorMapFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_ERROR_UNMAP_FAILED", ("hipErrorUnmapFailed", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorUnmapBufferObjectFailed", + ("hipErrorUnmapFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NO_BINARY_FOR_GPU", + ("hipErrorNoBinaryForGpu", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorNoKernelImageForDevice", + ("hipErrorNoBinaryForGpu", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_ECC_UNCORRECTABLE", + ("hipErrorECCNotCorrectable", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorECCUncorrectable", + ("hipErrorECCNotCorrectable", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_UNSUPPORTED_LIMIT", + ("hipErrorUnsupportedLimit", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorUnsupportedLimit", + ("hipErrorUnsupportedLimit", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PEER_ACCESS_UNSUPPORTED", + ("hipErrorPeerAccessUnsupported", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorPeerAccessUnsupported", + ("hipErrorPeerAccessUnsupported", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_PTX", + ("hipErrorInvalidKernelFile", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInvalidPtx", + ("hipErrorInvalidKernelFile", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_GRAPHICS_CONTEXT", + ("hipErrorInvalidGraphicsContext", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInvalidGraphicsContext", + ("hipErrorInvalidGraphicsContext", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NVLINK_UNCORRECTABLE", + ("hipErrorNvlinkUncorrectable", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorNvlinkUncorrectable", + ("hipErrorNvlinkUncorrectable", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND", + ("hipErrorSharedObjectSymbolNotFound", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorSharedObjectSymbolNotFound", + ( + "hipErrorSharedObjectSymbolNotFound", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "CUDA_ERROR_SHARED_OBJECT_INIT_FAILED", + ("hipErrorSharedObjectInitFailed", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorSharedObjectInitFailed", + ("hipErrorSharedObjectInitFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_OPERATING_SYSTEM", + ("hipErrorOperatingSystem", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorOperatingSystem", + ("hipErrorOperatingSystem", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_HANDLE", + ("hipErrorInvalidResourceHandle", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInvalidResourceHandle", + ("hipErrorInvalidResourceHandle", CONV_TYPE, API_RUNTIME), + ), + ("CUDA_ERROR_NOT_READY", ("hipErrorNotReady", CONV_TYPE, API_DRIVER)), + ("cudaErrorNotReady", ("hipErrorNotReady", CONV_TYPE, API_RUNTIME)), + ( + "CUDA_ERROR_ILLEGAL_ADDRESS", + ("hipErrorIllegalAddress", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorIllegalAddress", + ("hipErrorIllegalAddress", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES", + ("hipErrorLaunchOutOfResources", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorLaunchOutOfResources", + ("hipErrorLaunchOutOfResources", CONV_TYPE, API_RUNTIME), + ), + ("CUDA_ERROR_LAUNCH_TIMEOUT", ("hipErrorLaunchTimeOut", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorLaunchTimeout", + ("hipErrorLaunchTimeOut", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED", + ("hipErrorPeerAccessAlreadyEnabled", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorPeerAccessAlreadyEnabled", + ("hipErrorPeerAccessAlreadyEnabled", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_PEER_ACCESS_NOT_ENABLED", + ("hipErrorPeerAccessNotEnabled", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorPeerAccessNotEnabled", + ("hipErrorPeerAccessNotEnabled", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_ASSERT", + ("hipErrorAssert", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorAssert", + ("hipErrorAssert", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_TOO_MANY_PEERS", + ("hipErrorTooManyPeers", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorTooManyPeers", + ("hipErrorTooManyPeers", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED", + ("hipErrorHostMemoryAlreadyRegistered", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorHostMemoryAlreadyRegistered", + ("hipErrorHostMemoryAlreadyRegistered", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED", + ("hipErrorHostMemoryNotRegistered", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorHostMemoryNotRegistered", + ("hipErrorHostMemoryNotRegistered", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_HARDWARE_STACK_ERROR", + ("hipErrorHardwareStackError", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorHardwareStackError", + ("hipErrorHardwareStackError", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_ILLEGAL_INSTRUCTION", + ("hipErrorIllegalInstruction", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorIllegalInstruction", + ("hipErrorIllegalInstruction", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_MISALIGNED_ADDRESS", + ("hipErrorMisalignedAddress", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMisalignedAddress", + ("hipErrorMisalignedAddress", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_ADDRESS_SPACE", + ("hipErrorInvalidAddressSpace", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidAddressSpace", + ("hipErrorInvalidAddressSpace", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_PC", + ("hipErrorInvalidPc", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidPc", + ("hipErrorInvalidPc", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_LAUNCH_FAILED", + ("hipErrorLaunchFailure", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchFailure", + ("hipErrorLaunchFailure", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_UNKNOWN", + ("hipErrorUnknown", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cudaErrorUnknown", ("hipErrorUnknown", CONV_TYPE, API_RUNTIME)), + ( + "CU_TR_ADDRESS_MODE_WRAP", + ("HIP_TR_ADDRESS_MODE_WRAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TR_ADDRESS_MODE_CLAMP", + ("HIP_TR_ADDRESS_MODE_CLAMP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TR_ADDRESS_MODE_MIRROR", + ("HIP_TR_ADDRESS_MODE_MIRROR", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TR_ADDRESS_MODE_BORDER", + ("HIP_TR_ADDRESS_MODE_BORDER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_POSITIVE_X", + ("HIP_CUBEMAP_FACE_POSITIVE_X", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_NEGATIVE_X", + ("HIP_CUBEMAP_FACE_NEGATIVE_X", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_POSITIVE_Y", + ("HIP_CUBEMAP_FACE_POSITIVE_Y", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_NEGATIVE_Y", + ("HIP_CUBEMAP_FACE_NEGATIVE_Y", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_POSITIVE_Z", + ("HIP_CUBEMAP_FACE_POSITIVE_Z", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_NEGATIVE_Z", + ("HIP_CUBEMAP_FACE_NEGATIVE_Z", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_AD_FORMAT_UNSIGNED_INT8", + ("HIP_AD_FORMAT_UNSIGNED_INT8", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_UNSIGNED_INT16", + ("HIP_AD_FORMAT_UNSIGNED_INT16", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_UNSIGNED_INT32", + ("HIP_AD_FORMAT_UNSIGNED_INT32", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_SIGNED_INT8", + ("HIP_AD_FORMAT_SIGNED_INT8", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_SIGNED_INT16", + ("HIP_AD_FORMAT_SIGNED_INT16", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_SIGNED_INT32", + ("HIP_AD_FORMAT_SIGNED_INT32", CONV_TYPE, API_DRIVER), + ), + ("CU_AD_FORMAT_HALF", ("HIP_AD_FORMAT_HALF", CONV_TYPE, API_DRIVER)), + ("CU_AD_FORMAT_FLOAT", ("HIP_AD_FORMAT_FLOAT", CONV_TYPE, API_DRIVER)), + ( + "CU_COMPUTEMODE_DEFAULT", + ("hipComputeModeDefault", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_COMPUTEMODE_EXCLUSIVE", + ("hipComputeModeExclusive", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_COMPUTEMODE_PROHIBITED", + ("hipComputeModeProhibited", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_COMPUTEMODE_EXCLUSIVE_PROCESS", + ("hipComputeModeExclusiveProcess", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_SET_READ_MOSTLY", + ("hipMemAdviseSetReadMostly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_UNSET_READ_MOSTLY", + ("hipMemAdviseUnsetReadMostly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_SET_PREFERRED_LOCATION", + ( + "hipMemAdviseSetPreferredLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION", + ( + "hipMemAdviseUnsetPreferredLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_MEM_ADVISE_SET_ACCESSED_BY", + ("hipMemAdviseSetAccessedBy", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_UNSET_ACCESSED_BY", + ("hipMemAdviseUnsetAccessedBy", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY", + ("hipMemRangeAttributeReadMostly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION", + ( + "hipMemRangeAttributePreferredLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY", + ("hipMemRangeAttributeAccessedBy", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION", + ( + "hipMemRangeAttributeLastPrefetchLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_CTX_SCHED_AUTO", + ("HIP_CTX_SCHED_AUTO", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_SPIN", + ("HIP_CTX_SCHED_SPIN", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_YIELD", + ("HIP_CTX_SCHED_YIELD", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_BLOCKING_SYNC", + ("HIP_CTX_SCHED_BLOCKING_SYNC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_BLOCKING_SYNC", + ("HIP_CTX_BLOCKING_SYNC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_MASK", + ("HIP_CTX_SCHED_MASK", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_MAP_HOST", + ("HIP_CTX_MAP_HOST", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_LMEM_RESIZE_TO_MAX", + ("HIP_CTX_LMEM_RESIZE_TO_MAX", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_FLAGS_MASK", + ("HIP_CTX_FLAGS_MASK", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LAUNCH_PARAM_BUFFER_POINTER", + ("HIP_LAUNCH_PARAM_BUFFER_POINTER", CONV_TYPE, API_DRIVER), + ), + ( + "CU_LAUNCH_PARAM_BUFFER_SIZE", + ("HIP_LAUNCH_PARAM_BUFFER_SIZE", CONV_TYPE, API_DRIVER), + ), + ("CU_LAUNCH_PARAM_END", ("HIP_LAUNCH_PARAM_END", CONV_TYPE, API_DRIVER)), + ( + "CU_IPC_HANDLE_SIZE", + ("HIP_IPC_HANDLE_SIZE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTALLOC_DEVICEMAP", + ("HIP_MEMHOSTALLOC_DEVICEMAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTALLOC_PORTABLE", + ("HIP_MEMHOSTALLOC_PORTABLE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTALLOC_WRITECOMBINED", + ("HIP_MEMHOSTALLOC_WRITECOMBINED", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTREGISTER_DEVICEMAP", + ("HIP_MEMHOSTREGISTER_DEVICEMAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTREGISTER_IOMEMORY", + ("HIP_MEMHOSTREGISTER_IOMEMORY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTREGISTER_PORTABLE", + ("HIP_MEMHOSTREGISTER_PORTABLE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_PARAM_TR_DEFAULT", + ("HIP_PARAM_TR_DEFAULT", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_LEGACY", + ("HIP_STREAM_LEGACY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_PER_THREAD", + ("HIP_STREAM_PER_THREAD", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TRSA_OVERRIDE_FORMAT", + ("HIP_TRSA_OVERRIDE_FORMAT", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TRSF_NORMALIZED_COORDINATES", + ("HIP_TRSF_NORMALIZED_COORDINATES", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TRSF_READ_AS_INTEGER", + ("HIP_TRSF_READ_AS_INTEGER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_TRSF_SRGB", ("HIP_TRSF_SRGB", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUDA_ARRAY3D_2DARRAY", + ("HIP_ARRAY3D_LAYERED", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_CUBEMAP", + ("HIP_ARRAY3D_CUBEMAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_DEPTH_TEXTURE", + ("HIP_ARRAY3D_DEPTH_TEXTURE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_LAYERED", + ("HIP_ARRAY3D_LAYERED", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_SURFACE_LDST", + ("HIP_ARRAY3D_SURFACE_LDST", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_TEXTURE_GATHER", + ("HIP_ARRAY3D_TEXTURE_GATHER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK", + ( + "hipDeviceAttributeMaxThreadsPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X", + ("hipDeviceAttributeMaxBlockDimX", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y", + ("hipDeviceAttributeMaxBlockDimY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z", + ("hipDeviceAttributeMaxBlockDimZ", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X", + ("hipDeviceAttributeMaxGridDimX", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y", + ("hipDeviceAttributeMaxGridDimY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z", + ("hipDeviceAttributeMaxGridDimZ", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK", + ( + "hipDeviceAttributeMaxSharedMemoryPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_SHARED_MEMORY_PER_BLOCK", + ( + "hipDeviceAttributeMaxSharedMemoryPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY", + ( + "hipDeviceAttributeTotalConstantMemory", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_WARP_SIZE", + ("hipDeviceAttributeWarpSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_PITCH", + ("hipDeviceAttributeMaxPitch", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK", + ( + "hipDeviceAttributeMaxRegistersPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_REGISTERS_PER_BLOCK", + ( + "hipDeviceAttributeMaxRegistersPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CLOCK_RATE", + ("hipDeviceAttributeClockRate", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT", + ( + "hipDeviceAttributeTextureAlignment", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_GPU_OVERLAP", + ( + "hipDeviceAttributeAsyncEngineCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT", + ( + "hipDeviceAttributeMultiprocessorCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT", + ( + "hipDeviceAttributeKernelExecTimeout", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_INTEGRATED", + ("hipDeviceAttributeIntegrated", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY", + ( + "hipDeviceAttributeCanMapHostMemory", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_MODE", + ("hipDeviceAttributeComputeMode", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH", + ( + "hipDeviceAttributeMaxTexture3DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT", + ( + "hipDeviceAttributeMaxTexture3DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH", + ( + "hipDeviceAttributeMaxTexture3DDepth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DLayeredHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxTexture2DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DLayeredHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_NUMSLICES", + ( + "hipDeviceAttributeMaxTexture2DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_SURFACE_ALIGNMENT", + ( + "hipDeviceAttributeSurfaceAlignment", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS", + ("hipDeviceAttributeConcurrentKernels", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_ECC_ENABLED", + ("hipDeviceAttributeEccEnabled", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_PCI_BUS_ID", + ("hipDeviceAttributePciBusId", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID", + ("hipDeviceAttributePciDeviceId", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_TCC_DRIVER", + ("hipDeviceAttributeTccDriver", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE", + ( + "hipDeviceAttributeMemoryClockRate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH", + ("hipDeviceAttributeMemoryBusWidth", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE", + ("hipDeviceAttributeL2CacheSize", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR", + ("hipDeviceAttributeMaxThreadsPerMultiProcessor", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT", + ( + "hipDeviceAttributeAsyncEngineCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING", + ( + "hipDeviceAttributeUnifiedAddressing", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxTexture1DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CAN_TEX2D_GATHER", + ( + "hipDeviceAttributeCanTex2DGather", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DGatherWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DGatherHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE", + ( + "hipDeviceAttributeMaxTexture3DWidthAlternate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE", + ( + "hipDeviceAttributeMaxTexture3DHeightAlternate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE", + ( + "hipDeviceAttributeMaxTexture3DDepthAlternate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID", + ("hipDeviceAttributePciDomainId", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT", + ( + "hipDeviceAttributeTexturePitchAlignment", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH", + ( + "hipDeviceAttributeMaxTextureCubemapWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH", + ( + "hipDeviceAttributeMaxSurface1DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH", + ( + "hipDeviceAttributeMaxSurface2DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT", + ( + "hipDeviceAttributeMaxSurface2DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH", + ( + "hipDeviceAttributeMaxSurface3DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT", + ( + "hipDeviceAttributeMaxSurface3DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH", + ( + "hipDeviceAttributeMaxSurface3DDepth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxSurface1DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxSurface1DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxSurface2DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT", + ( + "hipDeviceAttributeMaxSurface2DLayeredHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxSurface2DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH", + ( + "hipDeviceAttributeMaxSurfaceCubemapWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DLinearWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DLinearWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DLinearHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH", + ( + "hipDeviceAttributeMaxTexture2DLinearPitch", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DMipmappedWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DMipmappedHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR", + ("hipDeviceAttributeComputeCapabilityMajor", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR", + ("hipDeviceAttributeComputeCapabilityMinor", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DMipmappedWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_STREAM_PRIORITIES_SUPPORTED", + ( + "hipDeviceAttributeStreamPrioritiesSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED", + ( + "hipDeviceAttributeGlobalL1CacheSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED", + ( + "hipDeviceAttributeLocalL1CacheSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR", + ( + "hipDeviceAttributeMaxSharedMemoryPerMultiprocessor", + CONV_TYPE, + API_DRIVER, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR", + ( + "hipDeviceAttributeMaxRegistersPerMultiprocessor", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY", + ("hipDeviceAttributeManagedMemory", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD", + ("hipDeviceAttributeIsMultiGpuBoard", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD_GROUP_ID", + ( + "hipDeviceAttributeMultiGpuBoardGroupId", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_HOST_NATIVE_ATOMIC_SUPPORTED", + ( + "hipDeviceAttributeHostNativeAtomicSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO", + ( + "hipDeviceAttributeSingleToDoublePrecisionPerfRatio", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS", + ( + "hipDeviceAttributePageableMemoryAccess", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS", + ( + "hipDeviceAttributeConcurrentManagedAccess", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED", + ( + "hipDeviceAttributeComputePreemptionSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM", + ( + "hipDeviceAttributeCanUseHostPointerForRegisteredMem", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX", + ("hipDeviceAttributeMax", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_CONTEXT", + ("hipPointerAttributeContext", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_MEMORY_TYPE", + ("hipPointerAttributeMemoryType", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_DEVICE_POINTER", + ( + "hipPointerAttributeDevicePointer", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_POINTER_ATTRIBUTE_HOST_POINTER", + ("hipPointerAttributeHostPointer", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_P2P_TOKENS", + ("hipPointerAttributeP2pTokens", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_SYNC_MEMOPS", + ("hipPointerAttributeSyncMemops", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_BUFFER_ID", + ("hipPointerAttributeBufferId", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_IS_MANAGED", + ("hipPointerAttributeIsManaged", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK", + ( + "hipFuncAttributeMaxThreadsPerBlocks", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES", + ("hipFuncAttributeSharedSizeBytes", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES", + ("hipFuncAttributeMaxDynamicSharedMemorySize", CONV_TYPE, API_RUNTIME), + ), + ( + "CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES", + ("hipFuncAttributeConstSizeBytes", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES", + ("hipFuncAttributeLocalSizeBytes", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_NUM_REGS", + ("hipFuncAttributeNumRegs", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_PTX_VERSION", + ("hipFuncAttributePtxVersion", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_BINARY_VERSION", + ("hipFuncAttributeBinaryVersion", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_CACHE_MODE_CA", + ("hipFuncAttributeCacheModeCA", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_MAX", + ("hipFuncAttributeMax", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_MAP_RESOURCE_FLAGS_NONE", + ("hipGraphicsMapFlagsNone", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_MAP_RESOURCE_FLAGS_READ_ONLY", + ("hipGraphicsMapFlagsReadOnly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + ("hipGraphicsMapFlagsWriteDiscard", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_NONE", + ("hipGraphicsRegisterFlagsNone", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_READ_ONLY", + ( + "hipGraphicsRegisterFlagsReadOnly", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_WRITE_DISCARD", + ( + "hipGraphicsRegisterFlagsWriteDiscard", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_SURFACE_LDST", + ( + "hipGraphicsRegisterFlagsSurfaceLoadStore", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_TEXTURE_GATHER", + ( + "hipGraphicsRegisterFlagsTextureGather", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_OCCUPANCY_DEFAULT", + ("hipOccupancyDefault", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE", + ( + "hipOccupancyDisableCachingOverride", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_FUNC_CACHE_PREFER_NONE", + ("hipFuncCachePreferNone", CONV_CACHE, API_DRIVER), + ), + ( + "CU_FUNC_CACHE_PREFER_SHARED", + ("hipFuncCachePreferShared", CONV_CACHE, API_DRIVER), + ), + ("CU_FUNC_CACHE_PREFER_L1", ("hipFuncCachePreferL1", CONV_CACHE, API_DRIVER)), + ( + "CU_FUNC_CACHE_PREFER_EQUAL", + ("hipFuncCachePreferEqual", CONV_CACHE, API_DRIVER), + ), + ( + "CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS", + ("hipIpcMemLazyEnablePeerAccess", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUDA_IPC_HANDLE_SIZE", ("HIP_IPC_HANDLE_SIZE", CONV_TYPE, API_DRIVER)), + ( + "CU_JIT_CACHE_OPTION_NONE", + ("hipJitCacheModeOptionNone", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_CACHE_OPTION_CG", + ("hipJitCacheModeOptionCG", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_CACHE_OPTION_CA", + ("hipJitCacheModeOptionCA", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_PREFER_PTX", + ("hipJitFallbackPreferPtx", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_PREFER_BINARY", + ("hipJitFallbackPreferBinary", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_JIT_MAX_REGISTERS", ("hipJitOptionMaxRegisters", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_THREADS_PER_BLOCK", + ("hipJitOptionThreadsPerBlock", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_WALL_TIME", ("hipJitOptionWallTime", CONV_JIT, API_DRIVER)), + ("CU_JIT_INFO_LOG_BUFFER", ("hipJitOptionInfoLogBuffer", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES", + ("hipJitOptionInfoLogBufferSizeBytes", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_ERROR_LOG_BUFFER", + ("hipJitOptionErrorLogBuffer", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES", + ("hipJitOptionErrorLogBufferSizeBytes", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_OPTIMIZATION_LEVEL", + ("hipJitOptionOptimizationLevel", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_TARGET_FROM_CUCONTEXT", + ("hipJitOptionTargetFromContext", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_TARGET", ("hipJitOptionTarget", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_FALLBACK_STRATEGY", + ("hipJitOptionFallbackStrategy", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_GENERATE_DEBUG_INFO", + ("hipJitOptionGenerateDebugInfo", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_LOG_VERBOSE", ("hipJitOptionLogVerbose", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_GENERATE_LINE_INFO", + ("hipJitOptionGenerateLineInfo", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_CACHE_MODE", ("hipJitOptionCacheMode", CONV_JIT, API_DRIVER)), + ("CU_JIT_NEW_SM3X_OPT", ("hipJitOptionSm3xOpt", CONV_JIT, API_DRIVER)), + ("CU_JIT_FAST_COMPILE", ("hipJitOptionFastCompile", CONV_JIT, API_DRIVER)), + ("CU_JIT_NUM_OPTIONS", ("hipJitOptionNumOptions", CONV_JIT, API_DRIVER)), + ( + "CU_TARGET_COMPUTE_10", + ("hipJitTargetCompute10", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_11", + ("hipJitTargetCompute11", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_12", + ("hipJitTargetCompute12", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_13", + ("hipJitTargetCompute13", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_20", + ("hipJitTargetCompute20", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_21", + ("hipJitTargetCompute21", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_30", + ("hipJitTargetCompute30", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_32", + ("hipJitTargetCompute32", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_35", + ("hipJitTargetCompute35", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_37", + ("hipJitTargetCompute37", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_50", + ("hipJitTargetCompute50", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_52", + ("hipJitTargetCompute52", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_53", + ("hipJitTargetCompute53", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_60", + ("hipJitTargetCompute60", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_61", + ("hipJitTargetCompute61", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_62", + ("hipJitTargetCompute62", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_CUBIN", + ("hipJitInputTypeBin", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_PTX", + ("hipJitInputTypePtx", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_FATBINARY", + ("hipJitInputTypeFatBinary", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_OBJECT", + ("hipJitInputTypeObject", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_LIBRARY", + ("hipJitInputTypeLibrary", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_NUM_INPUT_TYPES", + ("hipJitInputTypeNumInputTypes", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_STACK_SIZE", + ("hipLimitStackSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_PRINTF_FIFO_SIZE", + ("hipLimitPrintfFifoSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_MALLOC_HEAP_SIZE", + ("hipLimitMallocHeapSize", CONV_TYPE, API_DRIVER), + ), + ( + "CU_LIMIT_DEV_RUNTIME_SYNC_DEPTH", + ("hipLimitDevRuntimeSyncDepth", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT", + ( + "hipLimitDevRuntimePendingLaunchCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_LIMIT_STACK_SIZE", + ("hipLimitStackSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ATTACH_GLOBAL", + ("hipMemAttachGlobal", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ATTACH_HOST", + ("hipMemAttachHost", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ATTACH_SINGLE", + ("hipMemAttachSingle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_HOST", + ("hipMemTypeHost", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_DEVICE", + ("hipMemTypeDevice", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_ARRAY", + ("hipMemTypeArray", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_UNIFIED", + ("hipMemTypeUnified", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_MEMHOSTREGISTER_READ_ONLY", ("hipHostRegisterReadOnly", CONV_TYPE, API_DRIVER)), + ("CU_MEMPOOL_ATTR_RELEASE_THRESHOLD", ("hipMemPoolAttrReleaseThreshold", CONV_TYPE, API_DRIVER)), + ("CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT", ("hipMemPoolAttrReservedMemCurrent", CONV_TYPE, API_DRIVER)), + ("CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH", ("hipMemPoolAttrReservedMemHigh", CONV_TYPE, API_DRIVER)), + ( + "CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES", + ("hipMemPoolReuseAllowInternalDependencies", CONV_TYPE, API_DRIVER) + ), + ("CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC", ("hipMemPoolReuseAllowOpportunistic", CONV_TYPE, API_DRIVER)), + ( + "CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES", + ("hipMemPoolReuseFollowEventDependencies", CONV_TYPE, API_DRIVER) + ), + ("CU_MEMPOOL_ATTR_USED_MEM_CURRENT", ("hipMemPoolAttrUsedMemCurrent", CONV_TYPE, API_DRIVER)), + ("CU_MEMPOOL_ATTR_USED_MEM_HIGH", ("hipMemPoolAttrUsedMemHigh", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ACCESS_FLAGS_PROT_NONE", ("hipMemAccessFlagsProtNone", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ACCESS_FLAGS_PROT_READ", ("hipMemAccessFlagsProtRead", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ACCESS_FLAGS_PROT_READWRITE", ("hipMemAccessFlagsProtReadWrite", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ALLOCATION_TYPE_INVALID", ("hipMemAllocationTypeInvalid", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ALLOCATION_TYPE_MAX", ("hipMemAllocationTypeMax", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ALLOCATION_TYPE_PINNED", ("hipMemAllocationTypePinned", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ALLOC_GRANULARITY_MINIMUM", ("hipMemAllocationGranularityMinimum", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ALLOC_GRANULARITY_RECOMMENDED", ("hipMemAllocationGranularityRecommended", CONV_TYPE, API_DRIVER)), + ("CU_MEM_HANDLE_TYPE_GENERIC", ("hipMemHandleTypeGeneric", CONV_TYPE, API_DRIVER)), + ("CU_MEM_HANDLE_TYPE_NONE", ("hipMemHandleTypeNone", CONV_TYPE, API_DRIVER)), + ("CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR", ("hipMemHandleTypePosixFileDescriptor", CONV_TYPE, API_DRIVER)), + ("CU_MEM_HANDLE_TYPE_WIN32", ("hipMemHandleTypeWin32", CONV_TYPE, API_DRIVER)), + ("CU_MEM_HANDLE_TYPE_WIN32_KMT", ("hipMemHandleTypeWin32Kmt", CONV_TYPE, API_DRIVER)), + ("CU_MEM_LOCATION_TYPE_DEVICE", ("hipMemLocationTypeDevice", CONV_TYPE, API_DRIVER)), + ("CU_MEM_LOCATION_TYPE_INVALID", ("hipMemLocationTypeInvalid", CONV_TYPE, API_DRIVER)), + ("CU_MEM_OPERATION_TYPE_MAP", ("hipMemOperationTypeMap", CONV_TYPE, API_DRIVER)), + ("CU_MEM_OPERATION_TYPE_UNMAP", ("hipMemOperationTypeUnmap", CONV_TYPE, API_DRIVER)), + ( + "CU_RESOURCE_TYPE_ARRAY", + ("hipResourceTypeArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_MIPMAPPED_ARRAY", + ("hipResourceTypeMipmappedArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_LINEAR", + ("hipResourceTypeLinear", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_PITCH2D", + ("hipResourceTypePitch2D", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_RES_VIEW_FORMAT_NONE", ("hipResViewFormatNone", CONV_TEX, API_DRIVER)), + ( + "CU_RES_VIEW_FORMAT_UINT_1X8", + ("hipResViewFormatUnsignedChar1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_2X8", + ("hipResViewFormatUnsignedChar2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_4X8", + ("hipResViewFormatUnsignedChar4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_1X8", + ("hipResViewFormatSignedChar1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_2X8", + ("hipResViewFormatSignedChar2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_4X8", + ("hipResViewFormatSignedChar4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_1X16", + ("hipResViewFormatUnsignedShort1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_2X16", + ("hipResViewFormatUnsignedShort2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_4X16", + ("hipResViewFormatUnsignedShort4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_1X16", + ("hipResViewFormatSignedShort1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_2X16", + ("hipResViewFormatSignedShort2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_4X16", + ("hipResViewFormatSignedShort4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_1X32", + ("hipResViewFormatUnsignedInt1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_2X32", + ("hipResViewFormatUnsignedInt2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_4X32", + ("hipResViewFormatUnsignedInt4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_1X32", + ("hipResViewFormatSignedInt1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_2X32", + ("hipResViewFormatSignedInt2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_4X32", + ("hipResViewFormatSignedInt4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_1X16", + ("hipResViewFormatHalf1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_2X16", + ("hipResViewFormatHalf2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_4X16", + ("hipResViewFormatHalf4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_1X32", + ("hipResViewFormatFloat1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_2X32", + ("hipResViewFormatFloat2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_4X32", + ("hipResViewFormatFloat4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC1", + ("hipResViewFormatUnsignedBlockCompressed1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC2", + ("hipResViewFormatUnsignedBlockCompressed2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC3", + ("hipResViewFormatUnsignedBlockCompressed3", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC4", + ("hipResViewFormatUnsignedBlockCompressed4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SIGNED_BC4", + ("hipResViewFormatSignedBlockCompressed4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC5", + ("hipResViewFormatUnsignedBlockCompressed5", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SIGNED_BC5", + ("hipResViewFormatSignedBlockCompressed5", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC6H", + ("hipResViewFormatUnsignedBlockCompressed6H", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SIGNED_BC6H", + ("hipResViewFormatSignedBlockCompressed6H", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC7", + ("hipResViewFormatUnsignedBlockCompressed7", CONV_TEX, API_DRIVER), + ), + ( + "CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE", + ("hipSharedMemBankSizeDefault", CONV_TYPE, API_DRIVER), + ), + ( + "CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE", + ("hipSharedMemBankSizeFourByte", CONV_TYPE, API_DRIVER), + ), + ( + "CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE", + ("hipSharedMemBankSizeEightByte", CONV_TYPE, API_DRIVER), + ), + ("CU_STREAM_DEFAULT", ("hipStreamDefault", CONV_TYPE, API_DRIVER)), + ("CU_STREAM_NON_BLOCKING", ("hipStreamNonBlocking", CONV_TYPE, API_DRIVER)), + ( + "CU_STREAM_WAIT_VALUE_GEQ", + ("hipStreamWaitValueGeq", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WAIT_VALUE_EQ", + ("hipStreamWaitValueEq", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WAIT_VALUE_AND", + ("hipStreamWaitValueAnd", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WAIT_VALUE_FLUSH", + ("hipStreamWaitValueFlush", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WRITE_VALUE_DEFAULT", + ("hipStreamWriteValueDefault", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WRITE_VALUE_NO_MEMORY_BARRIER", + ( + "hipStreamWriteValueNoMemoryBarrier", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_STREAM_MEM_OP_WAIT_VALUE_32", + ("hipStreamBatchMemOpWaitValue32", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_MEM_OP_WRITE_VALUE_32", + ("hipStreamBatchMemOpWriteValue32", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES", + ( + "hipStreamBatchMemOpFlushRemoteWrites", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGetErrorName", + ("hipGetErrorName", CONV_ERROR, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGetErrorString", + ("hipDrvGetErrorString", CONV_ERROR, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuInit", ("hipInit", CONV_INIT, API_DRIVER)), + ("cuDriverGetVersion", ("hipDriverGetVersion", CONV_VERSION, API_DRIVER)), + ("cuCtxCreate", ("hipCtxCreate", CONV_CONTEXT, API_DRIVER)), + ("cuCtxCreate_v2", ("hipCtxCreate", CONV_CONTEXT, API_DRIVER)), + ("cuCtxDestroy", ("hipCtxDestroy", CONV_CONTEXT, API_DRIVER)), + ("cuCtxDestroy_v2", ("hipCtxDestroy", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetApiVersion", ("hipCtxGetApiVersion", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetCacheConfig", ("hipCtxGetCacheConfig", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetCurrent", ("hipCtxGetCurrent", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetDevice", ("hipCtxGetDevice", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetFlags", ("hipCtxGetFlags", CONV_CONTEXT, API_DRIVER)), + ("cuDeviceGetUuid", ("hipDeviceGetUuid", CONV_CONTEXT, API_DRIVER)), + ( + "cuCtxGetLimit", + ("hipCtxGetLimit", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuCtxGetSharedMemConfig", + ("hipCtxGetSharedMemConfig", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuCtxGetStreamPriorityRange", + ("hipCtxGetStreamPriorityRange", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuCtxPopCurrent_v2", ("hipCtxPopCurrent", CONV_CONTEXT, API_DRIVER)), + ("cuCtxPushCurrent_v2", ("hipCtxPushCurrent", CONV_CONTEXT, API_DRIVER)), + ("cuCtxSetCacheConfig", ("hipCtxSetCacheConfig", CONV_CONTEXT, API_DRIVER)), + ("cuCtxSetCurrent", ("hipCtxSetCurrent", CONV_CONTEXT, API_DRIVER)), + ( + "cuCtxSetLimit", + ("hipCtxSetLimit", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuCtxSetSharedMemConfig", + ("hipCtxSetSharedMemConfig", CONV_CONTEXT, API_DRIVER), + ), + ("cuCtxSynchronize", ("hipCtxSynchronize", CONV_CONTEXT, API_DRIVER)), + ("cuCtxAttach", ("hipCtxAttach", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED)), + ("cuCtxDetach", ("hipCtxDetach", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED)), + ("cuCtxEnablePeerAccess", ("hipCtxEnablePeerAccess", CONV_PEER, API_DRIVER)), + ("cuCtxDisablePeerAccess", ("hipCtxDisablePeerAccess", CONV_PEER, API_DRIVER)), + ("cuDeviceCanAccessPeer", ("hipDeviceCanAccessPeer", CONV_PEER, API_DRIVER)), + ( + "cuDeviceGetP2PAttribute", + ("hipDeviceGetP2PAttribute", CONV_PEER, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuDevicePrimaryCtxGetState", + ("hipDevicePrimaryCtxGetState", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxRelease", + ("hipDevicePrimaryCtxRelease", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxReset", + ("hipDevicePrimaryCtxReset", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxRetain", + ("hipDevicePrimaryCtxRetain", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxSetFlags", + ("hipDevicePrimaryCtxSetFlags", CONV_CONTEXT, API_DRIVER), + ), + ("cuDeviceGet", ("hipDeviceGet", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetName", ("hipDeviceGetName", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetCount", ("hipGetDeviceCount", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetAttribute", ("hipDeviceGetAttribute", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetPCIBusId", ("hipDeviceGetPCIBusId", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetByPCIBusId", ("hipDeviceGetByPCIBusId", CONV_DEVICE, API_DRIVER)), + ("cuDeviceTotalMem_v2", ("hipDeviceTotalMem", CONV_DEVICE, API_DRIVER)), + ( + "cuDeviceComputeCapability", + ("hipDeviceComputeCapability", CONV_DEVICE, API_DRIVER), + ), + ("cuDeviceGetProperties", ("hipGetDeviceProperties", CONV_DEVICE, API_DRIVER)), + ("cuLinkAddData", ("hipLinkAddData", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuLinkAddFile", ("hipLinkAddFile", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuLinkComplete", + ("hipLinkComplete", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuLinkCreate", ("hipLinkCreate", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuLinkDestroy", ("hipLinkDestroy", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuModuleGetFunction", ("hipModuleGetFunction", CONV_MODULE, API_DRIVER)), + ("cuModuleGetGlobal_v2", ("hipModuleGetGlobal", CONV_MODULE, API_DRIVER)), + ( + "cuModuleGetSurfRef", + ("hipModuleGetSurfRef", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuModuleGetTexRef", ("hipModuleGetTexRef", CONV_MODULE, API_DRIVER)), + ("cuModuleLoad", ("hipModuleLoad", CONV_MODULE, API_DRIVER)), + ("cuModuleLoadData", ("hipModuleLoadData", CONV_MODULE, API_DRIVER)), + ("cuModuleLoadDataEx", ("hipModuleLoadDataEx", CONV_MODULE, API_DRIVER)), + ( + "cuModuleLoadFatBinary", + ("hipModuleLoadFatBinary", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuModuleUnload", ("hipModuleUnload", CONV_MODULE, API_DRIVER)), + ( + "CU_DEVICE_P2P_ATTRIBUTE_PERFORMANCE_RANK", + ( + "hipDeviceP2PAttributePerformanceRank", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_P2P_ATTRIBUTE_ACCESS_SUPPORTED", + ( + "hipDeviceP2PAttributeAccessSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_P2P_ATTRIBUTE_NATIVE_ATOMIC_SUPPORTED", + ( + "hipDeviceP2PAttributeNativeAtomicSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("CU_EVENT_DEFAULT", ("hipEventDefault", CONV_EVENT, API_DRIVER)), + ("CU_EVENT_BLOCKING_SYNC", ("hipEventBlockingSync", CONV_EVENT, API_DRIVER)), + ("CU_EVENT_DISABLE_TIMING", ("hipEventDisableTiming", CONV_EVENT, API_DRIVER)), + ("CU_EVENT_INTERPROCESS", ("hipEventInterprocess", CONV_EVENT, API_DRIVER)), + ("cuEventCreate", ("hipEventCreate", CONV_EVENT, API_DRIVER)), + ("cuEventDestroy", ("hipEventDestroy", CONV_EVENT, API_DRIVER)), + ("cuEventDestroy_v2", ("hipEventDestroy", CONV_EVENT, API_DRIVER)), + ("cuEventElapsedTime", ("hipEventElapsedTime", CONV_EVENT, API_DRIVER)), + ("cuEventQuery", ("hipEventQuery", CONV_EVENT, API_DRIVER)), + ("cuEventRecord", ("hipEventRecord", CONV_EVENT, API_DRIVER)), + ("cuEventSynchronize", ("hipEventSynchronize", CONV_EVENT, API_DRIVER)), + ("cuFuncSetAttribute", ("hipFuncSetAttribute", CONV_EVENT, API_DRIVER)), + ( + "cuFuncGetAttribute", + ("hipFuncGetAttribute", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuFuncSetCacheConfig", ("hipFuncSetCacheConfig", CONV_MODULE, API_DRIVER)), + ( + "cuFuncSetSharedMemConfig", + ("hipFuncSetSharedMemConfig", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuLaunchKernel", ("hipModuleLaunchKernel", CONV_MODULE, API_DRIVER)), + ( + "cuFuncSetBlockShape", + ("hipFuncSetBlockShape", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cudaLaunchKernel", ("hipLaunchKernel", CONV_MODULE, API_DRIVER)), + ( + "cuFuncSetSharedSize", + ("hipFuncSetSharedSize", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuLaunch", ("hipLaunch", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuLaunchGrid", ("hipLaunchGrid", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuLaunchGridAsync", + ("hipLaunchGridAsync", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuParamSetf", ("hipParamSetf", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuParamSeti", ("hipParamSeti", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuParamSetSize", + ("hipParamSetSize", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuParamSetSize", + ("hipParamSetSize", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuParamSetv", ("hipParamSetv", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuOccupancyMaxActiveBlocksPerMultiprocessor", + ( + "hipModuleOccupancyMaxActiveBlocksPerMultiprocessor", + CONV_OCCUPANCY, + API_DRIVER, + ), + ), + ( + "cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + ( + "hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + CONV_OCCUPANCY, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuOccupancyMaxPotentialBlockSize", + ("hipModuleOccupancyMaxPotentialBlockSize", CONV_OCCUPANCY, API_DRIVER), + ), + ( + "cuOccupancyMaxPotentialBlockSizeWithFlags", + ( + "hipModuleOccupancyMaxPotentialBlockSizeWithFlags", + CONV_OCCUPANCY, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("cuStreamAddCallback", ("hipStreamAddCallback", CONV_STREAM, API_DRIVER)), + ( + "cuStreamAttachMemAsync", + ("hipStreamAttachMemAsync", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamCreate", + ("hipStreamCreate__", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamCreateWithPriority", + ("hipStreamCreateWithPriority", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuStreamDestroy", ("hipStreamDestroy", CONV_STREAM, API_DRIVER)), + ("cuStreamDestroy_v2", ("hipStreamDestroy", CONV_STREAM, API_DRIVER)), + ("cuStreamGetFlags", ("hipStreamGetFlags", CONV_STREAM, API_DRIVER)), + ( + "cuStreamGetPriority", + ("hipStreamGetPriority", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuStreamQuery", ("hipStreamQuery", CONV_STREAM, API_DRIVER)), + ("cuStreamSynchronize", ("hipStreamSynchronize", CONV_STREAM, API_DRIVER)), + ("cuStreamWaitEvent", ("hipStreamWaitEvent", CONV_STREAM, API_DRIVER)), + ( + "cuStreamWaitValue32", + ("hipStreamWaitValue32", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamWriteValue32", + ("hipStreamWriteValue32", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamBatchMemOp", + ("hipStreamBatchMemOp", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuArray3DCreate", ("hipArray3DCreate", CONV_MEM, API_DRIVER)), + ( + "cuArray3DGetDescriptor", + ("hipArray3DGetDescriptor", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuArrayCreate", ("hipArrayCreate", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuArrayDestroy", ("hipArrayDestroy", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuArrayGetDescriptor", + ("hipArrayGetDescriptor", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcCloseMemHandle", + ("hipIpcCloseMemHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcGetEventHandle", + ("hipIpcGetEventHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcGetMemHandle", + ("hipIpcGetMemHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcOpenEventHandle", + ("hipIpcOpenEventHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcOpenMemHandle", + ("hipIpcOpenMemHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemAlloc_v2", ("hipMalloc", CONV_MEM, API_DRIVER)), + ("cuMemAllocHost", ("hipMemAllocHost", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemAllocManaged", + ("hipMemAllocManaged", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemAllocPitch", + ("hipMemAllocPitch__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpy", ("hipMemcpy__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpy2D", ("hipMemcpy2D__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpy2DAsync", + ("hipMemcpy2DAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemcpy2DUnaligned", + ("hipMemcpy2DUnaligned", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpy3D", ("hipMemcpy3D__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpy3DAsync", + ("hipMemcpy3DAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemcpy3DPeer", + ("hipMemcpy3DPeer__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemcpy3DPeerAsync", + ("hipMemcpy3DPeerAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyAsync", ("hipMemcpyAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyAtoA", ("hipMemcpyAtoA", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyAtoD", ("hipMemcpyAtoD", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyAtoH", ("hipMemcpyAtoH", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpyAtoHAsync", + ("hipMemcpyAtoHAsync", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyDtoA", ("hipMemcpyDtoA", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyDtoD_v2", ("hipMemcpyDtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoDAsync_v2", ("hipMemcpyDtoDAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoH_v2", ("hipMemcpyDtoH", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoHAsync_v2", ("hipMemcpyDtoHAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoA", ("hipMemcpyHtoA", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpyHtoAAsync", + ("hipMemcpyHtoAAsync", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyHtoD_v2", ("hipMemcpyHtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoDAsync_v2", ("hipMemcpyHtoDAsync", CONV_MEM, API_DRIVER)), + ( + "cuMemcpyPeerAsync", + ("hipMemcpyPeerAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyPeer", ("hipMemcpyPeer__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemFree", ("hipFree", CONV_MEM, API_DRIVER)), + ("cuMemFree_v2", ("hipFree", CONV_MEM, API_DRIVER)), + ("cuMemFreeHost", ("hipHostFree", CONV_MEM, API_DRIVER)), + ( + "cuMemGetAddressRange", + ("hipMemGetAddressRange", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemGetInfo_v2", ("hipMemGetInfo", CONV_MEM, API_DRIVER)), + ("cuMemHostAlloc", ("hipHostMalloc", CONV_MEM, API_DRIVER)), + ( + "cuMemHostGetDevicePointer", + ("hipMemHostGetDevicePointer", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemHostGetFlags", + ("hipMemHostGetFlags", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemHostRegister_v2", ("hipHostRegister", CONV_MEM, API_DRIVER)), + ("cuMemHostUnregister", ("hipHostUnregister", CONV_MEM, API_DRIVER)), + ("cuMemsetD16_v2", ("hipMemsetD16", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD16Async", + ("hipMemsetD16Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD2D16_v2", ("hipMemsetD2D16", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD2D16Async", + ("hipMemsetD2D16Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD2D32_v2", ("hipMemsetD2D32", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD2D32Async", + ("hipMemsetD2D32Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD2D8_v2", ("hipMemsetD2D8", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD2D8Async", + ("hipMemsetD2D8Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD32_v2", ("hipMemset", CONV_MEM, API_DRIVER)), + ("cuMemsetD32Async", ("hipMemsetAsync", CONV_MEM, API_DRIVER)), + ("cuMemsetD8_v2", ("hipMemsetD8", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD8Async", + ("hipMemsetD8Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMipmappedArrayCreate", + ("hipMipmappedArrayCreate", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMipmappedArrayDestroy", + ("hipMipmappedArrayDestroy", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMipmappedArrayGetLevel", + ("hipMipmappedArrayGetLevel", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemPrefetchAsync", + ("hipMemPrefetchAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemAdvise", ("hipMemAdvise", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemRangeGetAttribute", + ("hipMemRangeGetAttribute", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemRangeGetAttributes", + ("hipMemRangeGetAttributes", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuPointerGetAttribute", + ("hipPointerGetAttribute", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemGetAddressRange_v2", + ("hipMemGetAddressRange", CONV_MEM, API_DRIVER), + ), + ("cuArray3DCreate_v2", ("hipArray3DCreate", CONV_MEM, API_DRIVER)), + ("cuArray3DGetDescriptor_v2", ("hipArray3DGetDescriptor", CONV_MEM, API_DRIVER)), + ("cuArrayGetDescriptor_v2", ("hipArrayGetDescriptor", CONV_MEM, API_DRIVER)), + ("cuMemAlloc", ("hipMalloc", CONV_MEM, API_DRIVER)), + ("cuMemAllocHost_v2", ("hipMemAllocHost", CONV_MEM, API_DRIVER)), + ("cuMemAllocPitch_v2", ("hipMemAllocPitch", CONV_MEM, API_DRIVER)), + ("cuMemGetInfo", ("hipMemGetInfo", CONV_MEM, API_DRIVER)), + ("cuMemHostGetDevicePointer_v2", ("hipHostGetDevicePointer", CONV_MEM, API_DRIVER)), + ("cuMemHostRegister", ("hipHostRegister", CONV_MEM, API_DRIVER)), + ("cuMemcpy2DAsync_v2", ("hipMemcpyParam2DAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpy2DUnaligned_v2", ("hipDrvMemcpy2DUnaligned", CONV_MEM, API_DRIVER)), + ("cuMemcpy2D_v2", ("hipMemcpyParam2D", CONV_MEM, API_DRIVER)), + ("cuMemcpy3DAsync_v2", ("hipDrvMemcpy3DAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpy3D_v2", ("hipDrvMemcpy3D", CONV_MEM, API_DRIVER)), + ("cuMemcpyAtoA_v2", ("hipMemcpyAtoA", CONV_MEM, API_DRIVER)), + ("cuMemcpyAtoD_v2", ("hipMemcpyAtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyAtoHAsync_v2", ("hipMemcpyAtoHAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyAtoH_v2", ("hipMemcpyAtoH", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoA_v2", ("hipMemcpyDtoA", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoD", ("hipMemcpyDtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoDAsync", ("hipMemcpyDtoDAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoH", ("hipMemcpyDtoH", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoHAsync", ("hipMemcpyDtoHAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoA_v2", ("hipMemcpyHtoA", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoD", ("hipMemcpyHtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoDAsync", ("hipMemcpyHtoDAsync", CONV_MEM, API_DRIVER)), + ("cuMemsetD16", ("hipMemsetD16", CONV_MEM, API_DRIVER)), + ("cuMemsetD32", ("hipMemsetD32", CONV_MEM, API_DRIVER)), + ("cuMemsetD8", ("hipMemsetD8", CONV_MEM, API_DRIVER)), + ("cuMemAddressFree", ("hipMemAddressFree", CONV_MEM, API_DRIVER)), + ("cuMemAddressReserve", ("hipMemAddressReserve", CONV_MEM, API_DRIVER)), + ("cuMemCreate", ("hipMemCreate", CONV_MEM, API_DRIVER)), + ("cuMemExportToShareableHandle", ("hipMemExportToShareableHandle", CONV_MEM, API_DRIVER)), + ("cuMemGetAccess", ("hipMemGetAccess", CONV_MEM, API_DRIVER)), + ("cuMemGetAllocationGranularity", ("hipMemGetAllocationGranularity", CONV_MEM, API_DRIVER)), + ("cuMemGetAllocationPropertiesFromHandle", ("hipMemGetAllocationPropertiesFromHandle", CONV_MEM, API_DRIVER)), + ("cuMemImportFromShareableHandle", ("hipMemImportFromShareableHandle", CONV_MEM, API_DRIVER)), + ("cuMemMap", ("hipMemMap", CONV_MEM, API_DRIVER)), + ("cuMemMapArrayAsync", ("hipMemMapArrayAsync", CONV_MEM, API_DRIVER)), + ("cuMemRelease", ("hipMemRelease", CONV_MEM, API_DRIVER)), + ("cuMemRetainAllocationHandle", ("hipMemRetainAllocationHandle", CONV_MEM, API_DRIVER)), + ("cuMemSetAccess", ("hipMemSetAccess", CONV_MEM, API_DRIVER)), + ("cuMemUnmap", ("hipMemUnmap", CONV_MEM, API_DRIVER)), + ("cuMemAllocAsync", ("hipMallocAsync", CONV_MEM, API_DRIVER)), + ("cuMemAllocFromPoolAsync", ("hipMallocFromPoolAsync", CONV_MEM, API_DRIVER)), + ("cuMemFreeAsync", ("hipFreeAsync", CONV_MEM, API_DRIVER)), + ("cuMemPoolCreate", ("hipMemPoolCreate", CONV_MEM, API_DRIVER)), + ("cuMemPoolDestroy", ("hipMemPoolDestroy", CONV_MEM, API_DRIVER)), + ("cuMemPoolExportPointer", ("hipMemPoolExportPointer", CONV_MEM, API_DRIVER)), + ("cuMemPoolExportToShareableHandle", ("hipMemPoolExportToShareableHandle", CONV_MEM, API_DRIVER)), + ("cuMemPoolGetAccess", ("hipMemPoolGetAccess", CONV_MEM, API_DRIVER)), + ("cuMemPoolGetAttribute", ("hipMemPoolGetAttribute", CONV_MEM, API_DRIVER)), + ("cuMemPoolImportFromShareableHandle", ("hipMemPoolImportFromShareableHandle", CONV_MEM, API_DRIVER)), + ("cuMemPoolImportPointer", ("hipMemPoolImportPointer", CONV_MEM, API_DRIVER)), + ("cuMemPoolSetAccess", ("hipMemPoolSetAccess", CONV_MEM, API_DRIVER)), + ("cuMemPoolSetAttribute", ("hipMemPoolSetAttribute", CONV_MEM, API_DRIVER)), + ("cuMemPoolTrimTo", ("hipMemPoolTrimTo", CONV_MEM, API_DRIVER)), + ( + "cuPointerGetAttributes", + ("hipPointerGetAttributes", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuPointerSetAttribute", + ("hipPointerSetAttribute", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_TR_FILTER_MODE_POINT", ("hipFilterModePoint", CONV_TEX, API_DRIVER)), + ( + "CU_TR_FILTER_MODE_LINEAR", + ("hipFilterModeLinear", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetAddress", + ("hipTexRefGetAddress", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetAddressMode", + ("hipTexRefGetAddressMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetArray", + ("hipTexRefGetArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetBorderColor", + ("hipTexRefGetBorderColor", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetFilterMode", + ("hipTexRefGetFilterMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetFlags", + ("hipTexRefGetFlags", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetFormat", + ("hipTexRefGetFormat", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMaxAnisotropy", + ("hipTexRefGetMaxAnisotropy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmapFilterMode", + ("hipTexRefGetMipmapFilterMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmapLevelBias", + ("hipTexRefGetMipmapLevelBias", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmapLevelClamp", + ("hipTexRefGetMipmapLevelClamp", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmappedArray", + ("hipTexRefGetMipmappedArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetAddress", + ("hipTexRefSetAddress", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetAddress2D", + ("hipTexRefSetAddress2D", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuTexRefSetAddressMode", ("hipTexRefSetAddressMode", CONV_TEX, API_DRIVER)), + ("cuTexRefSetArray", ("hipTexRefSetArray", CONV_TEX, API_DRIVER)), + ( + "cuTexRefSetBorderColor", + ("hipTexRefSetBorderColor", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuTexRefSetFilterMode", ("hipTexRefSetFilterMode", CONV_TEX, API_DRIVER)), + ("cuTexRefSetFlags", ("hipTexRefSetFlags", CONV_TEX, API_DRIVER)), + ("cuTexRefSetFormat", ("hipTexRefSetFormat", CONV_TEX, API_DRIVER)), + ( + "cuTexRefSetMaxAnisotropy", + ("hipTexRefSetMaxAnisotropy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmapFilterMode", + ("hipTexRefSetMipmapFilterMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmapLevelBias", + ("hipTexRefSetMipmapLevelBias", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmapLevelClamp", + ("hipTexRefSetMipmapLevelClamp", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmappedArray", + ("hipTexRefSetMipmappedArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuTexRefCreate", ("hipTexRefCreate", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuTexRefDestroy", + ("hipTexRefDestroy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfRefGetArray", + ("hipSurfRefGetArray", CONV_SURFACE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfRefSetArray", + ("hipSurfRefSetArray", CONV_SURFACE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectCreate", + ("hipTexObjectCreate", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectDestroy", + ("hipTexObjectDestroy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectGetResourceDesc", + ("hipTexObjectGetResourceDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectGetResourceViewDesc", + ("hipTexObjectGetResourceViewDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectGetTextureDesc", + ("hipTexObjectGetTextureDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfObjectCreate", + ("hipSurfObjectCreate", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfObjectDestroy", + ("hipSurfObjectDestroy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfObjectGetResourceDesc", + ("hipSurfObjectGetResourceDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsMapResources", + ("hipGraphicsMapResources", CONV_GRAPHICS, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsResourceGetMappedMipmappedArray", + ( + "hipGraphicsResourceGetMappedMipmappedArray", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsResourceGetMappedPointer", + ( + "hipGraphicsResourceGetMappedPointer", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsResourceSetMapFlags", + ( + "hipGraphicsResourceSetMapFlags", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsSubResourceGetMappedArray", + ( + "hipGraphicsSubResourceGetMappedArray", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsUnmapResources", + ("hipGraphicsUnmapResources", CONV_GRAPHICS, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsUnregisterResource", + ( + "hipGraphicsUnregisterResource", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuProfilerInitialize", + ("hipProfilerInitialize", CONV_OTHER, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuProfilerStart", ("hipProfilerStart", CONV_OTHER, API_DRIVER)), + ("cuProfilerStop", ("hipProfilerStop", CONV_OTHER, API_DRIVER)), + ( + "CU_GL_DEVICE_LIST_ALL", + ("HIP_GL_DEVICE_LIST_ALL", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GL_DEVICE_LIST_CURRENT_FRAME", + ("HIP_GL_DEVICE_LIST_CURRENT_FRAME", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GL_DEVICE_LIST_NEXT_FRAME", + ("HIP_GL_DEVICE_LIST_NEXT_FRAME", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuGLGetDevices", ("hipGLGetDevices", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuGraphicsGLRegisterBuffer", + ("hipGraphicsGLRegisterBuffer", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsGLRegisterImage", + ("hipGraphicsGLRegisterImage", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuWGLGetDevice", ("hipWGLGetDevice", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CU_GL_MAP_RESOURCE_FLAGS_NONE", + ("HIP_GL_MAP_RESOURCE_FLAGS_NONE", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GL_MAP_RESOURCE_FLAGS_READ_ONLY", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_READ_ONLY", + CONV_GL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + CONV_GL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("cuGLCtxCreate", ("hipGLCtxCreate", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ("cuGLInit", ("hipGLInit", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuGLMapBufferObject", + ("hipGLMapBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLMapBufferObjectAsync", + ("hipGLMapBufferObjectAsync", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLRegisterBufferObject", + ("hipGLRegisterBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLSetBufferObjectMapFlags", + ("hipGLSetBufferObjectMapFlags", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLUnmapBufferObject", + ("hipGLUnmapBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLUnmapBufferObjectAsync", + ("hipGLUnmapBufferObjectAsync", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLUnregisterBufferObject", + ("hipGLUnregisterBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_DEVICE_LIST_ALL", + ("HIP_D3D9_DEVICE_LIST_ALL", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_DEVICE_LIST_CURRENT_FRAME", + ( + "HIP_D3D9_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D9_DEVICE_LIST_NEXT_FRAME", + ("HIP_D3D9_DEVICE_LIST_NEXT_FRAME", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9CtxCreate", + ("hipD3D9CtxCreate", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9CtxCreateOnDevice", + ("hipD3D9CtxCreateOnDevice", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9GetDevice", + ("hipD3D9GetDevice", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9GetDevices", + ("hipD3D9GetDevices", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9GetDirect3DDevice", + ("hipD3D9GetDirect3DDevice", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsD3D9RegisterResource", + ("hipGraphicsD3D9RegisterResource", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_MAPRESOURCE_FLAGS_NONE", + ("HIP_D3D9_MAPRESOURCE_FLAGS_NONE", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_MAPRESOURCE_FLAGS_READONLY", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D9_MAPRESOURCE_FLAGS_WRITEDISCARD", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D9_REGISTER_FLAGS_NONE", + ("HIP_D3D9_REGISTER_FLAGS_NONE", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_REGISTER_FLAGS_ARRAY", + ("HIP_D3D9_REGISTER_FLAGS_ARRAY", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9MapResources", + ("hipD3D9MapResources", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9RegisterResource", + ("hipD3D9RegisterResource", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedArray", + ("hipD3D9ResourceGetMappedArray", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedPitch", + ("hipD3D9ResourceGetMappedPitch", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedPointer", + ("hipD3D9ResourceGetMappedPointer", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedSize", + ("hipD3D9ResourceGetMappedSize", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetSurfaceDimensions", + ( + "hipD3D9ResourceGetSurfaceDimensions", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D9ResourceSetMapFlags", + ("hipD3D9ResourceSetMapFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9UnmapResources", + ("hipD3D9UnmapResources", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9UnregisterResource", + ("hipD3D9UnregisterResource", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D10_DEVICE_LIST_ALL", + ("HIP_D3D10_DEVICE_LIST_ALL", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D10_DEVICE_LIST_CURRENT_FRAME", + ( + "HIP_D3D10_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_DEVICE_LIST_NEXT_FRAME", + ( + "HIP_D3D10_DEVICE_LIST_NEXT_FRAME", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D10GetDevice", + ("hipD3D10GetDevice", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10GetDevices", + ("hipD3D10GetDevices", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsD3D10RegisterResource", + ( + "hipGraphicsD3D10RegisterResource", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_MAPRESOURCE_FLAGS_NONE", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_NONE", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_MAPRESOURCE_FLAGS_READONLY", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_MAPRESOURCE_FLAGS_WRITEDISCARD", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_REGISTER_FLAGS_NONE", + ("HIP_D3D10_REGISTER_FLAGS_NONE", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D10_REGISTER_FLAGS_ARRAY", + ("HIP_D3D10_REGISTER_FLAGS_ARRAY", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10CtxCreate", + ("hipD3D10CtxCreate", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10CtxCreateOnDevice", + ("hipD3D10CtxCreateOnDevice", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10GetDirect3DDevice", + ("hipD3D10GetDirect3DDevice", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10MapResources", + ("hipD3D10MapResources", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10RegisterResource", + ("hipD3D10RegisterResource", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetMappedArray", + ("hipD3D10ResourceGetMappedArray", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetMappedPitch", + ("hipD3D10ResourceGetMappedPitch", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetMappedPointer", + ( + "hipD3D10ResourceGetMappedPointer", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D10ResourceGetMappedSize", + ("hipD3D10ResourceGetMappedSize", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetSurfaceDimensions", + ( + "hipD3D10ResourceGetSurfaceDimensions", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD310ResourceSetMapFlags", + ("hipD3D10ResourceSetMapFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10UnmapResources", + ("hipD3D10UnmapResources", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10UnregisterResource", + ("hipD3D10UnregisterResource", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D11_DEVICE_LIST_ALL", + ("HIP_D3D11_DEVICE_LIST_ALL", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D11_DEVICE_LIST_CURRENT_FRAME", + ( + "HIP_D3D11_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D11, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D11_DEVICE_LIST_NEXT_FRAME", + ( + "HIP_D3D11_DEVICE_LIST_NEXT_FRAME", + CONV_D3D11, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D11GetDevice", + ("hipD3D11GetDevice", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D11GetDevices", + ("hipD3D11GetDevices", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsD3D11RegisterResource", + ( + "hipGraphicsD3D11RegisterResource", + CONV_D3D11, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D11CtxCreate", + ("hipD3D11CtxCreate", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D11CtxCreateOnDevice", + ("hipD3D11CtxCreateOnDevice", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D11GetDirect3DDevice", + ("hipD3D11GetDirect3DDevice", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsVDPAURegisterOutputSurface", + ( + "hipGraphicsVDPAURegisterOutputSurface", + CONV_VDPAU, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsVDPAURegisterVideoSurface", + ( + "hipGraphicsVDPAURegisterVideoSurface", + CONV_VDPAU, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuVDPAUGetDevice", + ("hipVDPAUGetDevice", CONV_VDPAU, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuVDPAUCtxCreate", + ("hipVDPAUCtxCreate", CONV_VDPAU, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerAcquireFrame", + ("hipEGLStreamConsumerAcquireFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerConnect", + ("hipEGLStreamConsumerConnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerConnectWithFlags", + ( + "hipEGLStreamConsumerConnectWithFlags", + CONV_EGL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuEGLStreamConsumerDisconnect", + ("hipEGLStreamConsumerDisconnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerReleaseFrame", + ("hipEGLStreamConsumerReleaseFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerConnect", + ("hipEGLStreamProducerConnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerDisconnect", + ("hipEGLStreamProducerDisconnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerPresentFrame", + ("hipEGLStreamProducerPresentFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerReturnFrame", + ("hipEGLStreamProducerReturnFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsEGLRegisterImage", + ("hipGraphicsEGLRegisterImage", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsResourceGetMappedEglFrame", + ( + "hipGraphicsResourceGetMappedEglFrame", + CONV_EGL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("cudaDataType_t", ("hipDataType", CONV_TYPE, API_RUNTIME)), + ("cudaDataType", ("hipDataType", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_32F", ("HIP_R_32F", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_64F", ("HIP_R_64F", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_16F", ("HIP_R_16F", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_8I", ("HIP_R_8I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_32F", ("HIP_C_32F", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_64F", ("HIP_C_64F", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_16F", ("HIP_C_16F", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_8I", ("HIP_C_8I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_8U", ("HIP_R_8U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_8U", ("HIP_C_8U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_32I", ("HIP_R_32I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_32I", ("HIP_C_32I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_32U", ("HIP_R_32U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_32U", ("HIP_C_32U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_16BF", ("HIP_R_16BF", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_16BF", ("HIP_C_16BF", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_4I", ("HIP_R_4I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_4I", ("HIP_C_4I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_4U", ("HIP_R_4U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_4U", ("HIP_C_4U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_16I", ("HIP_R_16I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_16I", ("HIP_C_16I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_16U", ("HIP_R_16U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_16U", ("HIP_C_16U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_64I", ("HIP_R_64I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_64I", ("HIP_C_64I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_64U", ("HIP_R_64U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_64U", ("HIP_C_64U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_8F_E4M3", ("HIP_R_8F_E4M3", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_8F_E5M2", ("HIP_R_8F_E5M2", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_4F_E2M1", ("HIP_R_4F_E2M1", CONV_TYPE, API_RUNTIME)), + ( + "MAJOR_VERSION", + ("hipLibraryMajorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "MINOR_VERSION", + ("hipLibraryMinorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "PATCH_LEVEL", + ("hipLibraryPatchVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAttachGlobal", + ("hipMemAttachGlobal", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAttachHost", + ("hipMemAttachHost", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAttachSingle", + ("hipMemAttachSingle", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaOccupancyDefault", + ("hipOccupancyDefault", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaOccupancyDisableCachingOverride", + ( + "hipOccupancyDisableCachingOverride", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaGetLastError", ("hipGetLastError", CONV_ERROR, API_RUNTIME)), + ("cudaPeekAtLastError", ("hipPeekAtLastError", CONV_ERROR, API_RUNTIME)), + ("cudaGetErrorName", ("hipGetErrorName", CONV_ERROR, API_RUNTIME)), + ("cudaGetErrorString", ("hipGetErrorString", CONV_ERROR, API_RUNTIME)), + ("cudaMemcpy3DParms", ("hipMemcpy3DParms", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpy3DPeerParms", + ("hipMemcpy3DPeerParms", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpy", ("hipMemcpy", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyToArray", ("hipMemcpyToArray", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyToSymbol", ("hipMemcpyToSymbol", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyToSymbolAsync", ("hipMemcpyToSymbolAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyAsync", ("hipMemcpyAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpy2D", ("hipMemcpy2D", CONV_MEM, API_RUNTIME)), + ("cudaMemcpy2DAsync", ("hipMemcpy2DAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpy2DToArray", ("hipMemcpy2DToArray", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpy2DArrayToArray", + ("hipMemcpy2DArrayToArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy2DFromArray", + ("hipMemcpy2DFromArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy2DFromArrayAsync", + ("hipMemcpy2DFromArrayAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy2DToArrayAsync", + ("hipMemcpy2DToArrayAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpy3D", ("hipMemcpy3D", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpy3DAsync", + ("hipMemcpy3DAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy3DPeer", + ("hipMemcpy3DPeer", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy3DPeerAsync", + ("hipMemcpy3DPeerAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpyArrayToArray", + ("hipMemcpyArrayToArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpyFromArrayAsync", + ("hipMemcpyFromArrayAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpyFromSymbol", ("hipMemcpyFromSymbol", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpyFromSymbolAsync", + ("hipMemcpyFromSymbolAsync", CONV_MEM, API_RUNTIME), + ), + ("cudaMemAdvise", ("hipMemAdvise", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaMemRangeGetAttribute", + ("hipMemRangeGetAttribute", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeGetAttributes", + ("hipMemRangeGetAttributes", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseSetReadMostly", + ("hipMemAdviseSetReadMostly", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseUnsetReadMostly", + ("hipMemAdviseUnsetReadMostly", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseSetPreferredLocation", + ( + "hipMemAdviseSetPreferredLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaMemAdviseUnsetPreferredLocation", + ( + "hipMemAdviseUnsetPreferredLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaMemAdviseSetAccessedBy", + ("hipMemAdviseSetAccessedBy", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseUnsetAccessedBy", + ("hipMemAdviseUnsetAccessedBy", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeAttributeReadMostly", + ("hipMemRangeAttributeReadMostly", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeAttributePreferredLocation", + ( + "hipMemRangeAttributePreferredLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaMemRangeAttributeAccessedBy", + ("hipMemRangeAttributeAccessedBy", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeAttributeLastPrefetchLocation", + ( + "hipMemRangeAttributeLastPrefetchLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaMemcpyHostToHost", ("hipMemcpyHostToHost", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyHostToDevice", ("hipMemcpyHostToDevice", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyDeviceToHost", ("hipMemcpyDeviceToHost", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpyDeviceToDevice", + ("hipMemcpyDeviceToDevice", CONV_MEM, API_RUNTIME), + ), + ("cudaMemcpyDefault", ("hipMemcpyDefault", CONV_MEM, API_RUNTIME)), + ("cudaMemset", ("hipMemset", CONV_MEM, API_RUNTIME)), + ("cudaMemsetAsync", ("hipMemsetAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemset2D", ("hipMemset2D", CONV_MEM, API_RUNTIME)), + ( + "cudaMemset2DAsync", + ("hipMemset2DAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemset3D", ("hipMemset3D", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaMemset3DAsync", + ("hipMemset3DAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemGetInfo", ("hipMemGetInfo", CONV_MEM, API_RUNTIME)), + ("cudaDeviceGetDefaultMemPool", ("hipDeviceGetDefaultMemPool", CONV_MEM, API_RUNTIME)), + ("cudaMemAccessDesc", ("hipMemAccessDesc", CONV_MEM, API_RUNTIME)), + ("cudaMemAccessFlagsProtReadWrite", ("hipMemAccessFlagsProtReadWrite", CONV_MEM, API_RUNTIME)), + ("cudaMemLocationTypeDevice", ("hipMemLocationTypeDevice", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrReleaseThreshold", ("hipMemPoolAttrReleaseThreshold", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrReservedMemCurrent", ("hipMemPoolAttrReservedMemCurrent", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrReservedMemHigh", ("hipMemPoolAttrReservedMemHigh", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrUsedMemCurrent", ("hipMemPoolAttrUsedMemCurrent", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrUsedMemHigh", ("hipMemPoolAttrUsedMemHigh", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolGetAttribute", ("hipMemPoolGetAttribute", CONV_MEM, API_RUNTIME)), + ( + "cudaMemPoolReuseAllowInternalDependencies", + ("hipMemPoolReuseAllowInternalDependencies", CONV_MEM, API_RUNTIME) + ), + ("cudaMemPoolReuseAllowOpportunistic", ("hipMemPoolReuseAllowOpportunistic", CONV_MEM, API_RUNTIME)), + ( + "cudaMemPoolReuseFollowEventDependencies", + ("hipMemPoolReuseFollowEventDependencies", CONV_MEM, API_RUNTIME) + ), + ("cudaMemPoolSetAccess", ("hipMemPoolSetAccess", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolSetAttribute", ("hipMemPoolSetAttribute", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolTrimTo", ("hipMemPoolTrimTo", CONV_MEM, API_RUNTIME)), + ("cudaMemPool_t", ("hipMemPool_t", CONV_MEM, API_RUNTIME)), + ( + "cudaArrayGetInfo", + ("hipArrayGetInfo", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFreeMipmappedArray", + ("hipFreeMipmappedArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetMipmappedArrayLevel", + ("hipGetMipmappedArrayLevel", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSymbolAddress", + ("hipGetSymbolAddress", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSymbolSize", + ("hipGetSymbolSize", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemPrefetchAsync", + ("hipMemPrefetchAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMallocHost", ("hipHostMalloc", CONV_MEM, API_RUNTIME)), + ("cudaMallocArray", ("hipMallocArray", CONV_MEM, API_RUNTIME)), + ("cudaMallocAsync", ("hipMallocAsync", CONV_MEM, API_RUNTIME)), + ("cudaMalloc", ("hipMalloc", CONV_MEM, API_RUNTIME)), + ("cudaMalloc3D", ("hipMalloc3D", CONV_MEM, API_RUNTIME)), + ("cudaMalloc3DArray", ("hipMalloc3DArray", CONV_MEM, API_RUNTIME)), + ( + "cudaMallocManaged", + ("hipMallocManaged", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMallocMipmappedArray", + ("hipMallocMipmappedArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMallocPitch", ("hipMallocPitch", CONV_MEM, API_RUNTIME)), + ("cudaFreeHost", ("hipHostFree", CONV_MEM, API_RUNTIME)), + ("cudaFreeArray", ("hipFreeArray", CONV_MEM, API_RUNTIME)), + ("cudaFreeAsync", ("hipFreeAsync", CONV_MEM, API_RUNTIME)), + ("cudaFree", ("hipFree", CONV_MEM, API_RUNTIME)), + ("cudaHostRegister", ("hipHostRegister", CONV_MEM, API_RUNTIME)), + ("cudaHostUnregister", ("hipHostUnregister", CONV_MEM, API_RUNTIME)), + ("cudaHostAlloc", ("hipHostMalloc", CONV_MEM, API_RUNTIME)), + ("cudaMemoryTypeHost", ("hipMemoryTypeHost", CONV_MEM, API_RUNTIME)), + ("cudaMemoryTypeDevice", ("hipMemoryTypeDevice", CONV_MEM, API_RUNTIME)), + ("cudaMemoryTypeUnregistered", ("hipMemoryTypeUnregistered", CONV_MEM, API_RUNTIME)), + ("cudaMemoryTypeManaged", ("hipMemoryTypeManaged", CONV_MEM, API_RUNTIME)), + ("make_cudaExtent", ("make_hipExtent", CONV_MEM, API_RUNTIME)), + ("make_cudaPitchedPtr", ("make_hipPitchedPtr", CONV_MEM, API_RUNTIME)), + ("make_cudaPos", ("make_hipPos", CONV_MEM, API_RUNTIME)), + ("cudaHostAllocDefault", ("hipHostMallocDefault", CONV_MEM, API_RUNTIME)), + ("cudaHostAllocPortable", ("hipHostMallocPortable", CONV_MEM, API_RUNTIME)), + ("cudaHostAllocMapped", ("hipHostMallocMapped", CONV_MEM, API_RUNTIME)), + ("cudaHostNodeParams", ("hipHostNodeParams", CONV_MEM, API_RUNTIME)), + ( + "cudaHostAllocWriteCombined", + ("hipHostMallocWriteCombined", CONV_MEM, API_RUNTIME), + ), + ("cudaHostGetFlags", ("hipHostGetFlags", CONV_MEM, API_RUNTIME)), + ("cudaHostRegisterDefault", ("hipHostRegisterDefault", CONV_MEM, API_RUNTIME)), + ( + "cudaHostRegisterPortable", + ("hipHostRegisterPortable", CONV_MEM, API_RUNTIME), + ), + ("cudaHostRegisterMapped", ("hipHostRegisterMapped", CONV_MEM, API_RUNTIME)), + ( + "cudaHostRegisterIoMemory", + ("hipHostRegisterIoMemory", CONV_MEM, API_RUNTIME), + ), + # ("warpSize", ("hipWarpSize", CONV_SPECIAL_FUNC, API_RUNTIME), (HIP actually uses warpSize...)), + ("cudaEventCreate", ("hipEventCreate", CONV_EVENT, API_RUNTIME)), + ( + "cudaEventCreateWithFlags", + ("hipEventCreateWithFlags", CONV_EVENT, API_RUNTIME), + ), + ("cudaEventDestroy", ("hipEventDestroy", CONV_EVENT, API_RUNTIME)), + ("cudaEventRecord", ("hipEventRecord", CONV_EVENT, API_RUNTIME)), + ("cudaEventElapsedTime", ("hipEventElapsedTime", CONV_EVENT, API_RUNTIME)), + ("cudaEventSynchronize", ("hipEventSynchronize", CONV_EVENT, API_RUNTIME)), + ("cudaEventQuery", ("hipEventQuery", CONV_EVENT, API_RUNTIME)), + ("cudaEventDefault", ("hipEventDefault", CONV_EVENT, API_RUNTIME)), + ("cudaEventBlockingSync", ("hipEventBlockingSync", CONV_EVENT, API_RUNTIME)), + ("cudaEventDisableTiming", ("hipEventDisableTiming", CONV_EVENT, API_RUNTIME)), + ("cudaEventInterprocess", ("hipEventInterprocess", CONV_EVENT, API_RUNTIME)), + ("cudaStreamCreate", ("hipStreamCreate", CONV_STREAM, API_RUNTIME)), + ( + "cudaStreamCreateWithFlags", + ("hipStreamCreateWithFlags", CONV_STREAM, API_RUNTIME), + ), + ( + "cudaStreamCreateWithPriority", + ("hipStreamCreateWithPriority", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaStreamDestroy", ("hipStreamDestroy", CONV_STREAM, API_RUNTIME)), + ("cudaStreamWaitEvent", ("hipStreamWaitEvent", CONV_STREAM, API_RUNTIME)), + ("cudaStreamSynchronize", ("hipStreamSynchronize", CONV_STREAM, API_RUNTIME)), + ("cudaStreamGetFlags", ("hipStreamGetFlags", CONV_STREAM, API_RUNTIME)), + ("cudaStreamQuery", ("hipStreamQuery", CONV_STREAM, API_RUNTIME)), + ("cudaStreamAddCallback", ("hipStreamAddCallback", CONV_STREAM, API_RUNTIME)), + ( + "cudaStreamAttachMemAsync", + ("hipStreamAttachMemAsync", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaStreamGetPriority", + ("hipStreamGetPriority", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaCpuDeviceId", ("hipCpuDeviceId", CONV_TYPE, API_RUNTIME)), + ("cudaStreamDefault", ("hipStreamDefault", CONV_TYPE, API_RUNTIME)), + ("cudaStreamNonBlocking", ("hipStreamNonBlocking", CONV_TYPE, API_RUNTIME)), + ("cudaStreamGetCaptureInfo", ("hipStreamGetCaptureInfo", CONV_TYPE, API_RUNTIME)), + ("cudaStreamGetCaptureInfo_v2", ("hipStreamGetCaptureInfo_v2", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureStatus", ("hipStreamCaptureStatus", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureStatusActive", ("hipStreamCaptureStatusActive", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureStatusNone", ("hipStreamCaptureStatusNone", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureMode", ("hipStreamCaptureMode", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureModeGlobal", ("hipStreamCaptureModeGlobal", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureModeRelaxed", ("hipStreamCaptureModeRelaxed", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureModeThreadLocal", ("hipStreamCaptureModeThreadLocal", CONV_TYPE, API_RUNTIME)), + ("cudaStreamBeginCapture", ("hipStreamBeginCapture", CONV_TYPE, API_RUNTIME)), + ("cudaStreamEndCapture", ("hipStreamEndCapture", CONV_TYPE, API_RUNTIME)), + ("cudaStreamSetCaptureDependencies", ("hipStreamSetCaptureDependencies", CONV_STREAM, API_RUNTIME)), + ("cudaStreamUpdateCaptureDependencies", ("hipStreamUpdateCaptureDependencies", CONV_STREAM, API_RUNTIME)), + ("cudaGraphInstantiate", ("hipGraphInstantiate", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateWithFlags", ("hipGraphInstantiateWithFlags", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphInstantiateFlagAutoFreeOnLaunch", + ("hipGraphInstantiateFlagAutoFreeOnLaunch", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphDestroy", ("hipGraphDestroy", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecDestroy", ("hipGraphExecDestroy", CONV_TYPE, API_RUNTIME)), + ("cudaGraphLaunch", ("hipGraphLaunch", CONV_TYPE, API_RUNTIME)), + ("cudaGraphGetNodes", ("hipGraphGetNodes", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotPrint", ("hipGraphDebugDotPrint", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlagsVerbose", ("hipGraphDebugDotFlagsVerbose", CONV_NUMERIC_LITERAL, API_RUNTIME)), + ("cudaGraphRetainUserObject", ("hipGraphRetainUserObject", CONV_TYPE, API_RUNTIME)), + ("cudaGraphUserObjectMove", ("hipGraphUserObjectMove", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceGetGraphMemAttribute", ("hipDeviceGetGraphMemAttribute", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceGraphMemTrim", ("hipDeviceGraphMemTrim", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceSetGraphMemAttribute", ("hipDeviceSetGraphMemAttribute", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddChildGraphNode", ("hipGraphAddChildGraphNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddDependencies", ("hipGraphAddDependencies", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddEmptyNode", ("hipGraphAddEmptyNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddEventRecordNode", ("hipGraphAddEventRecordNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddEventWaitNode", ("hipGraphAddEventWaitNode", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphAddExternalSemaphoresSignalNode", + ("hipGraphAddExternalSemaphoresSignalNode", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphAddExternalSemaphoresWaitNode", ("hipGraphAddExternalSemaphoresWaitNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddHostNode", ("hipGraphAddHostNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddKernelNode", ("hipGraphAddKernelNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemAllocNode", ("hipGraphAddMemAllocNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemFreeNode", ("hipGraphAddMemFreeNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemcpyNode", ("hipGraphAddMemcpyNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemcpyNode1D", ("hipGraphAddMemcpyNode1D", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemcpyNodeFromSymbol", ("hipGraphAddMemcpyNodeFromSymbol", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemcpyNodeToSymbol", ("hipGraphAddMemcpyNodeToSymbol", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemsetNode", ("hipGraphAddMemsetNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddNode", ("hipGraphAddNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphChildGraphNodeGetGraph", ("hipGraphChildGraphNodeGetGraph", CONV_TYPE, API_RUNTIME)), + ("cudaGraphClone", ("hipGraphClone", CONV_TYPE, API_RUNTIME)), + ("cudaGraphCreate", ("hipGraphCreate", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDestroyNode", ("hipGraphDestroyNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphEventRecordNodeGetEvent", ("hipGraphEventRecordNodeGetEvent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphEventRecordNodeSetEvent", ("hipGraphEventRecordNodeSetEvent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphEventWaitNodeGetEvent", ("hipGraphEventWaitNodeGetEvent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphEventWaitNodeSetEvent", ("hipGraphEventWaitNodeSetEvent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecChildGraphNodeSetParams", ("hipGraphExecChildGraphNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecEventRecordNodeSetEvent", ("hipGraphExecEventRecordNodeSetEvent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecEventWaitNodeSetEvent", ("hipGraphExecEventWaitNodeSetEvent", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphExecExternalSemaphoresSignalNodeSetParams", + ("hipGraphExecExternalSemaphoresSignalNodeSetParams", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExecExternalSemaphoresWaitNodeSetParams", + ("hipGraphExecExternalSemaphoresWaitNodeSetParams", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphExecGetFlags", ("hipGraphExecGetFlags", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecHostNodeSetParams", ("hipGraphExecHostNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecKernelNodeSetParams", ("hipGraphExecKernelNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecMemcpyNodeSetParams", ("hipGraphExecMemcpyNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecMemcpyNodeSetParams1D", ("hipGraphExecMemcpyNodeSetParams1D", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphExecMemcpyNodeSetParamsFromSymbol", + ("hipGraphExecMemcpyNodeSetParamsFromSymbol", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExecMemcpyNodeSetParamsToSymbol", + ("hipGraphExecMemcpyNodeSetParamsToSymbol", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphExecMemsetNodeSetParams", ("hipGraphExecMemsetNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecNodeSetParams", ("hipGraphExecNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecUpdate", ("hipGraphExecUpdate", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphExternalSemaphoresSignalNodeGetParams", + ("hipGraphExternalSemaphoresSignalNodeGetParams", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExternalSemaphoresSignalNodeSetParams", + ("hipGraphExternalSemaphoresSignalNodeSetParams", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExternalSemaphoresWaitNodeGetParams", + ("hipGraphExternalSemaphoresWaitNodeGetParams", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExternalSemaphoresWaitNodeSetParams", + ("hipGraphExternalSemaphoresWaitNodeSetParams", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphGetEdges", ("hipGraphGetEdges", CONV_TYPE, API_RUNTIME)), + ("cudaGraphGetRootNodes", ("hipGraphGetRootNodes", CONV_TYPE, API_RUNTIME)), + ("cudaGraphHostNodeGetParams", ("hipGraphHostNodeGetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphHostNodeSetParams", ("hipGraphHostNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateWithParams", ("hipGraphInstantiateWithParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphKernelNodeCopyAttributes", ("hipGraphKernelNodeCopyAttributes", CONV_TYPE, API_RUNTIME)), + ("cudaGraphKernelNodeGetAttribute", ("hipGraphKernelNodeGetAttribute", CONV_TYPE, API_RUNTIME)), + ("cudaGraphKernelNodeGetParams", ("hipGraphKernelNodeGetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphKernelNodeSetAttribute", ("hipGraphKernelNodeSetAttribute", CONV_TYPE, API_RUNTIME)), + ("cudaGraphKernelNodeSetParams", ("hipGraphKernelNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphLaunch", ("hipGraphLaunch", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemAllocNodeGetParams", ("hipGraphMemAllocNodeGetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemFreeNodeGetParams", ("hipGraphMemFreeNodeGetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemcpyNodeGetParams", ("hipGraphMemcpyNodeGetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemcpyNodeSetParams", ("hipGraphMemcpyNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemcpyNodeSetParams1D", ("hipGraphMemcpyNodeSetParams1D", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemcpyNodeSetParamsFromSymbol", ("hipGraphMemcpyNodeSetParamsFromSymbol", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemcpyNodeSetParamsToSymbol", ("hipGraphMemcpyNodeSetParamsToSymbol", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemsetNodeGetParams", ("hipGraphMemsetNodeGetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemsetNodeSetParams", ("hipGraphMemsetNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeFindInClone", ("hipGraphNodeFindInClone", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeGetDependencies", ("hipGraphNodeGetDependencies", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeGetDependentNodes", ("hipGraphNodeGetDependentNodes", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeGetEnabled", ("hipGraphNodeGetEnabled", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeGetType", ("hipGraphNodeGetType", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeSetEnabled", ("hipGraphNodeSetEnabled", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeSetParams", ("hipGraphNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphReleaseUserObject", ("hipGraphReleaseUserObject", CONV_TYPE, API_RUNTIME)), + ("cudaGraphRemoveDependencies", ("hipGraphRemoveDependencies", CONV_TYPE, API_RUNTIME)), + ("cudaGraphUpload", ("hipGraphUpload", CONV_TYPE, API_RUNTIME)), + ("cudaUserObjectRelease", ("hipUserObjectRelease", CONV_TYPE, API_RUNTIME)), + ("cudaUserObjectRetain", ("hipUserObjectRetain", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlags", ("hipGraphDebugDotFlags", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlagsEventNodeParams", ("hipGraphDebugDotFlagsEventNodeParams", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphDebugDotFlagsExtSemasSignalNodeParams", + ("hipGraphDebugDotFlagsExtSemasSignalNodeParams", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphDebugDotFlagsExtSemasWaitNodeParams", + ("hipGraphDebugDotFlagsExtSemasWaitNodeParams", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphDebugDotFlagsHandles", ("hipGraphDebugDotFlagsHandles", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlagsHostNodeParams", ("hipGraphDebugDotFlagsHostNodeParams", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphDebugDotFlagsKernelNodeAttributes", + ("hipGraphDebugDotFlagsKernelNodeAttributes", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphDebugDotFlagsKernelNodeParams", ("hipGraphDebugDotFlagsKernelNodeParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlagsMemcpyNodeParams", ("hipGraphDebugDotFlagsMemcpyNodeParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlagsMemsetNodeParams", ("hipGraphDebugDotFlagsMemsetNodeParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDependencyType", ("hipGraphDependencyType", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDependencyTypeDefault", ("hipGraphDependencyTypeDefault", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDependencyTypeProgrammatic", ("hipGraphDependencyTypeProgrammatic", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDependencyType_enum", ("hipGraphDependencyType", CONV_TYPE, API_RUNTIME)), + ("cudaGraphEdgeData", ("hipGraphEdgeData", CONV_TYPE, API_RUNTIME)), + ("cudaGraphEdgeData_st", ("hipGraphEdgeData", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecUpdateError", ("hipGraphExecUpdateError", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphExecUpdateErrorFunctionChanged", + ("hipGraphExecUpdateErrorFunctionChanged", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExecUpdateErrorNodeTypeChanged", + ("hipGraphExecUpdateErrorNodeTypeChanged", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphExecUpdateErrorNotSupported", ("hipGraphExecUpdateErrorNotSupported", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphExecUpdateErrorParametersChanged", + ("hipGraphExecUpdateErrorParametersChanged", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExecUpdateErrorTopologyChanged", + ("hipGraphExecUpdateErrorTopologyChanged", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExecUpdateErrorUnsupportedFunctionChange", + ("hipGraphExecUpdateErrorUnsupportedFunctionChange", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphExecUpdateResult", ("hipGraphExecUpdateResult", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecUpdateSuccess", ("hipGraphExecUpdateSuccess", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateError", ("hipGraphInstantiateError", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateFlagDeviceLaunch", ("hipGraphInstantiateFlagDeviceLaunch", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateFlagUpload", ("hipGraphInstantiateFlagUpload", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphInstantiateFlagUseNodePriority", + ("hipGraphInstantiateFlagUseNodePriority", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphInstantiateFlags", ("hipGraphInstantiateFlags", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateInvalidStructure", ("hipGraphInstantiateInvalidStructure", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphInstantiateMultipleDevicesNotSupported", + ("hipGraphInstantiateMultipleDevicesNotSupported", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphInstantiateNodeOperationNotSupported", + ("hipGraphInstantiateNodeOperationNotSupported", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphInstantiateParams", ("hipGraphInstantiateParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateParams_st", ("hipGraphInstantiateParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateResult", ("hipGraphInstantiateResult", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateSuccess", ("hipGraphInstantiateSuccess", CONV_TYPE, API_RUNTIME)), + ("cudaGraphKernelNodePortDefault", ("hipGraphKernelNodePortDefault", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphKernelNodePortLaunchCompletion", + ("hipGraphKernelNodePortLaunchCompletion", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphKernelNodePortProgrammatic", ("hipGraphKernelNodePortProgrammatic", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemAttrReservedMemCurrent", ("hipGraphMemAttrReservedMemCurrent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemAttrReservedMemHigh", ("hipGraphMemAttrReservedMemHigh", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemAttrUsedMemCurrent", ("hipGraphMemAttrUsedMemCurrent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemAttrUsedMemHigh", ("hipGraphMemAttrUsedMemHigh", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemAttributeType", ("hipGraphMemAttributeType", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeParams", ("hipGraphNodeParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeType", ("hipGraphNodeType", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeConditional", ("hipGraphNodeTypeConditional", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeCount", ("hipGraphNodeTypeCount", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeEmpty", ("hipGraphNodeTypeEmpty", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeEventRecord", ("hipGraphNodeTypeEventRecord", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeExtSemaphoreSignal", ("hipGraphNodeTypeExtSemaphoreSignal", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeExtSemaphoreWait", ("hipGraphNodeTypeExtSemaphoreWait", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeGraph", ("hipGraphNodeTypeGraph", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeHost", ("hipGraphNodeTypeHost", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeKernel", ("hipGraphNodeTypeKernel", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeMemAlloc", ("hipGraphNodeTypeMemAlloc", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeMemFree", ("hipGraphNodeTypeMemFree", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeMemcpy", ("hipGraphNodeTypeMemcpy", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeMemset", ("hipGraphNodeTypeMemset", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeWaitEvent", ("hipGraphNodeTypeWaitEvent", CONV_TYPE, API_RUNTIME)), + ("cudaUserObject_t", ("hipUserObject_t", CONV_TYPE, API_RUNTIME)), + ("cudaUserObjectCreate", ("hipUserObjectCreate", CONV_TYPE, API_RUNTIME)), + ("cudaUserObjectNoDestructorSync", ("hipUserObjectNoDestructorSync", CONV_TYPE, API_RUNTIME)), + ("cudaThreadExchangeStreamCaptureMode", ("hipThreadExchangeStreamCaptureMode", CONV_TYPE, API_RUNTIME)), + ("cudaStreamIsCapturing", ("hipStreamIsCapturing", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceSynchronize", ("hipDeviceSynchronize", CONV_DEVICE, API_RUNTIME)), + ("cudaDeviceReset", ("hipDeviceReset", CONV_DEVICE, API_RUNTIME)), + ("cudaSetDevice", ("hipSetDevice", CONV_DEVICE, API_RUNTIME)), + ("cudaGetDevice", ("hipGetDevice", CONV_DEVICE, API_RUNTIME)), + ("cudaGetDeviceCount", ("hipGetDeviceCount", CONV_DEVICE, API_RUNTIME)), + ("cudaChooseDevice", ("hipChooseDevice", CONV_DEVICE, API_RUNTIME)), + ("cudaThreadExit", ("hipDeviceReset", CONV_THREAD, API_RUNTIME)), + ( + "cudaThreadGetCacheConfig", + ("hipDeviceGetCacheConfig", CONV_THREAD, API_RUNTIME), + ), + ( + "cudaThreadGetLimit", + ("hipThreadGetLimit", CONV_THREAD, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaThreadSetCacheConfig", + ("hipDeviceSetCacheConfig", CONV_THREAD, API_RUNTIME), + ), + ( + "cudaThreadSetLimit", + ("hipThreadSetLimit", CONV_THREAD, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaThreadSynchronize", ("hipDeviceSynchronize", CONV_THREAD, API_RUNTIME)), + ("cudaDeviceGetAttribute", ("hipDeviceGetAttribute", CONV_DEVICE, API_RUNTIME)), + ( + "cudaDevAttrMaxThreadsPerBlock", + ("hipDeviceAttributeMaxThreadsPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxBlockDimX", + ("hipDeviceAttributeMaxBlockDimX", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxBlockDimY", + ("hipDeviceAttributeMaxBlockDimY", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxBlockDimZ", + ("hipDeviceAttributeMaxBlockDimZ", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxGridDimX", + ("hipDeviceAttributeMaxGridDimX", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxGridDimY", + ("hipDeviceAttributeMaxGridDimY", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxGridDimZ", + ("hipDeviceAttributeMaxGridDimZ", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxSharedMemoryPerBlock", + ("hipDeviceAttributeMaxSharedMemoryPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxSharedMemoryPerBlockOptin", + ("hipDeviceAttributeMaxSharedMemoryPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrTotalConstantMemory", + ("hipDeviceAttributeTotalConstantMemory", CONV_TYPE, API_RUNTIME), + ), + ("cudaDevAttrWarpSize", ("hipDeviceAttributeWarpSize", CONV_TYPE, API_RUNTIME)), + ( + "cudaDevAttrMaxPitch", + ("hipDeviceAttributeMaxPitch", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrMaxRegistersPerBlock", + ("hipDeviceAttributeMaxRegistersPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrClockRate", + ("hipDeviceAttributeClockRate", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrTextureAlignment", + ( + "hipDeviceAttributeTextureAlignment", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrGpuOverlap", + ("hipDeviceAttributeGpuOverlap", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrMultiProcessorCount", + ("hipDeviceAttributeMultiprocessorCount", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrKernelExecTimeout", + ( + "hipDeviceAttributeKernelExecTimeout", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrIntegrated", + ("hipDeviceAttributeIntegrated", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrCanMapHostMemory", + ( + "hipDeviceAttributeCanMapHostMemory", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrComputeMode", + ("hipDeviceAttributeComputeMode", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxTexture1DWidth", + ( + "hipDeviceAttributeMaxTexture1DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DWidth", + ( + "hipDeviceAttributeMaxTexture2DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DHeight", + ( + "hipDeviceAttributeMaxTexture2DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DWidth", + ( + "hipDeviceAttributeMaxTexture3DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DHeight", + ( + "hipDeviceAttributeMaxTexture3DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DDepth", + ( + "hipDeviceAttributeMaxTexture3DDepth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLayeredWidth", + ( + "hipDeviceAttributeMaxTexture2DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLayeredHeight", + ( + "hipDeviceAttributeMaxTexture2DLayeredHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLayeredLayers", + ( + "hipDeviceAttributeMaxTexture2DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrSurfaceAlignment", + ( + "hipDeviceAttributeSurfaceAlignment", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrConcurrentKernels", + ("hipDeviceAttributeConcurrentKernels", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrEccEnabled", + ("hipDeviceAttributeEccEnabled", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDevAttrPciBusId", ("hipDeviceAttributePciBusId", CONV_TYPE, API_RUNTIME)), + ( + "cudaDevAttrPciDeviceId", + ("hipDeviceAttributePciDeviceId", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrTccDriver", + ("hipDeviceAttributeTccDriver", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrMemoryClockRate", + ("hipDeviceAttributeMemoryClockRate", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrGlobalMemoryBusWidth", + ("hipDeviceAttributeMemoryBusWidth", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrL2CacheSize", + ("hipDeviceAttributeL2CacheSize", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxThreadsPerMultiProcessor", + ("hipDeviceAttributeMaxThreadsPerMultiProcessor", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrAsyncEngineCount", + ( + "hipDeviceAttributeAsyncEngineCount", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrUnifiedAddressing", + ( + "hipDeviceAttributeUnifiedAddressing", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture1DLayeredWidth", + ( + "hipDeviceAttributeMaxTexture1DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture1DLayeredLayers", + ( + "hipDeviceAttributeMaxTexture1DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DGatherWidth", + ( + "hipDeviceAttributeMaxTexture2DGatherWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DGatherHeight", + ( + "hipDeviceAttributeMaxTexture2DGatherHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DWidthAlt", + ( + "hipDeviceAttributeMaxTexture3DWidthAlternate", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DHeightAlt", + ( + "hipDeviceAttributeMaxTexture3DHeightAlternate", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DDepthAlt", + ( + "hipDeviceAttributeMaxTexture3DDepthAlternate", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrPciDomainId", + ("hipDeviceAttributePciDomainId", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrTexturePitchAlignment", + ( + "hipDeviceAttributeTexturePitchAlignment", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTextureCubemapWidth", + ( + "hipDeviceAttributeMaxTextureCubemapWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTextureCubemapLayeredWidth", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTextureCubemapLayeredLayers", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface1DWidth", + ( + "hipDeviceAttributeMaxSurface1DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DWidth", + ( + "hipDeviceAttributeMaxSurface2DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DHeight", + ( + "hipDeviceAttributeMaxSurface2DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface3DWidth", + ( + "hipDeviceAttributeMaxSurface3DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface3DHeight", + ( + "hipDeviceAttributeMaxSurface3DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface3DDepth", + ( + "hipDeviceAttributeMaxSurface3DDepth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface1DLayeredWidth", + ( + "hipDeviceAttributeMaxSurface1DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface1DLayeredLayers", + ( + "hipDeviceAttributeMaxSurface1DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DLayeredWidth", + ( + "hipDeviceAttributeMaxSurface2DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DLayeredHeight", + ( + "hipDeviceAttributeMaxSurface2DLayeredHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DLayeredLayers", + ( + "hipDeviceAttributeMaxSurface2DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurfaceCubemapWidth", + ( + "hipDeviceAttributeMaxSurfaceCubemapWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurfaceCubemapLayeredWidth", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurfaceCubemapLayeredLayers", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture1DLinearWidth", + ( + "hipDeviceAttributeMaxTexture1DLinearWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLinearWidth", + ( + "hipDeviceAttributeMaxTexture2DLinearWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLinearHeight", + ( + "hipDeviceAttributeMaxTexture2DLinearHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLinearPitch", + ( + "hipDeviceAttributeMaxTexture2DLinearPitch", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DMipmappedWidth", + ( + "hipDeviceAttributeMaxTexture2DMipmappedWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DMipmappedHeight", + ( + "hipDeviceAttributeMaxTexture2DMipmappedHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrComputeCapabilityMajor", + ("hipDeviceAttributeComputeCapabilityMajor", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrComputeCapabilityMinor", + ("hipDeviceAttributeComputeCapabilityMinor", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxTexture1DMipmappedWidth", + ( + "hipDeviceAttributeMaxTexture1DMipmappedWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrStreamPrioritiesSupported", + ( + "hipDeviceAttributeStreamPrioritiesSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrGlobalL1CacheSupported", + ( + "hipDeviceAttributeGlobalL1CacheSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrLocalL1CacheSupported", + ( + "hipDeviceAttributeLocalL1CacheSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSharedMemoryPerMultiprocessor", + ( + "hipDeviceAttributeMaxSharedMemoryPerMultiprocessor", + CONV_TYPE, + API_RUNTIME, + ), + ), + ( + "cudaDevAttrMaxRegistersPerMultiprocessor", + ( + "hipDeviceAttributeMaxRegistersPerMultiprocessor", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrManagedMemory", + ( + "hipDeviceAttributeManagedMemory", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrIsMultiGpuBoard", + ("hipDeviceAttributeIsMultiGpuBoard", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMultiGpuBoardGroupID", + ( + "hipDeviceAttributeMultiGpuBoardGroupID", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrHostNativeAtomicSupported", + ( + "hipDeviceAttributeHostNativeAtomicSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrSingleToDoublePrecisionPerfRatio", + ( + "hipDeviceAttributeSingleToDoublePrecisionPerfRatio", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrPageableMemoryAccess", + ( + "hipDeviceAttributePageableMemoryAccess", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrConcurrentManagedAccess", + ( + "hipDeviceAttributeConcurrentManagedAccess", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrComputePreemptionSupported", + ( + "hipDeviceAttributeComputePreemptionSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrCanUseHostPointerForRegisteredMem", + ( + "hipDeviceAttributeCanUseHostPointerForRegisteredMem", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaPointerGetAttributes", + ("hipPointerGetAttributes", CONV_MEM, API_RUNTIME), + ), + ( + "cudaHostGetDevicePointer", + ("hipHostGetDevicePointer", CONV_MEM, API_RUNTIME), + ), + ( + "cudaGetDeviceProperties", + ("hipGetDeviceProperties", CONV_DEVICE, API_RUNTIME), + ), + ("cudaDeviceGetPCIBusId", ("hipDeviceGetPCIBusId", CONV_DEVICE, API_RUNTIME)), + ( + "cudaDeviceGetByPCIBusId", + ("hipDeviceGetByPCIBusId", CONV_DEVICE, API_RUNTIME), + ), + ( + "cudaDeviceGetStreamPriorityRange", + ( + "hipDeviceGetStreamPriorityRange", + CONV_DEVICE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaSetValidDevices", + ("hipSetValidDevices", CONV_DEVICE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevP2PAttrPerformanceRank", + ( + "hipDeviceP2PAttributePerformanceRank", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevP2PAttrAccessSupported", + ( + "hipDeviceP2PAttributeAccessSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevP2PAttrNativeAtomicSupported", + ( + "hipDeviceP2PAttributeNativeAtomicSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDeviceGetP2PAttribute", + ("hipDeviceGetP2PAttribute", CONV_DEVICE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeDefault", + ("hipComputeModeDefault", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeExclusive", + ("hipComputeModeExclusive", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeProhibited", + ("hipComputeModeProhibited", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeExclusiveProcess", + ("hipComputeModeExclusiveProcess", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetDeviceFlags", + ("hipGetDeviceFlags", CONV_DEVICE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaSetDeviceFlags", ("hipSetDeviceFlags", CONV_DEVICE, API_RUNTIME)), + ("cudaDeviceScheduleAuto", ("hipDeviceScheduleAuto", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceScheduleSpin", ("hipDeviceScheduleSpin", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceScheduleYield", ("hipDeviceScheduleYield", CONV_TYPE, API_RUNTIME)), + ( + "cudaDeviceBlockingSync", + ("hipDeviceScheduleBlockingSync", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDeviceScheduleBlockingSync", + ("hipDeviceScheduleBlockingSync", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDeviceScheduleMask", + ("hipDeviceScheduleMask", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDeviceMapHost", ("hipDeviceMapHost", CONV_TYPE, API_RUNTIME)), + ( + "cudaDeviceLmemResizeToMax", + ("hipDeviceLmemResizeToMax", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDeviceMask", ("hipDeviceMask", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaDeviceSetCacheConfig", + ("hipDeviceSetCacheConfig", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaDeviceGetCacheConfig", + ("hipDeviceGetCacheConfig", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaFuncAttributes", + ("hipFuncAttributes", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaFuncAttributeMaxDynamicSharedMemorySize", + ("hipFuncAttributeMaxDynamicSharedMemorySize", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaFuncAttributePreferredSharedMemoryCarveout", + ("hipFuncAttributePreferredSharedMemoryCarveout", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaFuncSetAttribute", + ("hipFuncSetAttribute", CONV_EXEC, API_RUNTIME), + ), + ("cudaFuncSetCacheConfig", ("hipFuncSetCacheConfig", CONV_CACHE, API_RUNTIME)), + ( + "cudaFuncCachePreferNone", + ("hipFuncCachePreferNone", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaFuncCachePreferShared", + ("hipFuncCachePreferShared", CONV_CACHE, API_RUNTIME), + ), + ("cudaFuncCachePreferL1", ("hipFuncCachePreferL1", CONV_CACHE, API_RUNTIME)), + ( + "cudaFuncCachePreferEqual", + ("hipFuncCachePreferEqual", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaFuncGetAttributes", + ("hipFuncGetAttributes", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFuncSetSharedMemConfig", + ("hipFuncSetSharedMemConfig", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetParameterBuffer", + ("hipGetParameterBuffer", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaSetDoubleForDevice", + ("hipSetDoubleForDevice", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaSetDoubleForHost", + ("hipSetDoubleForHost", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaConfigureCall", + ("hipConfigureCall", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaLaunch", ("hipLaunch", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaLaunchCooperativeKernel", + ("hipLaunchCooperativeKernel", CONV_EXEC, API_RUNTIME), + ), + ("cudaLaunchHostFunc", ("hipLaunchHostFunc", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaSetupArgument", + ("hipSetupArgument", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDriverGetVersion", ("hipDriverGetVersion", CONV_VERSION, API_RUNTIME)), + ( + "cudaRuntimeGetVersion", + ("hipRuntimeGetVersion", CONV_VERSION, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaOccupancyMaxPotentialBlockSize", + ("hipOccupancyMaxPotentialBlockSize", CONV_OCCUPANCY, API_RUNTIME), + ), + ( + "cudaOccupancyMaxPotentialBlockSizeWithFlags", + ( + "hipOccupancyMaxPotentialBlockSizeWithFlags", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaOccupancyMaxActiveBlocksPerMultiprocessor", + ( + "hipOccupancyMaxActiveBlocksPerMultiprocessor", + CONV_OCCUPANCY, + API_RUNTIME, + ), + ), + ( + "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + ( + "hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaOccupancyMaxPotentialBlockSizeVariableSMem", + ( + "hipOccupancyMaxPotentialBlockSizeVariableSMem", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaOccupancyMaxPotentialBlockSizeVariableSMemWithFlags", + ( + "hipOccupancyMaxPotentialBlockSizeVariableSMemWithFlags", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaDeviceCanAccessPeer", ("hipDeviceCanAccessPeer", CONV_PEER, API_RUNTIME)), + ( + "cudaDeviceDisablePeerAccess", + ("hipDeviceDisablePeerAccess", CONV_PEER, API_RUNTIME), + ), + ( + "cudaDeviceEnablePeerAccess", + ("hipDeviceEnablePeerAccess", CONV_PEER, API_RUNTIME), + ), + ("cudaMemcpyPeerAsync", ("hipMemcpyPeerAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyPeer", ("hipMemcpyPeer", CONV_MEM, API_RUNTIME)), + ( + "cudaIpcMemLazyEnablePeerAccess", + ("hipIpcMemLazyEnablePeerAccess", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDeviceSetSharedMemConfig", + ("hipDeviceSetSharedMemConfig", CONV_DEVICE, API_RUNTIME), + ), + ( + "cudaDeviceGetSharedMemConfig", + ("hipDeviceGetSharedMemConfig", CONV_DEVICE, API_RUNTIME), + ), + ( + "cudaSharedMemBankSizeDefault", + ("hipSharedMemBankSizeDefault", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaSharedMemBankSizeFourByte", + ("hipSharedMemBankSizeFourByte", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaSharedMemBankSizeEightByte", + ("hipSharedMemBankSizeEightByte", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaLimitStackSize", + ("hipLimitStackSize", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaLimitPrintfFifoSize", + ("hipLimitPrintfFifoSize", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaLimitMallocHeapSize", ("hipLimitMallocHeapSize", CONV_TYPE, API_RUNTIME)), + ( + "cudaLimitDevRuntimeSyncDepth", + ("hipLimitDevRuntimeSyncDepth", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaLimitDevRuntimePendingLaunchCount", + ( + "hipLimitDevRuntimePendingLaunchCount", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaDeviceGetLimit", ("hipDeviceGetLimit", CONV_DEVICE, API_RUNTIME)), + ("cudaProfilerStart", ("hipProfilerStart", CONV_OTHER, API_RUNTIME)), + ("cudaProfilerStop", ("hipProfilerStop", CONV_OTHER, API_RUNTIME)), + ( + "cudaKeyValuePair", + ("hipKeyValuePair", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaCSV", ("hipCSV", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED)), + ("cudaReadModeElementType", ("hipReadModeElementType", CONV_TEX, API_RUNTIME)), + ( + "cudaReadModeNormalizedFloat", + ("hipReadModeNormalizedFloat", CONV_TEX, API_RUNTIME), + ), + ("cudaFilterModePoint", ("hipFilterModePoint", CONV_TEX, API_RUNTIME)), + ("cudaFilterModeLinear", ("hipFilterModeLinear", CONV_TEX, API_RUNTIME)), + ("cudaBindTexture", ("hipBindTexture", CONV_TEX, API_RUNTIME)), + ("cudaUnbindTexture", ("hipUnbindTexture", CONV_TEX, API_RUNTIME)), + ("cudaBindTexture2D", ("hipBindTexture2D", CONV_TEX, API_RUNTIME)), + ("cudaBindTextureToArray", ("hipBindTextureToArray", CONV_TEX, API_RUNTIME)), + ( + "cudaBindTextureToMipmappedArray", + ("hipBindTextureToMipmappedArray", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureAlignmentOffset", + ("hipGetTextureAlignmentOffset", CONV_TEX, API_RUNTIME), + ), + ("cudaGetTextureReference", ("hipGetTextureReference", CONV_TEX, API_RUNTIME)), + ( + "cudaChannelFormatKindSigned", + ("hipChannelFormatKindSigned", CONV_TEX, API_RUNTIME), + ), + ( + "cudaChannelFormatKindUnsigned", + ("hipChannelFormatKindUnsigned", CONV_TEX, API_RUNTIME), + ), + ( + "cudaChannelFormatKindFloat", + ("hipChannelFormatKindFloat", CONV_TEX, API_RUNTIME), + ), + ( + "cudaChannelFormatKindNone", + ("hipChannelFormatKindNone", CONV_TEX, API_RUNTIME), + ), + ("cudaCreateChannelDesc", ("hipCreateChannelDesc", CONV_TEX, API_RUNTIME)), + ("cudaGetChannelDesc", ("hipGetChannelDesc", CONV_TEX, API_RUNTIME)), + ("cudaResourceTypeArray", ("hipResourceTypeArray", CONV_TEX, API_RUNTIME)), + ( + "cudaResourceTypeMipmappedArray", + ("hipResourceTypeMipmappedArray", CONV_TEX, API_RUNTIME), + ), + ("cudaResourceTypeLinear", ("hipResourceTypeLinear", CONV_TEX, API_RUNTIME)), + ("cudaResourceTypePitch2D", ("hipResourceTypePitch2D", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatNone", ("hipResViewFormatNone", CONV_TEX, API_RUNTIME)), + ( + "cudaResViewFormatUnsignedChar1", + ("hipResViewFormatUnsignedChar1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedChar2", + ("hipResViewFormatUnsignedChar2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedChar4", + ("hipResViewFormatUnsignedChar4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedChar1", + ("hipResViewFormatSignedChar1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedChar2", + ("hipResViewFormatSignedChar2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedChar4", + ("hipResViewFormatSignedChar4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedShort1", + ("hipResViewFormatUnsignedShort1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedShort2", + ("hipResViewFormatUnsignedShort2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedShort4", + ("hipResViewFormatUnsignedShort4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedShort1", + ("hipResViewFormatSignedShort1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedShort2", + ("hipResViewFormatSignedShort2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedShort4", + ("hipResViewFormatSignedShort4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedInt1", + ("hipResViewFormatUnsignedInt1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedInt2", + ("hipResViewFormatUnsignedInt2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedInt4", + ("hipResViewFormatUnsignedInt4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedInt1", + ("hipResViewFormatSignedInt1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedInt2", + ("hipResViewFormatSignedInt2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedInt4", + ("hipResViewFormatSignedInt4", CONV_TEX, API_RUNTIME), + ), + ("cudaResViewFormatHalf1", ("hipResViewFormatHalf1", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatHalf2", ("hipResViewFormatHalf2", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatHalf4", ("hipResViewFormatHalf4", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatFloat1", ("hipResViewFormatFloat1", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatFloat2", ("hipResViewFormatFloat2", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatFloat4", ("hipResViewFormatFloat4", CONV_TEX, API_RUNTIME)), + ( + "cudaResViewFormatUnsignedBlockCompressed1", + ("hipResViewFormatUnsignedBlockCompressed1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed2", + ("hipResViewFormatUnsignedBlockCompressed2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed3", + ("hipResViewFormatUnsignedBlockCompressed3", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed4", + ("hipResViewFormatUnsignedBlockCompressed4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedBlockCompressed4", + ("hipResViewFormatSignedBlockCompressed4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed5", + ("hipResViewFormatUnsignedBlockCompressed5", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedBlockCompressed5", + ("hipResViewFormatSignedBlockCompressed5", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed6H", + ("hipResViewFormatUnsignedBlockCompressed6H", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedBlockCompressed6H", + ("hipResViewFormatSignedBlockCompressed6H", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed7", + ("hipResViewFormatUnsignedBlockCompressed7", CONV_TEX, API_RUNTIME), + ), + ("cudaAddressModeWrap", ("hipAddressModeWrap", CONV_TEX, API_RUNTIME)), + ("cudaAddressModeClamp", ("hipAddressModeClamp", CONV_TEX, API_RUNTIME)), + ("cudaAddressModeMirror", ("hipAddressModeMirror", CONV_TEX, API_RUNTIME)), + ("cudaAddressModeBorder", ("hipAddressModeBorder", CONV_TEX, API_RUNTIME)), + ("cudaCreateTextureObject", ("hipCreateTextureObject", CONV_TEX, API_RUNTIME)), + ( + "cudaDestroyTextureObject", + ("hipDestroyTextureObject", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureObjectResourceDesc", + ("hipGetTextureObjectResourceDesc", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureObjectResourceViewDesc", + ("hipGetTextureObjectResourceViewDesc", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureObjectTextureDesc", + ("hipGetTextureObjectTextureDesc", CONV_TEX, API_RUNTIME), + ), + ( + "cudaBindSurfaceToArray", + ("hipBindSurfaceToArray", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSurfaceReference", + ("hipGetSurfaceReference", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaBoundaryModeZero", + ("hipBoundaryModeZero", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaBoundaryModeClamp", + ("hipBoundaryModeClamp", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaBoundaryModeTrap", + ("hipBoundaryModeTrap", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFormatModeForced", + ("hipFormatModeForced", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFormatModeAuto", + ("hipFormatModeAuto", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaCreateSurfaceObject", + ("hipCreateSurfaceObject", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDestroySurfaceObject", + ("hipDestroySurfaceObject", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSurfaceObjectResourceDesc", + ( + "hipGetSurfaceObjectResourceDesc", + CONV_SURFACE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaIpcCloseMemHandle", ("hipIpcCloseMemHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcGetEventHandle", ("hipIpcGetEventHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcGetMemHandle", ("hipIpcGetMemHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcOpenEventHandle", ("hipIpcOpenEventHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcOpenMemHandle", ("hipIpcOpenMemHandle", CONV_DEVICE, API_RUNTIME)), + ( + "cudaGLGetDevices", + ("hipGLGetDevices", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterBuffer", + ("hipGraphicsGLRegisterBuffer", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterImage", + ("hipGraphicsGLRegisterImage", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaWGLGetDevice", + ("hipWGLGetDevice", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsMapResources", + ("hipGraphicsMapResources", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsResourceGetMappedMipmappedArray", + ( + "hipGraphicsResourceGetMappedMipmappedArray", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsResourceGetMappedPointer", + ( + "hipGraphicsResourceGetMappedPointer", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsResourceSetMapFlags", + ( + "hipGraphicsResourceSetMapFlags", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsSubResourceGetMappedArray", + ( + "hipGraphicsSubResourceGetMappedArray", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsUnmapResources", + ("hipGraphicsUnmapResources", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsUnregisterResource", + ( + "hipGraphicsUnregisterResource", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFacePositiveX", + ( + "hipGraphicsCubeFacePositiveX", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFaceNegativeX", + ( + "hipGraphicsCubeFaceNegativeX", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFacePositiveY", + ( + "hipGraphicsCubeFacePositiveY", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFaceNegativeY", + ( + "hipGraphicsCubeFaceNegativeY", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFacePositiveZ", + ( + "hipGraphicsCubeFacePositiveZ", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFaceNegativeZ", + ( + "hipGraphicsCubeFaceNegativeZ", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsMapFlagsNone", + ("hipGraphicsMapFlagsNone", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsMapFlagsReadOnly", + ( + "hipGraphicsMapFlagsReadOnly", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsMapFlagsWriteDiscard", + ( + "hipGraphicsMapFlagsWriteDiscard", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsNone", + ( + "hipGraphicsRegisterFlagsNone", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsReadOnly", + ( + "hipGraphicsRegisterFlagsReadOnly", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsWriteDiscard", + ( + "hipGraphicsRegisterFlagsWriteDiscard", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsSurfaceLoadStore", + ( + "hipGraphicsRegisterFlagsSurfaceLoadStore", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsTextureGather", + ( + "hipGraphicsRegisterFlagsTextureGather", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGLDeviceListAll", + ("HIP_GL_DEVICE_LIST_ALL", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLDeviceListCurrentFrame", + ("HIP_GL_DEVICE_LIST_CURRENT_FRAME", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLDeviceListNextFrame", + ("HIP_GL_DEVICE_LIST_NEXT_FRAME", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLGetDevices", + ("hipGLGetDevices", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterBuffer", + ("hipGraphicsGLRegisterBuffer", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterImage", + ("hipGraphicsGLRegisterImage", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaWGLGetDevice", + ("hipWGLGetDevice", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLMapFlagsNone", + ("HIP_GL_MAP_RESOURCE_FLAGS_NONE", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLMapFlagsReadOnly", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_READ_ONLY", + CONV_GL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGLMapFlagsWriteDiscard", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + CONV_GL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGLMapBufferObject", + ("hipGLMapBufferObject__", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLMapBufferObjectAsync", + ("hipGLMapBufferObjectAsync__", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLRegisterBufferObject", + ("hipGLRegisterBufferObject", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLSetBufferObjectMapFlags", + ("hipGLSetBufferObjectMapFlags", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLSetGLDevice", + ("hipGLSetGLDevice", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLUnmapBufferObject", + ("hipGLUnmapBufferObject", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLUnmapBufferObjectAsync", + ("hipGLUnmapBufferObjectAsync", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLUnregisterBufferObject", + ("hipGLUnregisterBufferObject", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9DeviceListAll", + ("HIP_D3D9_DEVICE_LIST_ALL", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9DeviceListCurrentFrame", + ( + "HIP_D3D9_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9DeviceListNextFrame", + ( + "HIP_D3D9_DEVICE_LIST_NEXT_FRAME", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9GetDevice", + ("hipD3D9GetDevice", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9GetDevices", + ("hipD3D9GetDevices", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9GetDirect3DDevice", + ("hipD3D9GetDirect3DDevice", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9SetDirect3DDevice", + ("hipD3D9SetDirect3DDevice", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D9RegisterResource", + ( + "hipGraphicsD3D9RegisterResource", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9MapFlags", + ("hipD3D9MapFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9MapFlagsNone", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_NONE", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9MapFlagsReadOnly", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9MapFlagsWriteDiscard", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9RegisterFlagsNone", + ("HIP_D3D9_REGISTER_FLAGS_NONE", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9RegisterFlagsArray", + ("HIP_D3D9_REGISTER_FLAGS_ARRAY", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9MapResources", + ("hipD3D9MapResources", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9RegisterResource", + ("hipD3D9RegisterResource", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetMappedArray", + ("hipD3D9ResourceGetMappedArray", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetMappedPitch", + ("hipD3D9ResourceGetMappedPitch", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetMappedPointer", + ( + "hipD3D9ResourceGetMappedPointer", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9ResourceGetMappedSize", + ("hipD3D9ResourceGetMappedSize", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetSurfaceDimensions", + ( + "hipD3D9ResourceGetSurfaceDimensions", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9ResourceSetMapFlags", + ("hipD3D9ResourceSetMapFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9UnmapResources", + ("hipD3D9UnmapResources", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9UnregisterResource", + ("hipD3D9UnregisterResource", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10DeviceListAll", + ("HIP_D3D10_DEVICE_LIST_ALL", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10DeviceListCurrentFrame", + ( + "HIP_D3D10_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10DeviceListNextFrame", + ( + "HIP_D3D10_DEVICE_LIST_NEXT_FRAME", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10GetDevice", + ("hipD3D10GetDevice", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10GetDevices", + ("hipD3D10GetDevices", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D10RegisterResource", + ( + "hipGraphicsD3D10RegisterResource", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10MapFlagsNone", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_NONE", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10MapFlagsReadOnly", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10MapFlagsWriteDiscard", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10RegisterFlagsNone", + ("HIP_D3D10_REGISTER_FLAGS_NONE", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10RegisterFlagsArray", + ( + "HIP_D3D10_REGISTER_FLAGS_ARRAY", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10GetDirect3DDevice", + ("hipD3D10GetDirect3DDevice", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10MapResources", + ("hipD3D10MapResources", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10RegisterResource", + ("hipD3D10RegisterResource", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10ResourceGetMappedArray", + ( + "hipD3D10ResourceGetMappedArray", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceGetMappedPitch", + ( + "hipD3D10ResourceGetMappedPitch", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceGetMappedPointer", + ( + "hipD3D10ResourceGetMappedPointer", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceGetMappedSize", + ("hipD3D10ResourceGetMappedSize", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10ResourceGetSurfaceDimensions", + ( + "hipD3D10ResourceGetSurfaceDimensions", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceSetMapFlags", + ("hipD3D10ResourceSetMapFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10SetDirect3DDevice", + ("hipD3D10SetDirect3DDevice", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10UnmapResources", + ("hipD3D10UnmapResources", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10UnregisterResource", + ("hipD3D10UnregisterResource", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11DeviceListAll", + ("HIP_D3D11_DEVICE_LIST_ALL", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11DeviceListCurrentFrame", + ( + "HIP_D3D11_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D11DeviceListNextFrame", + ( + "HIP_D3D11_DEVICE_LIST_NEXT_FRAME", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D11GetDevice", + ("hipD3D11GetDevice", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11GetDevices", + ("hipD3D11GetDevices", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D11RegisterResource", + ( + "hipGraphicsD3D11RegisterResource", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D11GetDevice", + ("hipD3D11GetDevice", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11GetDevices", + ("hipD3D11GetDevices", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D11RegisterResource", + ( + "hipGraphicsD3D11RegisterResource", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsVDPAURegisterOutputSurface", + ( + "hipGraphicsVDPAURegisterOutputSurface", + CONV_VDPAU, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsVDPAURegisterVideoSurface", + ( + "hipGraphicsVDPAURegisterVideoSurface", + CONV_VDPAU, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaVDPAUGetDevice", + ("hipVDPAUGetDevice", CONV_VDPAU, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaVDPAUSetVDPAUDevice", + ("hipVDPAUSetDevice", CONV_VDPAU, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamConsumerAcquireFrame", + ( + "hipEGLStreamConsumerAcquireFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamConsumerConnect", + ("hipEGLStreamConsumerConnect", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamConsumerConnectWithFlags", + ( + "hipEGLStreamConsumerConnectWithFlags", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamConsumerReleaseFrame", + ( + "hipEGLStreamConsumerReleaseFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamProducerConnect", + ("hipEGLStreamProducerConnect", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamProducerDisconnect", + ("hipEGLStreamProducerDisconnect", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamProducerPresentFrame", + ( + "hipEGLStreamProducerPresentFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamProducerReturnFrame", + ("hipEGLStreamProducerReturnFrame", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsEGLRegisterImage", + ("hipGraphicsEGLRegisterImage", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsResourceGetMappedEglFrame", + ( + "hipGraphicsResourceGetMappedEglFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cublasInit", ("hipblasInit", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasShutdown", + ("hipblasShutdown", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetVersion", + ("hipblasGetVersion", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetError", + ("hipblasGetError", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasAlloc", ("hipblasAlloc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasFree", ("hipblasFree", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSetKernelStream", + ("hipblasSetKernelStream", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetAtomicsMode", + ("hipblasGetAtomicsMode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSetAtomicsMode", + ("hipblasSetAtomicsMode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetMathMode", + ("hipblasGetMathMode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSetMathMode", + ("hipblasSetMathMode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("CUBLAS_OP_N", ("HIPBLAS_OP_N", CONV_NUMERIC_LITERAL, API_BLAS)), + ( + "CUBLAS_OP_T", + ("HIPBLAS_OP_T", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_OP_C", + ("HIPBLAS_OP_C", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_SUCCESS", + ("HIPBLAS_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_NOT_INITIALIZED", + ("HIPBLAS_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_ALLOC_FAILED", + ("HIPBLAS_STATUS_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_INVALID_VALUE", + ("HIPBLAS_STATUS_INVALID_VALUE", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_MAPPING_ERROR", + ("HIPBLAS_STATUS_MAPPING_ERROR", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_EXECUTION_FAILED", + ("HIPBLAS_STATUS_EXECUTION_FAILED", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_INTERNAL_ERROR", + ("HIPBLAS_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_NOT_SUPPORTED", + ("HIPBLAS_STATUS_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_ARCH_MISMATCH", + ("HIPBLAS_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_FILL_MODE_LOWER", + ("HIPBLAS_FILL_MODE_LOWER", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_FILL_MODE_UPPER", + ("HIPBLAS_FILL_MODE_UPPER", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_DIAG_NON_UNIT", + ("HIPBLAS_DIAG_NON_UNIT", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ("CUBLAS_DIAG_UNIT", ("HIPBLAS_DIAG_UNIT", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUBLAS_SIDE_LEFT", ("HIPBLAS_SIDE_LEFT", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUBLAS_SIDE_RIGHT", ("HIPBLAS_SIDE_RIGHT", CONV_NUMERIC_LITERAL, API_BLAS)), + ( + "CUBLAS_POINTER_MODE_HOST", + ("HIPBLAS_POINTER_MODE_HOST", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_POINTER_MODE_DEVICE", + ("HIPBLAS_POINTER_MODE_DEVICE", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_ATOMICS_NOT_ALLOWED", + ( + "HIPBLAS_ATOMICS_NOT_ALLOWED", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_ATOMICS_ALLOWED", + ( + "HIPBLAS_ATOMICS_ALLOWED", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_DATA_FLOAT", + ( + "HIPBLAS_DATA_FLOAT", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_DATA_DOUBLE", + ( + "HIPBLAS_DATA_DOUBLE", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_DATA_HALF", + ("HIPBLAS_DATA_HALF", CONV_NUMERIC_LITERAL, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "CUBLAS_DATA_INT8", + ("HIPBLAS_DATA_INT8", CONV_NUMERIC_LITERAL, API_BLAS, HIP_UNSUPPORTED), + ), + ("CUBLAS_GEMM_DEFAULT", ("HIPBLAS_GEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUBLAS_GEMM_DEFAULT_TENSOR_OP", ("HIPBLAS_GEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_BLAS)), + ("cublasCreate", ("hipblasCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasDestroy", ("hipblasDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasSetVector", ("hipblasSetVector", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetVector", ("hipblasGetVector", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSetVectorAsync", + ("hipblasSetVectorAsync", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetVectorAsync", + ("hipblasGetVectorAsync", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSetMatrix", ("hipblasSetMatrix", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetMatrix", ("hipblasGetMatrix", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasGetMatrixAsync", + ("hipblasGetMatrixAsync", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSetMatrixAsync", + ("hipblasSetMatrixAsync", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasXerbla", ("hipblasXerbla", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSnrm2", ("hipblasSnrm2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDnrm2", ("hipblasDnrm2", CONV_MATH_FUNC, API_BLAS)), + ("cublasScnrm2", ("hipblasScnrm2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDznrm2", ("hipblasDznrm2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasNrm2Ex", + ("hipblasNrm2Ex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSdot", ("hipblasSdot", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSdotBatched", + ("hipblasSdotBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDdot", ("hipblasDdot", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDdotBatched", + ("hipblasDdotBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasCdotu", ("hipblasCdotu", CONV_MATH_FUNC, API_BLAS)), + ("cublasCdotc", ("hipblasCdotc", CONV_MATH_FUNC, API_BLAS)), + ("cublasZdotu", ("hipblasZdotu", CONV_MATH_FUNC, API_BLAS)), + ("cublasZdotc", ("hipblasZdotc", CONV_MATH_FUNC, API_BLAS)), + ("cublasSscal", ("hipblasSscal", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSscalBatched", + ("hipblasSscalBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDscal", ("hipblasDscal", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDscalBatched", + ("hipblasDscalBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasCscal", ("hipblasCscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsscal", ("hipblasCsscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZscal", ("hipblasZscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdscal", ("hipblasZdscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSaxpy", ("hipblasSaxpy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSaxpyBatched", + ("hipblasSaxpyBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDaxpy", ("hipblasDaxpy", CONV_MATH_FUNC, API_BLAS)), + ("cublasCaxpy", ("hipblasCaxpy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZaxpy", ("hipblasZaxpy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasScopy", ("hipblasScopy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasScopyBatched", + ("hipblasScopyBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDcopy", ("hipblasDcopy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDcopyBatched", + ("hipblasDcopyBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasCcopy", ("hipblasCcopy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZcopy", ("hipblasZcopy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSswap", ("hipblasSswap", CONV_MATH_FUNC, API_BLAS)), + ("cublasDswap", ("hipblasDswap", CONV_MATH_FUNC, API_BLAS)), + ("cublasCswap", ("hipblasCswap", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZswap", ("hipblasZswap", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIsamax", ("hipblasIsamax", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamax", ("hipblasIdamax", CONV_MATH_FUNC, API_BLAS)), + ("cublasIcamax", ("hipblasIcamax", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIzamax", ("hipblasIzamax", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIsamin", ("hipblasIsamin", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamin", ("hipblasIdamin", CONV_MATH_FUNC, API_BLAS)), + ("cublasIcamin", ("hipblasIcamin", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIzamin", ("hipblasIzamin", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSasum", ("hipblasSasum", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSasumBatched", + ("hipblasSasumBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDasum", ("hipblasDasum", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDasumBatched", + ("hipblasDasumBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasScasum", ("hipblasScasum", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDzasum", ("hipblasDzasum", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrot", ("hipblasSrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrot", ("hipblasDrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCrot", ("hipblasCrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsrot", ("hipblasCsrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZrot", ("hipblasZrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdrot", ("hipblasZdrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrotg", ("hipblasSrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrotg", ("hipblasDrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCrotg", ("hipblasCrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZrotg", ("hipblasZrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrotm", ("hipblasSrotm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrotm", ("hipblasDrotm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrotmg", ("hipblasSrotmg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrotmg", ("hipblasDrotmg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSgemv", ("hipblasSgemv", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSgemvBatched", + ("hipblasSgemvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDgemv", ("hipblasDgemv", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgemv", ("hipblasCgemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgemv", ("hipblasZgemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSgbmv", ("hipblasSgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDgbmv", ("hipblasDgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCgbmv", ("hipblasCgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgbmv", ("hipblasZgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrmv", ("hipblasStrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrmv", ("hipblasDtrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrmv", ("hipblasCtrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrmv", ("hipblasZtrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStbmv", ("hipblasStbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtbmv", ("hipblasDtbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtbmv", ("hipblasCtbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtbmv", ("hipblasZtbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStpmv", ("hipblasStpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtpmv", ("hipblasDtpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtpmv", ("hipblasCtpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtpmv", ("hipblasZtpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrsv", ("hipblasStrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrsv", ("hipblasDtrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrsv", ("hipblasCtrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrsv", ("hipblasZtrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStpsv", ("hipblasStpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtpsv", ("hipblasDtpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtpsv", ("hipblasCtpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtpsv", ("hipblasZtpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStbsv", ("hipblasStbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtbsv", ("hipblasDtbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtbsv", ("hipblasCtbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtbsv", ("hipblasZtbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsymv", ("hipblasSsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsymv", ("hipblasDsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsymv", ("hipblasCsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsymv", ("hipblasZsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChemv", ("hipblasChemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhemv", ("hipblasZhemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsbmv", ("hipblasSsbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsbmv", ("hipblasDsbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChbmv", ("hipblasChbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhbmv", ("hipblasZhbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspmv", ("hipblasSspmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspmv", ("hipblasDspmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpmv", ("hipblasChpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpmv", ("hipblasZhpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSger", ("hipblasSger", CONV_MATH_FUNC, API_BLAS)), + ("cublasDger", ("hipblasDger", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgeru", ("hipblasCgeru", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCgerc", ("hipblasCgerc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgeru", ("hipblasZgeru", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgerc", ("hipblasZgerc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyr", ("hipblasSsyr", CONV_MATH_FUNC, API_BLAS)), + ("cublasDsyr", ("hipblasDsyr", CONV_MATH_FUNC, API_BLAS)), + ("cublasCher", ("hipblasCher", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher", ("hipblasZher", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspr", ("hipblasSspr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspr", ("hipblasDspr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpr", ("hipblasChpr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpr", ("hipblasZhpr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyr2", ("hipblasSsyr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyr2", ("hipblasDsyr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCher2", ("hipblasCher2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher2", ("hipblasZher2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspr2", ("hipblasSspr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspr2", ("hipblasDspr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpr2", ("hipblasChpr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpr2", ("hipblasZhpr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSgemmBatched", + ("hipblasSgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgemmBatched", + ("hipblasDgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasHgemmBatched", + ("hipblasHgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgemmStridedBatched", + ("hipblasSgemmStridedBatched", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasDgemmStridedBatched", + ("hipblasDgemmStridedBatched", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasHgemmStridedBatched", + ("hipblasHgemmStridedBatched", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasCgemmBatched", + ("hipblasCgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemm3mBatched", + ("hipblasCgemm3mBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemmBatched", + ("hipblasZgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemmStridedBatched", + ( + "hipblasCgemmStridedBatched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "cublasCgemm3mStridedBatched", + ( + "hipblasCgemm3mStridedBatched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "cublasZgemmStridedBatched", + ( + "hipblasZgemmStridedBatched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "cublasHgemmStridedBatched", + ( + "hipblasHgemmStridedBatched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ("cublasSgemm", ("hipblasSgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgemm", ("hipblasDgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgemm", ("hipblasCgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasZgemm", ("hipblasZgemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasHgemm", ("hipblasHgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasSsyrk", ("hipblasSsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyrk", ("hipblasDsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyrk", ("hipblasCsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyrk", ("hipblasZsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCherk", ("hipblasCherk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZherk", ("hipblasZherk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyr2k", ("hipblasSsyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyr2k", ("hipblasDsyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyr2k", ("hipblasCsyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyr2k", ("hipblasZyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyrkx", ("hipblasSsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyrkx", ("hipblasDsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyrkx", ("hipblasCsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyrkx", ("hipblasZsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCher2k", ("hipblasCher2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher2k", ("hipblasZher2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCherkx", ("hipblasCherkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZherkx", ("hipblasZherkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsymm", ("hipblasSsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsymm", ("hipblasDsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsymm", ("hipblasCsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsymm", ("hipblasZsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChemm", ("hipblasChemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhemm", ("hipblasZhemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrsm", ("hipblasStrsm", CONV_MATH_FUNC, API_BLAS)), + ("cublasDtrsm", ("hipblasDtrsm", CONV_MATH_FUNC, API_BLAS)), + ("cublasCtrsm", ("hipblasCtrsm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrsm", ("hipblasZtrsm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasStrsmBatched", + ("hipblasStrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsmBatched", + ("hipblasDtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsmBatched", + ("hipblasCtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsmBatched", + ("hipblasZtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasStrmm", ("hipblasStrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrmm", ("hipblasDtrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrmm", ("hipblasCtrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrmm", ("hipblasZtrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSgeam", ("hipblasSgeam", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgeam", ("hipblasDgeam", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgeam", ("hipblasCgeam", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgeam", ("hipblasZgeam", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSgetrfBatched", + ("hipblasSgetrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgetrfBatched", + ("hipblasDgetrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgetrfBatched", + ("hipblasCgetrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgetrfBatched", + ("hipblasZgetrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgetriBatched", + ("hipblasSgetriBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgetriBatched", + ("hipblasDgetriBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgetriBatched", + ("hipblasCgetriBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgetriBatched", + ("hipblasZgetriBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgetrsBatched", + ("hipblasSgetrsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgetrsBatched", + ("hipblasDgetrsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgetrsBatched", + ("hipblasCgetrsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgetrsBatched", + ("hipblasZgetrsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrsmBatched", + ("hipblasStrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsmBatched", + ("hipblasDtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsmBatched", + ("hipblasCtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsmBatched", + ("hipblasZtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSmatinvBatched", + ("hipblasSmatinvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDmatinvBatched", + ("hipblasDmatinvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCmatinvBatched", + ("hipblasCmatinvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZmatinvBatched", + ("hipblasZmatinvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgeqrfBatched", + ("hipblasSgeqrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgeqrfBatched", + ("hipblasDgeqrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgeqrfBatched", + ("hipblasCgeqrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgeqrfBatched", + ("hipblasZgeqrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgelsBatched", + ("hipblasSgelsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgelsBatched", + ("hipblasDgelsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgelsBatched", + ("hipblasCgelsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgelsBatched", + ("hipblasZgelsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSdgmm", ("hipblasSdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDdgmm", ("hipblasDdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCdgmm", ("hipblasCdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdgmm", ("hipblasZdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStpttr", ("hipblasStpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtpttr", ("hipblasDtpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtpttr", ("hipblasCtpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtpttr", ("hipblasZtpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrttp", ("hipblasStrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrttp", ("hipblasDtrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrttp", ("hipblasCtrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrttp", ("hipblasZtrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCreate_v2", ("hipblasCreate_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDestroy_v2", ("hipblasDestroy_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasGetVersion_v2", + ("hipblasGetVersion_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSetWorkspace", ("hipblasSetWorkspace", CONV_MATH_FUNC, API_BLAS)), + ("cublasSetStream", ("hipblasSetStream", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetStream", ("hipblasGetStream", CONV_MATH_FUNC, API_BLAS)), + ("cublasSetStream_v2", ("hipblasSetStream_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetStream_v2", ("hipblasGetStream_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasGetPointerMode", + ("hipblasGetPointerMode", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasSetPointerMode", + ("hipblasSetPointerMode", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasGetPointerMode_v2", + ("hipblasGetPointerMode_v2", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasSetPointerMode_v2", + ("hipblasSetPointerMode_v2", CONV_MATH_FUNC, API_BLAS), + ), + ("cublasSgemv_v2", ("hipblasSgemv_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgemv_v2", ("hipblasDgemv_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCgemv_v2", + ("hipblasCgemv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemv_v2", + ("hipblasZgemv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgbmv_v2", + ("hipblasSgbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgbmv_v2", + ("hipblasDgbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgbmv_v2", + ("hipblasCgbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgbmv_v2", + ("hipblasZgbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrmv_v2", + ("hipblasStrmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrmv_v2", + ("hipblasDtrmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrmv_v2", + ("hipblasCtrmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrmv_v2", + ("hipblasZtrmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStbmv_v2", + ("hipblasStbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtbmv_v2", + ("hipblasDtbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtbmv_v2", + ("hipblasCtbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtbmv_v2", + ("hipblasZtbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStpmv_v2", + ("hipblasStpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtpmv_v2", + ("hipblasDtpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtpmv_v2", + ("hipblasCtpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtpmv_v2", + ("hipblasZtpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrsv_v2", + ("hipblasStrsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsv_v2", + ("hipblasDtrsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsv_v2", + ("hipblasCtrsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsv_v2", + ("hipblasZtrsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStpsv_v2", + ("hipblasStpsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtpsv_v2", + ("hipblasDtpsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtpsv_v2", + ("hipblasCtpsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtpsv_v2", + ("hipblasZtpsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStbsv_v2", + ("hipblasStbsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtbsv_v2", + ("hipblasDtbsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtbsv_v2", + ("hipblasCtbsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtbsv_v2", + ("hipblasZtbsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsymv_v2", + ("hipblasSsymv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsymv_v2", + ("hipblasDsymv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsymv_v2", + ("hipblasCsymv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsymv_v2", + ("hipblasZsymv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChemv_v2", + ("hipblasChemv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhemv_v2", + ("hipblasZhemv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsbmv_v2", + ("hipblasSsbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsbmv_v2", + ("hipblasDsbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChbmv_v2", + ("hipblasChbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhbmv_v2", + ("hipblasZhbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSspmv_v2", + ("hipblasSspmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDspmv_v2", + ("hipblasDspmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChpmv_v2", + ("hipblasChpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhpmv_v2", + ("hipblasZhpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSger_v2", ("hipblasSger_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDger_v2", ("hipblasDger_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCgeru_v2", + ("hipblasCgeru_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgerc_v2", + ("hipblasCergc_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgeru_v2", + ("hipblasZgeru_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgerc_v2", + ("hipblasZgerc_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSsyr_v2", ("hipblasSsyr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyr_v2", ("hipblasDsyr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyr_v2", ("hipblasCsyr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyr_v2", ("hipblasZsyr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCher_v2", ("hipblasCher_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher_v2", ("hipblasZher_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspr_v2", ("hipblasSspr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspr_v2", ("hipblasDspr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpr_v2", ("hipblasChpr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpr_v2", ("hipblasZhpr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSsyr2_v2", + ("hipblasSsyr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsyr2_v2", + ("hipblasDsyr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyr2_v2", + ("hipblasCsyr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsyr2_v2", + ("hipblasZsyr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCher2_v2", + ("hipblasCher2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZher2_v2", + ("hipblasZher2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSspr2_v2", + ("hipblasSspr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDspr2_v2", + ("hipblasDspr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChpr2_v2", + ("hipblasChpr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhpr2_v2", + ("hipblasZhpr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSgemm_v2", ("hipblasSgemm_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgemm_v2", ("hipblasDgemm_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCgemm_v2", + ("hipblasCgemm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemm3m", + ("hipblasCgemm3m", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemm3mEx", + ("hipblasCgemm3mEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemm_v2", + ("hipblasZgemm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemm3m", + ("hipblasZgemm3m", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgemmEx", + ("hipblasSgemmEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasGemmEx", ("hipblasGemmEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasGemmBatchedEx", + ("hipblasGemmBatchedEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGemmStridedBatchedEx", + ("hipblasGemmStridedBatchedEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemmEx", + ("hipblasCgemmEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasUint8gemmBias", + ("hipblasUint8gemmBias", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsyrk_v2", + ("hipblasSsyrk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsyrk_v2", + ("hipblasDsyrk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyrk_v2", + ("hipblasCsyrk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsyrk_v2", + ("hipblasZsyrk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyrkEx", + ("hipblasCsyrkEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyrk3mEx", + ("hipblasCsyrk3mEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCherk_v2", + ("hipblasCherk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCherkEx", + ("hipblasCherkEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCherk3mEx", + ("hipblasCherk3mEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZherk_v2", + ("hipblasZherk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsyr2k_v2", + ("hipblasSsyr2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsyr2k_v2", + ("hipblasDsyr2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyr2k_v2", + ("hipblasCsyr2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsyr2k_v2", + ("hipblasZsyr2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCher2k_v2", + ("hipblasCher2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZher2k_v2", + ("hipblasZher2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsymm_v2", + ("hipblasSsymm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsymm_v2", + ("hipblasDsymm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsymm_v2", + ("hipblasCsymm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsymm_v2", + ("hipblasZsymm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChemm_v2", + ("hipblasChemm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhemm_v2", + ("hipblasZhemm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrsm_v2", + ("hipblasStrsm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsm_v2", + ("hipblasDtrsm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsm_v2", + ("hipblasCtrsm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsm_v2", + ("hipblasZtrsm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrmm_v2", + ("hipblasStrmm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrmm_v2", + ("hipblasDtrmm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrmm_v2", + ("hipblasCtrmm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrmm_v2", + ("hipblasZtrmm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSnrm2_v2", ("hipblasSnrm2_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDnrm2_v2", ("hipblasDnrm2_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasScnrm2_v2", + ("hipblasScnrm2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDznrm2_v2", + ("hipblasDznrm2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDotEx", ("hipblasDotEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDotcEx", ("hipblasDotcEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSdot_v2", ("hipblasSdot_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDdot_v2", ("hipblasDdot_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCdotu_v2", + ("hipblasCdotu_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCdotc_v2", + ("hipblasCdotc_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZdotu_v2", + ("hipblasZdotu_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZdotc_v2", + ("hipblasZdotc_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasScalEx", ("hipblasScalEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSscal_v2", ("hipblasSscal_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDscal_v2", ("hipblasDscal_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCscal_v2", + ("hipblasCscal_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsscal_v2", + ("hipblasCsscal_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZscal_v2", + ("hipblasZcsal_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZdscal_v2", + ("hipblasZdscal_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasAxpyEx", ("hipblasAxpyEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSaxpy_v2", ("hipblasSaxpy_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDaxpy_v2", ("hipblasDaxpy_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCaxpy_v2", + ("hipblasCaxpy_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZaxpy_v2", + ("hipblasZaxpy_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasScopy_v2", ("hipblasScopy_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDcopy_v2", ("hipblasDcopy_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCcopy_v2", + ("hipblasCcopy_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZcopy_v2", + ("hipblasZcopy_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSswap_v2", ("hipblasSswap_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDswap_v2", ("hipblasDswap_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCswap_v2", + ("hipblasCswap_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZswap_v2", + ("hipblasZswap_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasIsamax_v2", ("hipblasIsamax_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamax_v2", ("hipblasIdamax_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasIcamax_v2", + ("hipblasIcamax_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasIzamax_v2", + ("hipblasIzamax_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasIsamin_v2", ("hipblasIsamin_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamin_v2", ("hipblasIdamin_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasIcamin_v2", + ("hipblasIcamin_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasIzamin_v2", + ("hipblasIzamin_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSasum_v2", ("hipblasSasum_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDasum_v2", ("hipblasDasum_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasScasum_v2", + ("hipblasScasum_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDzasum_v2", + ("hipblasDzasum_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSrot_v2", ("hipblasSrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrot_v2", ("hipblasDrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCrot_v2", ("hipblasCrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasCsrot_v2", + ("hipblasCsrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasZrot_v2", ("hipblasZrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasZdrot_v2", + ("hipblasZdrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSrotg_v2", + ("hipblasSrotg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDrotg_v2", + ("hipblasDrotg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCrotg_v2", + ("hipblasCrotg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZrotg_v2", + ("hipblasZrotg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSrotm_v2", + ("hipblasSrotm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDrotm_v2", + ("hipblasDrotm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSrotmg_v2", + ("hipblasSrotmg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDrotmg_v2", + ("hipblasDrotmg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasComputeType_t", + ("hipblasComputeType_t", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_32I", + ("HIPBLAS_COMPUTE_32I", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_32F", + ("HIPBLAS_COMPUTE_32F", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_32F_FAST_TF32", + ("HIPBLAS_COMPUTE_32F_FAST_TF32", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_64F", + ("HIPBLAS_COMPUTE_64F", CONV_MATH_FUNC, API_BLAS) + ), + ("cublasLtEpilogue_t", ("hipblasLtEpilogue_t", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_DEFAULT", ("HIPBLASLT_EPILOGUE_DEFAULT", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_RELU", ("HIPBLASLT_EPILOGUE_RELU", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_BIAS", ("HIPBLASLT_EPILOGUE_BIAS", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_RELU_BIAS", ("HIPBLASLT_EPILOGUE_RELU_BIAS", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_GELU", ("HIPBLASLT_EPILOGUE_GELU", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_GELU_BIAS", ("HIPBLASLT_EPILOGUE_GELU_BIAS", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtHandle_t", ("hipblasLtHandle_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDesc_t", ("hipblasLtMatmulDesc_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescOpaque_t", ("hipblasLtMatmulDescOpaque_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescAttributes_t", ("hipblasLtMatmulDescAttributes_t", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_TRANSA", ("HIPBLASLT_MATMUL_DESC_TRANSA", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_TRANSB", ("HIPBLASLT_MATMUL_DESC_TRANSB", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_EPILOGUE", ("HIPBLASLT_MATMUL_DESC_EPILOGUE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_BIAS_POINTER", ("HIPBLASLT_MATMUL_DESC_BIAS_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_A_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_B_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_D_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", ("HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_POINTER_MODE", ("HIPBLASLT_MATMUL_DESC_POINTER_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_POINTER_MODE_DEVICE", ("HIPBLASLT_POINTER_MODE_DEVICE", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUBLASLT_POINTER_MODE_HOST", ("HIPBLASLT_POINTER_MODE_HOST", CONV_NUMERIC_LITERAL, API_BLAS)), + ("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutOpaque_t", ("hipblasLtMatrixLayoutOpaque_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutAttribute_t", ("hipblasLtMatrixLayoutAttribute_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutCreate", ("hipblasLtMatrixLayoutCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutDestroy", ("hipblasLtMatrixLayoutDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutSetAttribute", ("hipblasLtMatrixLayoutSetAttribute", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT", ("HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET", ("HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreference_t", ("hipblasLtMatmulPreference_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceOpaque_t", ("hipblasLtMatmulPreferenceOpaque_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceAttributes_t", ("hipblasLtMatmulPreferenceAttributes_t", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_PREF_SEARCH_MODE", ("HIPBLASLT_MATMUL_PREF_SEARCH_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES", ("HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulAlgo_t", ("hipblasLtMatmulAlgo_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulHeuristicResult_t", ("hipblasLtMatmulHeuristicResult_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtCreate", ("hipblasLtCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtDestroy", ("hipblasLtDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescCreate", ("hipblasLtMatmulDescCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescDestroy", ("hipblasLtMatmulDescDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescSetAttribute", ("hipblasLtMatmulDescSetAttribute", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceCreate", ("hipblasLtMatmulPreferenceCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceDestroy", ("hipblasLtMatmulPreferenceDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceSetAttribute", ("hipblasLtMatmulPreferenceSetAttribute", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulAlgoGetHeuristic", ("hipblasLtMatmulAlgoGetHeuristic", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmul", ("hipblasLtMatmul", CONV_MATH_FUNC, API_BLAS)), + ( + "CURAND_STATUS_SUCCESS", + ("HIPRAND_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_VERSION_MISMATCH", + ("HIPRAND_STATUS_VERSION_MISMATCH", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_NOT_INITIALIZED", + ("HIPRAND_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_ALLOCATION_FAILED", + ("HIPRAND_STATUS_ALLOCATION_FAILED", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_TYPE_ERROR", + ("HIPRAND_STATUS_TYPE_ERROR", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_OUT_OF_RANGE", + ("HIPRAND_STATUS_OUT_OF_RANGE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_LENGTH_NOT_MULTIPLE", + ("HIPRAND_STATUS_LENGTH_NOT_MULTIPLE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED", + ( + "HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED", + CONV_NUMERIC_LITERAL, + API_RAND, + ), + ), + ( + "CURAND_STATUS_LAUNCH_FAILURE", + ("HIPRAND_STATUS_LAUNCH_FAILURE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_PREEXISTING_FAILURE", + ("HIPRAND_STATUS_PREEXISTING_FAILURE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_INITIALIZATION_FAILED", + ("HIPRAND_STATUS_INITIALIZATION_FAILED", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_ARCH_MISMATCH", + ("HIPRAND_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_INTERNAL_ERROR", + ("HIPRAND_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_RAND), + ), + ("CURAND_RNG_TEST", ("HIPRAND_RNG_TEST", CONV_NUMERIC_LITERAL, API_RAND)), + ( + "mtgp32dc_params_fast_11213", + ("mtgp32dc_params_fast_11213", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_DEFAULT", + ("HIPRAND_RNG_PSEUDO_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_XORWOW", + ("HIPRAND_RNG_PSEUDO_XORWOW", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_MRG32K3A", + ("HIPRAND_RNG_PSEUDO_MRG32K3A", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_MTGP32", + ("HIPRAND_RNG_PSEUDO_MTGP32", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_MT19937", + ("HIPRAND_RNG_PSEUDO_MT19937", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_PHILOX4_32_10", + ("HIPRAND_RNG_PSEUDO_PHILOX4_32_10", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_DEFAULT", + ("HIPRAND_RNG_QUASI_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SOBOL32", + ("HIPRAND_RNG_QUASI_SOBOL32", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SCRAMBLED_SOBOL32", + ("HIPRAND_RNG_QUASI_SCRAMBLED_SOBOL32", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SOBOL64", + ("HIPRAND_RNG_QUASI_SOBOL64", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SCRAMBLED_SOBOL64", + ("HIPRAND_RNG_QUASI_SCRAMBLED_SOBOL64", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "curand_ORDERING_PSEUDO_BEST", + ( + "HIPRAND_ORDERING_PSEUDO_BEST", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_ORDERING_PSEUDO_DEFAULT", + ( + "HIPRAND_ORDERING_PSEUDO_DEFAULT", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_ORDERING_PSEUDO_SEEDED", + ( + "HIPRAND_ORDERING_PSEUDO_SEEDED", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_ORDERING_QUASI_DEFAULT", + ( + "HIPRAND_ORDERING_QUASI_DEFAULT", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_DIRECTION_VECTORS_32_JOEKUO6", + ( + "HIPRAND_DIRECTION_VECTORS_32_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_SCRAMBLED_DIRECTION_VECTORS_32_JOEKUO6", + ( + "HIPRAND_SCRAMBLED_DIRECTION_VECTORS_32_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_DIRECTION_VECTORS_64_JOEKUO6", + ( + "HIPRAND_DIRECTION_VECTORS_64_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_SCRAMBLED_DIRECTION_VECTORS_64_JOEKUO6", + ( + "HIPRAND_SCRAMBLED_DIRECTION_VECTORS_64_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_CHOOSE_BEST", + ("HIPRAND_CHOOSE_BEST", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_ITR", + ("HIPRAND_ITR", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_KNUTH", + ("HIPRAND_KNUTH", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_HITR", + ("HIPRAND_HITR", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ("curand_M1", ("HIPRAND_M1", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED)), + ("curand_M2", ("HIPRAND_M2", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED)), + ( + "curand_BINARY_SEARCH", + ("HIPRAND_BINARY_SEARCH", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_DISCRETE_GAUSS", + ("HIPRAND_DISCRETE_GAUSS", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_REJECTION", + ("HIPRAND_REJECTION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_DEVICE_API", + ("HIPRAND_DEVICE_API", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_FAST_REJECTION", + ("HIPRAND_FAST_REJECTION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_3RD", + ("HIPRAND_3RD", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_DEFINITION", + ("HIPRAND_DEFINITION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_POISSON", + ("HIPRAND_POISSON", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ("curandCreateGenerator", ("hiprandCreateGenerator", CONV_MATH_FUNC, API_RAND)), + ( + "curandCreateGeneratorHost", + ("hiprandCreateGeneratorHost", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandCreatePoissonDistribution", + ("hiprandCreatePoissonDistribution", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandDestroyDistribution", + ("hiprandDestroyDistribution", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandDestroyGenerator", + ("hiprandDestroyGenerator", CONV_MATH_FUNC, API_RAND), + ), + ("curandGenerate", ("hiprandGenerate", CONV_MATH_FUNC, API_RAND)), + ( + "curandGenerateLogNormal", + ("hiprandGenerateLogNormal", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandGenerateLogNormalDouble", + ("hiprandGenerateLogNormalDouble", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandGenerateLongLong", + ("hiprandGenerateLongLong", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ("curandGenerateNormal", ("hiprandGenerateNormal", CONV_MATH_FUNC, API_RAND)), + ( + "curandGenerateNormalDouble", + ("hiprandGenerateNormalDouble", CONV_MATH_FUNC, API_RAND), + ), + ("curandGeneratePoisson", ("hiprandGeneratePoisson", CONV_MATH_FUNC, API_RAND)), + ("curandGenerateSeeds", ("hiprandGenerateSeeds", CONV_MATH_FUNC, API_RAND)), + ("curandGenerateUniform", ("hiprandGenerateUniform", CONV_MATH_FUNC, API_RAND)), + ( + "curandGenerateUniformDouble", + ("hiprandGenerateUniformDouble", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandGetDirectionVectors32", + ("hiprandGetDirectionVectors32", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandGetDirectionVectors64", + ("hiprandGetDirectionVectors64", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandGetProperty", + ("hiprandGetProperty", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandGetScrambleConstants32", + ( + "hiprandGetScrambleConstants32", + CONV_MATH_FUNC, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curandGetScrambleConstants64", + ( + "hiprandGetScrambleConstants64", + CONV_MATH_FUNC, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ("curandGetVersion", ("hiprandGetVersion", CONV_MATH_FUNC, API_RAND)), + ( + "curandSetGeneratorOffset", + ("hiprandSetGeneratorOffset", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandSetGeneratorOrdering", + ("hiprandSetGeneratorOrdering", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandSetPseudoRandomGeneratorSeed", + ("hiprandSetPseudoRandomGeneratorSeed", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandSetQuasiRandomGeneratorDimensions", + ("hiprandSetQuasiRandomGeneratorDimensions", CONV_MATH_FUNC, API_RAND), + ), + ("curandSetStream", ("hiprandSetStream", CONV_MATH_FUNC, API_RAND)), + ("curand", ("hiprand", CONV_DEVICE_FUNC, API_RAND)), + ("curand4", ("hiprand4", CONV_DEVICE_FUNC, API_RAND)), + ("curand_init", ("hiprand_init", CONV_DEVICE_FUNC, API_RAND)), + ("curand_log_normal", ("hiprand_log_normal", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_log_normal_double", + ("hiprand_log_normal_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_log_normal2", ("hiprand_log_normal2", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_log_normal2_double", + ("hiprand_log_normal2_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_log_normal4", ("hiprand_log_normal4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_log_normal4_double", + ("hiprand_log_normal4_double", CONV_DEVICE_FUNC, API_RAND), + ), + ( + "curand_mtgp32_single", + ("hiprand_mtgp32_single", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_mtgp32_single_specific", + ( + "hiprand_mtgp32_single_specific", + CONV_DEVICE_FUNC, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_mtgp32_specific", + ("hiprand_mtgp32_specific", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ("curand_normal", ("hiprand_normal", CONV_DEVICE_FUNC, API_RAND)), + ( + "curandMakeMTGP32Constants", + ("hiprandMakeMTGP32Constants", CONV_DEVICE_FUNC, API_RAND), + ), + ( + "curandMakeMTGP32KernelState", + ("hiprandMakeMTGP32KernelState", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_normal_double", ("hiprand_normal_double", CONV_DEVICE_FUNC, API_RAND)), + ("curand_normal2", ("hiprand_normal2", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_normal2_double", + ("hiprand_normal2_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_normal4", ("hiprand_normal4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_normal4_double", + ("hiprand_normal4_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_uniform", ("hiprand_uniform", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_uniform_double", + ("hiprand_uniform_double", CONV_DEVICE_FUNC, API_RAND), + ), + ( + "curand_uniform2_double", + ("hiprand_uniform2_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_uniform4", ("hiprand_uniform4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_uniform4_double", + ("hiprand_uniform4_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_discrete", ("hiprand_discrete", CONV_DEVICE_FUNC, API_RAND)), + ("curand_discrete4", ("hiprand_discrete4", CONV_DEVICE_FUNC, API_RAND)), + ("curand_poisson", ("hiprand_poisson", CONV_DEVICE_FUNC, API_RAND)), + ("curand_poisson4", ("hiprand_poisson4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_Philox4x32_10", + ("hiprand_Philox4x32_10", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ("mtgp32_kernel_params", ("mtgp32_kernel_params_t", CONV_MATH_FUNC, API_RAND)), + ("CUFFT_FORWARD", ("HIPFFT_FORWARD", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUFFT_INVERSE", ("HIPFFT_BACKWARD", CONV_NUMERIC_LITERAL, API_BLAS)), + ( + "CUFFT_COMPATIBILITY_DEFAULT", + ( + "HIPFFT_COMPATIBILITY_DEFAULT", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ("cuComplex", ("hipComplex", CONV_TYPE, API_BLAS)), + ("cuDoubleComplex", ("hipDoubleComplex", CONV_TYPE, API_BLAS)), + ("cufftResult_t", ("hipfftResult_t", CONV_TYPE, API_FFT)), + ("cufftResult", ("hipfftResult", CONV_TYPE, API_FFT)), + ("CUFFT_SUCCESS", ("HIPFFT_SUCCESS", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_INVALID_PLAN", ("HIPFFT_INVALID_PLAN", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_ALLOC_FAILED", ("HIPFFT_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_INVALID_TYPE", ("HIPFFT_INVALID_TYPE", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "CUFFT_INVALID_VALUE", + ("HIPFFT_INVALID_VALUE", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_INTERNAL_ERROR", + ("HIPFFT_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_FFT), + ), + ("CUFFT_EXEC_FAILED", ("HIPFFT_EXEC_FAILED", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_SETUP_FAILED", ("HIPFFT_SETUP_FAILED", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_INVALID_SIZE", ("HIPFFT_INVALID_SIZE", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "CUFFT_UNALIGNED_DATA", + ("HIPFFT_UNALIGNED_DATA", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_INCOMPLETE_PARAMETER_LIST", + ("HIPFFT_INCOMPLETE_PARAMETER_LIST", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_INVALID_DEVICE", + ("HIPFFT_INVALID_DEVICE", CONV_NUMERIC_LITERAL, API_FFT), + ), + ("CUFFT_PARSE_ERROR", ("HIPFFT_PARSE_ERROR", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_NO_WORKSPACE", ("HIPFFT_NO_WORKSPACE", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "CUFFT_NOT_IMPLEMENTED", + ("HIPFFT_NOT_IMPLEMENTED", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_LICENSE_ERROR", + ("HIPFFT_LICENSE_ERROR", CONV_NUMERIC_LITERAL, API_FFT, HIP_UNSUPPORTED), + ), + ( + "CUFFT_NOT_SUPPORTED", + ("HIPFFT_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_FFT), + ), + ("cufftType_t", ("hipfftType_t", CONV_TYPE, API_FFT)), + ("cufftType", ("hipfftType", CONV_TYPE, API_FFT)), + ("CUFFT_R2C", ("HIPFFT_R2C", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_C2R", ("HIPFFT_C2R", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_C2C", ("HIPFFT_C2C", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_D2Z", ("HIPFFT_D2Z", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_Z2D", ("HIPFFT_Z2D", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_Z2Z", ("HIPFFT_Z2Z", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "cufftCompatibility_t", + ("hipfftCompatibility_t", CONV_TYPE, API_FFT, HIP_UNSUPPORTED), + ), + ( + "cufftCompatibility", + ("hipfftCompatibility", CONV_TYPE, API_FFT, HIP_UNSUPPORTED), + ), + ( + "CUFFT_COMPATIBILITY_FFTW_PADDING", + ( + "HIPFFT_COMPATIBILITY_FFTW_PADDING", + CONV_NUMERIC_LITERAL, + API_FFT, + HIP_UNSUPPORTED, + ), + ), + ("cufftReal", ("hipfftReal", CONV_TYPE, API_FFT)), + ("cufftDoubleReal", ("hipfftDoubleReal", CONV_TYPE, API_FFT)), + ("cufftComplex", ("hipfftComplex", CONV_TYPE, API_FFT)), + ("cufftDoubleComplex", ("hipfftDoubleComplex", CONV_TYPE, API_FFT)), + ("cufftHandle", ("hipfftHandle", CONV_TYPE, API_FFT)), + ("cufftPlan1d", ("hipfftPlan1d", CONV_MATH_FUNC, API_FFT)), + ("cufftPlan2d", ("hipfftPlan2d", CONV_MATH_FUNC, API_FFT)), + ("cufftPlan3d", ("hipfftPlan3d", CONV_MATH_FUNC, API_FFT)), + ("cufftPlanMany", ("hipfftPlanMany", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlan1d", ("hipfftMakePlan1d", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlan2d", ("hipfftMakePlan2d", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlan3d", ("hipfftMakePlan3d", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlanMany", ("hipfftMakePlanMany", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlanMany64", ("hipfftMakePlanMany64", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSizeMany64", ("hipfftGetSizeMany64", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimate1d", ("hipfftEstimate1d", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimate2d", ("hipfftEstimate2d", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimate3d", ("hipfftEstimate3d", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimateMany", ("hipfftEstimateMany", CONV_MATH_FUNC, API_FFT)), + ("cufftCreate", ("hipfftCreate", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize1d", ("hipfftGetSize1d", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize2d", ("hipfftGetSize2d", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize3d", ("hipfftGetSize3d", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSizeMany", ("hipfftGetSizeMany", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize", ("hipfftGetSize", CONV_MATH_FUNC, API_FFT)), + ("cufftSetWorkArea", ("hipfftSetWorkArea", CONV_MATH_FUNC, API_FFT)), + ( + "cufftSetAutoAllocation", + ("hipfftSetAutoAllocation", CONV_MATH_FUNC, API_FFT), + ), + ("cufftXtExec", ("hipfftXtExec", CONV_MATH_FUNC, API_FFT)), + ("cufftXtMakePlanMany", ("hipfftXtMakePlanMany", CONV_MATH_FUNC, API_FFT)), + ("cufftExecC2C", ("hipfftExecC2C", CONV_MATH_FUNC, API_FFT)), + ("cufftExecR2C", ("hipfftExecR2C", CONV_MATH_FUNC, API_FFT)), + ("cufftExecC2R", ("hipfftExecC2R", CONV_MATH_FUNC, API_FFT)), + ("cufftExecZ2Z", ("hipfftExecZ2Z", CONV_MATH_FUNC, API_FFT)), + ("cufftExecD2Z", ("hipfftExecD2Z", CONV_MATH_FUNC, API_FFT)), + ("cufftExecZ2D", ("hipfftExecZ2D", CONV_MATH_FUNC, API_FFT)), + ("cufftSetStream", ("hipfftSetStream", CONV_MATH_FUNC, API_FFT)), + ("cufftDestroy", ("hipfftDestroy", CONV_MATH_FUNC, API_FFT)), + ("cufftGetVersion", ("hipfftGetVersion", CONV_MATH_FUNC, API_FFT)), + ( + "cufftGetProperty", + ("hipfftGetProperty", CONV_MATH_FUNC, API_FFT, HIP_UNSUPPORTED), + ), + ("nvrtcResult", ("hiprtcResult", CONV_TYPE, API_RTC)), + ("NVRTC_SUCCESS", ("HIPRTC_SUCCESS", CONV_TYPE, API_RTC)), + ( + "NVRTC_ERROR_OUT_OF_MEMORY", + ("HIPRTC_ERROR_OUT_OF_MEMORY", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_PROGRAM_CREATION_FAILURE", + ("HIPRTC_ERROR_PROGRAM_CREATION_FAILURE", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_INVALID_INPUT", + ("HIPRTC_ERROR_INVALID_INPUT", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_INVALID_PROGRAM", + ("HIPRTC_ERROR_INVALID_PROGRAM", CONV_TYPE, API_RTC), + ), + ("NVRTC_ERROR_COMPILATION", ("HIPRTC_ERROR_COMPILATION", CONV_TYPE, API_RTC)), + ( + "NVRTC_ERROR_BUILTIN_OPERATION_FAILURE", + ("HIPRTC_ERROR_BUILTIN_OPERATION_FAILURE", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION", + ("HIPRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID", + ("HIPRTC_ERROR_NAME_EXPRESSION_NOT_VALID", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_INTERNAL_ERROR", + ("HIPRTC_ERROR_INTERNAL_ERROR", CONV_TYPE, API_RTC), + ), + ("nvrtcGetErrorString", ("hiprtcGetErrorString", CONV_JIT, API_RTC)), + ("nvrtcVersion", ("hiprtcVersion", CONV_JIT, API_RTC)), + ("nvrtcProgram", ("hiprtcProgram", CONV_TYPE, API_RTC)), + ("nvrtcAddNameExpression", ("hiprtcAddNameExpression", CONV_JIT, API_RTC)), + ("nvrtcCompileProgram", ("hiprtcCompileProgram", CONV_JIT, API_RTC)), + ("nvrtcCreateProgram", ("hiprtcCreateProgram", CONV_JIT, API_RTC)), + ("nvrtcDestroyProgram", ("hiprtcDestroyProgram", CONV_JIT, API_RTC)), + ("nvrtcGetLoweredName", ("hiprtcGetLoweredName", CONV_JIT, API_RTC)), + ("nvrtcGetProgramLog", ("hiprtcGetProgramLog", CONV_JIT, API_RTC)), + ("nvrtcGetProgramLogSize", ("hiprtcGetProgramLogSize", CONV_JIT, API_RTC)), + ("nvrtcGetPTX", ("hiprtcGetCode", CONV_JIT, API_RTC)), + ("nvrtcGetPTXSize", ("hiprtcGetCodeSize", CONV_JIT, API_RTC)), + ("thrust::cuda", ("thrust::hip", CONV_MATH_FUNC, API_BLAS)), + ( + "cudaCpuDeviceId", + ("hipCpuDeviceId", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + # The caffe2 directory does a string match; pytorch does a word-boundary match. + # Patterns such as 'cub::' will not match for pytorch. + # We list all current uses of cub symbols for this reason. + ("cub::", ("hipcub::", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ArgMax", ("hipcub::ArgMax", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ArgMin", ("hipcub::ArgMin", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_SCAN_WARP_SCANS", ("hipcub::BLOCK_SCAN_WARP_SCANS", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_REDUCE_WARP_REDUCTIONS", ("hipcub::BLOCK_REDUCE_WARP_REDUCTIONS", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_STORE_WARP_TRANSPOSE", ("hipcub::BLOCK_STORE_WARP_TRANSPOSE", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_LOAD_DIRECT", ("hipcub::BLOCK_LOAD_DIRECT", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_STORE_DIRECT", ("hipcub::BLOCK_STORE_DIRECT", CONV_SPECIAL_FUNC, API_RUNTIME)), + ( + "cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY", + ("hipcub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY", CONV_SPECIAL_FUNC, API_RUNTIME) + ), + ("cub::BlockReduce", ("hipcub::BlockReduce", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockScan", ("hipcub::BlockScan", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockLoad", ("hipcub::BlockLoad", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockStore", ("hipcub::BlockStore", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockRakingLayout", ("hipcub::BlockRakingLayout", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockRadixSort", ("hipcub::BlockRadixSort", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Uninitialized", ("hipcub::Uninitialized", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::RowMajorTid", ("hipcub::RowMajorTid", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::CachingDeviceAllocator", ("hipcub::CachingDeviceAllocator", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::CountingInputIterator", ("hipcub::CountingInputIterator", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceRadixSort", ("hipcub::DeviceRadixSort", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceReduce", ("hipcub::DeviceReduce", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceRunLengthEncode", ("hipcub::DeviceRunLengthEncode", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceScan", ("hipcub::DeviceScan", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceSegmentedRadixSort", ("hipcub::DeviceSegmentedRadixSort", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceSegmentedReduce", ("hipcub::DeviceSegmentedReduce", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceSelect", ("hipcub::DeviceSelect", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::FpLimits", ("hipcub::FpLimits", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::KeyValuePair", ("hipcub::KeyValuePair", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Max", ("hipcub::Max", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Min", ("hipcub::Min", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Sum", ("hipcub::Sum", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Log2", ("hipcub::Log2", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::LaneId", ("hipcub::LaneId", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::WarpMask", ("hipcub::WarpMask", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ShuffleIndex", ("hipcub::ShuffleIndex", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ShuffleDown", ("hipcub::ShuffleDown", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ArgIndexInputIterator", ("hipcub::ArgIndexInputIterator", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::TransformInputIterator", ("hipcub::TransformInputIterator", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::WarpReduce", ("hipcub::WarpReduce", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::CTA_SYNC", ("hipcub::CTA_SYNC", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("nvtxMark", ("roctxMark", CONV_OTHER, API_ROCTX)), + ("nvtxMarkA", ("roctxMarkA", CONV_OTHER, API_ROCTX)), + ("nvtxRangePushA", ("roctxRangePushA", CONV_OTHER, API_ROCTX)), + ("nvtxRangePop", ("roctxRangePop", CONV_OTHER, API_ROCTX)), + ("nvtxRangeStartA", ("roctxRangeStartA", CONV_OTHER, API_ROCTX)), + ("nvtxRangeEnd", ("roctxRangeStop", CONV_OTHER, API_ROCTX)), + ("nvtxRangeId_t", ("int", CONV_OTHER, API_ROCTX)), + ("nvmlReturn_t", ("rsmi_status_t", CONV_OTHER, API_ROCMSMI)), + ("NVML_SUCCESS", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)), + ("NVML_P2P_CAPS_INDEX_READ", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)), + ("NVML_P2P_STATUS_OK", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)), + ("NVML_ERROR_INSUFFICIENT_SIZE", ("RSMI_STATUS_INSUFFICIENT_SIZE", CONV_OTHER, API_ROCMSMI)), + ("nvmlDevice_t", ("uint32_t", CONV_OTHER, API_ROCMSMI)), + ("nvmlGpuP2PStatus_t", ("bool", CONV_OTHER, API_ROCMSMI)), + ("nvmlProcessInfo_t", ("rsmi_process_info_t", CONV_OTHER, API_ROCMSMI)), + ("nvmlGpuP2PCapsIndex_t", ("uint32_t", CONV_OTHER, API_ROCMSMI)), + ] +) + +# pyrefly: ignore [no-matching-overload] +CUDA_SPECIAL_MAP = collections.OrderedDict( + [ + # SPARSE + ("cusparseStatus_t", ("hipsparseStatus_t", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseHandle_t", ("hipsparseHandle_t", CONV_MATH_FUNC, API_SPECIAL)), + ("cuComplex", ("hipComplex", CONV_TYPE, API_SPECIAL)), + ("cuDoubleComplex", ("hipDoubleComplex", CONV_TYPE, API_SPECIAL)), + ( + "CUSPARSE_POINTER_MODE_HOST", + ("HIPSPARSE_POINTER_MODE_HOST", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ("cusparseOperation_t", ("hipsparseOperation_t", CONV_TYPE, API_SPECIAL)), + ( + "cusparseCreateMatDescr", + ("hipsparseCreateMatDescr", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseCreate", ("hipsparseCreate", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseDestroyMatDescr", + ("hipsparseDestroyMatDescr", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseDestroy", ("hipsparseDestroy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXcoo2csr", ("hipsparseXcoo2csr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseMatDescr_t", ("hipsparseMatDescr_t", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDiagType_t", ("hipsparseDiagType_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_DIAG_TYPE_UNIT", ("HIPSPARSE_DIAG_TYPE_UNIT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_DIAG_TYPE_NON_UNIT", ("HIPSPARSE_DIAG_TYPE_NON_UNIT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseSetMatDiagType", ("hipsparseSetMatDiagType", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseFillMode_t", ("hipsparseFillMode_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_FILL_MODE_UPPER", ("HIPSPARSE_FILL_MODE_UPPER", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_FILL_MODE_LOWER", ("HIPSPARSE_FILL_MODE_LOWER", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseSetMatFillMode", ("hipsparseSetMatFillMode", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDirection_t", ("hipsparseDirection_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_DIRECTION_ROW", ("HIPSPARSE_DIRECTION_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_DIRECTION_COLUMN", ("HIPSPARSE_DIRECTION_COLUMN", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseSolvePolicy_t", ("hipsparseSolvePolicy_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_SOLVE_POLICY_NO_LEVEL", ("HIPSPARSE_SOLVE_POLICY_NO_LEVEL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SOLVE_POLICY_USE_LEVEL", ("HIPSPARSE_SOLVE_POLICY_USE_LEVEL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseCreateBsrsv2Info", ("hipsparseCreateBsrsv2Info", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateBsrsm2Info", ("hipsparseCreateBsrsm2Info", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroyBsrsv2Info", ("hipsparseDestroyBsrsv2Info", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroyBsrsm2Info", ("hipsparseDestroyBsrsm2Info", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrmm", ("hipsparseSbsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrmm", ("hipsparseDbsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrmm", ("hipsparseCbsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrmm", ("hipsparseZbsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrmv", ("hipsparseSbsrmv", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrmv", ("hipsparseDbsrmv", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrmv", ("hipsparseCbsrmv", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrmv", ("hipsparseZbsrmv", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsv2_bufferSize", ("hipsparseSbsrsv2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsv2_bufferSize", ("hipsparseDbsrsv2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsv2_bufferSize", ("hipsparseCbsrsv2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsv2_bufferSize", ("hipsparseZbsrsv2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsv2_analysis", ("hipsparseSbsrsv2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsv2_analysis", ("hipsparseDbsrsv2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsv2_analysis", ("hipsparseCbsrsv2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsv2_analysis", ("hipsparseZbsrsv2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsv2_solve", ("hipsparseSbsrsv2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsv2_solve", ("hipsparseDbsrsv2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsv2_solve", ("hipsparseCbsrsv2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsv2_solve", ("hipsparseZbsrsv2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsm2_bufferSize", ("hipsparseSbsrsm2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsm2_bufferSize", ("hipsparseDbsrsm2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsm2_bufferSize", ("hipsparseCbsrsm2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsm2_bufferSize", ("hipsparseZbsrsm2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsm2_analysis", ("hipsparseSbsrsm2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsm2_analysis", ("hipsparseDbsrsm2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsm2_analysis", ("hipsparseCbsrsm2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsm2_analysis", ("hipsparseZbsrsm2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsm2_solve", ("hipsparseSbsrsm2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsm2_solve", ("hipsparseDbsrsm2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsm2_solve", ("hipsparseCbsrsm2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsm2_solve", ("hipsparseZbsrsm2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrmm2", ("hipsparseScsrmm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrmm2", ("hipsparseDcsrmm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCcsrmm2", ("hipsparseCcsrmm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZcsrmm2", ("hipsparseZcsrmm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrmm", ("hipsparseScsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrmm", ("hipsparseDcsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseXcsrsort_bufferSizeExt", + ("hipsparseXcsrsort_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseCreateCsrgemm2Info", ("hipsparseCreateCsrgemm2Info", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseDestroyCsrgemm2Info", + ("hipsparseDestroyCsrgemm2Info", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseXcsrgemm2Nnz", ("hipsparseXcsrgemm2Nnz", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrgemm2_bufferSizeExt", ("hipsparseDcsrgemm2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrgemm2_bufferSizeExt", ("hipsparseScsrgemm2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrgemm2", ("hipsparseDcsrgemm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrgemm2", ("hipsparseScsrgemm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSetPointerMode", ("hipsparseSetPointerMode", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXcsrgeam2Nnz", ("hipsparseXcsrgeam2Nnz", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrgeam2_bufferSizeExt", ("hipsparseScsrgeam2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrgeam2_bufferSizeExt", ("hipsparseDcsrgeam2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCcsrgeam2_bufferSizeExt", ("hipsparseCcsrgeam2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZcsrgeam2_bufferSizeExt", ("hipsparseZcsrgeam2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrgeam2", ("hipsparseScsrgeam2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrgeam2", ("hipsparseDcsrgeam2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCcsrgeam2", ("hipsparseCcsrgeam2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZcsrgeam2", ("hipsparseZcsrgeam2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXcsrsort", ("hipsparseXcsrsort", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXbsrsm2_zeroPivot", ("hipsparseXbsrsm2_zeroPivot", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXbsrsv2_zeroPivot", ("hipsparseXbsrsv2_zeroPivot", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseXcoosort_bufferSizeExt", + ("hipsparseXcoosort_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL), + ), + ( + "cusparseXcoosortByRow", + ("hipsparseXcoosortByRow", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseSetStream", ("hipsparseSetStream", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseCreateIdentityPermutation", + ("hipsparseCreateIdentityPermutation", CONV_MATH_FUNC, API_SPECIAL), + ), + ( + "cusparseSetMatIndexBase", + ("hipsparseSetMatIndexBase", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseSetMatType", ("hipsparseSetMatType", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMV", ("hipsparseSpMV", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMV_bufferSize", ("hipsparseSpMV_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMM", ("hipsparseSpMM", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMM_bufferSize", ("hipsparseSpMM_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateDnMat", ("hipsparseCreateDnMat", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDnMatSetStridedBatch", ("hipsparseDnMatSetStridedBatch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCsrSetStridedBatch", ("hipsparseCsrSetStridedBatch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateDnVec", ("hipsparseCreateDnVec", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateCsr", ("hipsparseCreateCsr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroyDnMat", ("hipsparseDestroyDnMat", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroyDnVec", ("hipsparseDestroyDnVec", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroySpMat", ("hipsparseDestroySpMat", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_destroyDescr", ("hipsparseSpGEMM_destroyDescr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateCoo", ("hipsparseCreateCoo", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateCsr", ("hipsparseCreateCsr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_createDescr", ("hipsparseSpGEMM_createDescr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDnMatSetStridedBatch", ("hipsparseDnMatSetStridedBatch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_copy", ("hipsparseSpGEMM_copy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSDDMM_bufferSize", ("hipsparseSDDMM_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSDDMM_preprocess", ("hipsparseSDDMM_preprocess", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSDDMM", ("hipsparseSDDMM", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_compute", ("hipsparseSpGEMM_compute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_workEstimation", ("hipsparseSpGEMM_workEstimation", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMatGetSize", ("hipsparseSpMatGetSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCsrSetPointers", ("hipsparseCsrSetPointers", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMVAlg_t", ("hipsparseSpMVAlg_t", CONV_TYPE, API_SPECIAL)), + ("cusparseSpMMAlg_t", ("hipsparseSpMMAlg_t", CONV_TYPE, API_SPECIAL)), + ("cusparseIndexType_t", ("hipsparseIndexType_t", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseMatDescr", ("hipsparseMatDescr", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseDnMatDescr", ("hipsparseDnMatDescr", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseDnVecDescr", ("hipsparseDnVecDescr", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseSpMatDescr", ("hipsparseSpMatDescr", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseSpGEMMDescr", ("hipsparseSpGEMMDescr", CONV_TYPE, API_SPECIAL)), + ("cusparseDnMatDescr_t", ("hipsparseDnMatDescr_t", CONV_TYPE, API_SPECIAL)), + ("cusparseDnVecDescr_t", ("hipsparseDnVecDescr_t", CONV_TYPE, API_SPECIAL)), + ("cusparseSpMatDescr_t", ("hipsparseSpMatDescr_t", CONV_TYPE, API_SPECIAL)), + ("cusparseSpGEMMDescr_t", ("hipsparseSpGEMMDescr_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_INDEX_32I", ("HIPSPARSE_INDEX_32I", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_INDEX_64I", ("HIPSPARSE_INDEX_64I", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_ORDER_ROW", ("HIPSPARSE_ORDER_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_MV_ALG_DEFAULT", ("HIPSPARSE_MV_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_MM_ALG_DEFAULT", ("HIPSPARSE_MM_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_COO_ALG1", ("HIPSPARSE_SPMM_COO_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_COO_ALG2", ("HIPSPARSE_SPMM_COO_ALG2", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG1", ("HIPSPARSE_SPMM_CSR_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG2", ("HIPSPARSE_SPMM_CSR_ALG2", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG3", ("HIPSPARSE_SPMM_CSR_ALG3", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COOMV_ALG", ("HIPSPARSE_COOMV_ALG", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG1", ("HIPSPARSE_CSRMM_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPGEMM_DEFAULT", ("HIPSPARSE_SPGEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SDDMM_ALG_DEFAULT", ("HIPSPARSE_SDDMM_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ( + "CUSPARSE_STATUS_SUCCESS", + ("HIPSPARSE_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_NOT_INITIALIZED", + ("HIPSPARSE_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_ALLOC_FAILED", + ("HIPSPARSE_STATUS_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_INVALID_VALUE", + ("HIPSPARSE_STATUS_INVALID_VALUE", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_MAPPING_ERROR", + ("HIPSPARSE_STATUS_MAPPING_ERROR", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_EXECUTION_FAILED", + ("HIPSPARSE_STATUS_EXECUTION_FAILED", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_INTERNAL_ERROR", + ("HIPSPARSE_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED", + ( + "HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED", + CONV_NUMERIC_LITERAL, + API_SPECIAL, + ), + ), + ( + "CUSPARSE_STATUS_ARCH_MISMATCH", + ("HIPSPARSE_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_ZERO_PIVOT", + ("HIPSPARSE_STATUS_ZERO_PIVOT", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_OPERATION_TRANSPOSE", + ("HIPSPARSE_OPERATION_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_OPERATION_NON_TRANSPOSE", + ("HIPSPARSE_OPERATION_NON_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE", + ( + "HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE", + CONV_NUMERIC_LITERAL, + API_SPECIAL, + ), + ), + ( + "CUSPARSE_INDEX_BASE_ZERO", + ("HIPSPARSE_INDEX_BASE_ZERO", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_INDEX_BASE_ONE", + ("HIPSPARSE_INDEX_BASE_ONE", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_MATRIX_TYPE_GENERAL", + ("HIPSPARSE_MATRIX_TYPE_GENERAL", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + # SparseLt + ("cuSPARSELt", ("hipSPARSELt", CONV_TYPE, API_SPECIAL)), + ("AT_CUSPARSELT_ENABLED", ("AT_HIPSPARSELT_ENABLED", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_ORDER_ROW", ("HIPSPARSE_ORDER_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_SPARSITY_50_PERCENT", ("HIPSPARSELT_SPARSITY_50_PERCENT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseComputeType", ("hipsparseLtComputetype_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_COMPUTE_32F", ("HIPSPARSELT_COMPUTE_32F", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_16F", ("HIPSPARSELT_COMPUTE_16F", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_32I", ("HIPSPARSELT_COMPUTE_32I", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_TF32", ("HIPSPARSELT_COMPUTE_TF32", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALG_CONFIG_ID", ("HIPSPARSELT_MATMUL_ALG_CONFIG_ID", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALG_CONFIG_MAX_ID", ("HIPSPARSELT_MATMUL_ALG_CONFIG_MAX_ID", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_BIAS_POINTER", ("HIPSPARSELT_MATMUL_BIAS_POINTER", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALG_DEFAULT", ("HIPSPARSELT_MATMUL_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALG_CONFIG_ID", ("HIPSPARSELT_MATMUL_ALG_CONFIG_ID", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", ("HIPSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_SPLIT_K", ("HIPSPARSELT_MATMUL_SPLIT_K", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_SPLIT_K_MODE", ("HIPSPARSELT_MATMUL_SPLIT_K_MODE", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseLtHandle_t", ("hipsparseLtHandle_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatDescriptor_t", ("hipsparseLtMatDescriptor_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtInit", ("hipsparseLtInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtSplitKMode_t", ("hipsparseLtSplitKMode_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtSpMMACompressedSize2", ("hipsparseLtSpMMACompressedSize2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtSpMMACompress2", ("hipsparseLtSpMMACompress2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescriptor_t", ("hipsparseLtMatmulDescriptor_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatmulPlan_t", ("hipsparseLtMatmulPlan_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatmulAlgSelection_t", ("hipsparseLtMatmulAlgSelection_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtDenseDescriptorInit", ("hipsparseLtDenseDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescriptorInit", ("hipsparseLtMatmulDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescSetAttribute", ("hipsparseLtMatmulDescSetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgSelectionInit", ("hipsparseLtMatmulAlgSelectionInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgSetAttribute", ("hipsparseLtMatmulAlgSetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulPlanInit", ("hipsparseLtMatmulPlanInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulGetWorkspace", ("hipsparseLtMatmulGetWorkspace", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulSearch", ("hipsparseLtMatmulSearch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgGetAttribute", ("hipsparseLtMatmulAlgGetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmul", ("hipsparseLtMatmul", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatDescriptorDestroy", ("hipsparseLtMatDescriptorDestroy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulPlanDestroy", ("hipsparseLtMatmulPlanDestroy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseGetErrorString", ("hipsparseGetErrorString", CONV_MATH_FUNC, API_SPECIAL)), + # SOLVER + ("cublasOperation_t", ("hipsolverOperation_t", CONV_TYPE, API_SPECIAL)), + ("CUBLAS_OP_N", ("HIPSOLVER_OP_N", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ( + "CUBLAS_OP_T", + ("HIPSOLVER_OP_T", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUBLAS_OP_C", + ("HIPSOLVER_OP_C", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ("cublasFillMode_t", ("hipsolverFillMode_t", CONV_TYPE, API_SPECIAL)), + ( + "CUBLAS_FILL_MODE_LOWER", + ("HIPSOLVER_FILL_MODE_LOWER", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUBLAS_FILL_MODE_UPPER", + ("HIPSOLVER_FILL_MODE_UPPER", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ("cublasSideMode_t", ("hipsolverSideMode_t", CONV_TYPE, API_SPECIAL)), + ("CUBLAS_SIDE_LEFT", ("HIPSOLVER_SIDE_LEFT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUBLAS_SIDE_RIGHT", ("HIPSOLVER_SIDE_RIGHT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + + ("cusolverEigMode_t", ("hipsolverEigMode_t", CONV_TYPE, API_SPECIAL)), + ("CUSOLVER_EIG_MODE_VECTOR", ("HIPSOLVER_EIG_MODE_VECTOR", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSOLVER_EIG_MODE_NOVECTOR", ("HIPSOLVER_EIG_MODE_NOVECTOR", CONV_NUMERIC_LITERAL, API_SPECIAL)), + + ("syevjInfo_t", ("hipsolverSyevjInfo_t", CONV_TYPE, API_SPECIAL)), + ("cusolverDnCreateSyevjInfo", ("hipsolverDnCreateSyevjInfo", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnXsyevjSetSortEig", ("hipsolverDnXsyevjSetSortEig", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnDestroySyevjInfo", ("hipsolverDnDestroySyevjInfo", CONV_MATH_FUNC, API_SPECIAL)), + + ("gesvdjInfo_t", ("hipsolverGesvdjInfo_t", CONV_TYPE, API_SPECIAL)), + ("cusolverDnCreateGesvdjInfo", ("hipsolverDnCreateGesvdjInfo", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnXgesvdjSetSortEig", ("hipsolverDnXgesvdjSetSortEig", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnDestroyGesvdjInfo", ("hipsolverDnDestroyGesvdjInfo", CONV_MATH_FUNC, API_SPECIAL)), + + ("cusolverDnHandle_t", ("hipsolverDnHandle_t", CONV_TYPE, API_SPECIAL)), + ("cusolverDnCreate", ("hipsolverDnCreate", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnSetStream", ("hipsolverDnSetStream", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnDestroy", ("hipsolverDnDestroy", CONV_MATH_FUNC, API_SPECIAL)), + + # from aten/src/ATen/native/hip/linalg/HIPSolver.cpp + ('cusolverDnParams_t', ('hipsolverDnParams_t', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgeqrf', ('hipsolverDnCgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgeqrf_bufferSize', ('hipsolverDnCgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvd', ('hipsolverDnCgesvd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvd_bufferSize', ('hipsolverDnCgesvd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdj', ('hipsolverDnCgesvdj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdjBatched', ('hipsolverDnCgesvdjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdjBatched_bufferSize', ('hipsolverDnCgesvdjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdj_bufferSize', ('hipsolverDnCgesvdj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgetrf', ('hipsolverDnCgetrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgetrf_bufferSize', ('hipsolverDnCgetrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgetrs', ('hipsolverDnCgetrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevd', ('hipsolverDnCheevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevd_bufferSize', ('hipsolverDnCheevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevj', ('hipsolverDnCheevj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevjBatched', ('hipsolverDnCheevjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevjBatched_bufferSize', ('hipsolverDnCheevjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevj_bufferSize', ('hipsolverDnCheevj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrf', ('hipsolverDnCpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrfBatched', ('hipsolverDnCpotrfBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrf_bufferSize', ('hipsolverDnCpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrs', ('hipsolverDnCpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrsBatched', ('hipsolverDnCpotrsBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCungqr', ('hipsolverDnCungqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCungqr_bufferSize', ('hipsolverDnCungqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCunmqr', ('hipsolverDnCunmqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCunmqr_bufferSize', ('hipsolverDnCunmqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgeqrf', ('hipsolverDnDgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgeqrf_bufferSize', ('hipsolverDnDgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvd', ('hipsolverDnDgesvd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvd_bufferSize', ('hipsolverDnDgesvd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdj', ('hipsolverDnDgesvdj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdjBatched', ('hipsolverDnDgesvdjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdjBatched_bufferSize', ('hipsolverDnDgesvdjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdj_bufferSize', ('hipsolverDnDgesvdj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgetrf', ('hipsolverDnDgetrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgetrf_bufferSize', ('hipsolverDnDgetrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgetrs', ('hipsolverDnDgetrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDorgqr', ('hipsolverDnDorgqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDorgqr_bufferSize', ('hipsolverDnDorgqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDormqr', ('hipsolverDnDormqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDormqr_bufferSize', ('hipsolverDnDormqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrf', ('hipsolverDnDpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrfBatched', ('hipsolverDnDpotrfBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrf_bufferSize', ('hipsolverDnDpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrs', ('hipsolverDnDpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrsBatched', ('hipsolverDnDpotrsBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevd', ('hipsolverDnDsyevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevd_bufferSize', ('hipsolverDnDsyevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevj', ('hipsolverDnDsyevj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevjBatched', ('hipsolverDnDsyevjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevjBatched_bufferSize', ('hipsolverDnDsyevjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevj_bufferSize', ('hipsolverDnDsyevj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgeqrf', ('hipsolverDnSgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgeqrf_bufferSize', ('hipsolverDnSgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvd', ('hipsolverDnSgesvd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvd_bufferSize', ('hipsolverDnSgesvd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvdj', ('hipsolverDnSgesvdj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvdjBatched', ('hipsolverDnSgesvdjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvdjBatched_bufferSize', ('hipsolverDnSgesvdjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvdj_bufferSize', ('hipsolverDnSgesvdj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgetrf', ('hipsolverDnSgetrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgetrf_bufferSize', ('hipsolverDnSgetrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgetrs', ('hipsolverDnSgetrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSorgqr', ('hipsolverDnSorgqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSorgqr_bufferSize', ('hipsolverDnSorgqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSormqr', ('hipsolverDnSormqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSormqr_bufferSize', ('hipsolverDnSormqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrf', ('hipsolverDnSpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrfBatched', ('hipsolverDnSpotrfBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrf_bufferSize', ('hipsolverDnSpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrs', ('hipsolverDnSpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrsBatched', ('hipsolverDnSpotrsBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevd', ('hipsolverDnSsyevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevd_bufferSize', ('hipsolverDnSsyevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevj', ('hipsolverDnSsyevj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevjBatched', ('hipsolverDnSsyevjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevjBatched_bufferSize', ('hipsolverDnSsyevjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevj_bufferSize', ('hipsolverDnSsyevj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXgeqrf', ('hipsolverDnXgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXgeqrf_bufferSize', ('hipsolverDnXgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXpotrf', ('hipsolverDnXpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXpotrf_bufferSize', ('hipsolverDnXpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXpotrs', ('hipsolverDnXpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXsyevd', ('hipsolverDnXsyevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXsyevd_bufferSize', ('hipsolverDnXsyevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgeqrf', ('hipsolverDnZgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgeqrf_bufferSize', ('hipsolverDnZgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvd', ('hipsolverDnZgesvd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvd_bufferSize', ('hipsolverDnZgesvd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdj', ('hipsolverDnZgesvdj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdjBatched', ('hipsolverDnZgesvdjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdjBatched_bufferSize', ('hipsolverDnZgesvdjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdj_bufferSize', ('hipsolverDnZgesvdj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgetrf', ('hipsolverDnZgetrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgetrf_bufferSize', ('hipsolverDnZgetrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgetrs', ('hipsolverDnZgetrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevd', ('hipsolverDnZheevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevd_bufferSize', ('hipsolverDnZheevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevj', ('hipsolverDnZheevj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevjBatched', ('hipsolverDnZheevjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevjBatched_bufferSize', ('hipsolverDnZheevjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevj_bufferSize', ('hipsolverDnZheevj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrf', ('hipsolverDnZpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrfBatched', ('hipsolverDnZpotrfBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrf_bufferSize', ('hipsolverDnZpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrs', ('hipsolverDnZpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrsBatched', ('hipsolverDnZpotrsBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZungqr', ('hipsolverDnZungqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZungqr_bufferSize', ('hipsolverDnZungqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZunmqr', ('hipsolverDnZunmqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZunmqr_bufferSize', ('hipsolverDnZunmqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + + # sytrf + ('cusolverDnDsytrf_bufferSize', ('hipsolverDnDsytrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsytrf_bufferSize', ('hipsolverDnSsytrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZsytrf_bufferSize', ('hipsolverDnZsytrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCsytrf_bufferSize', ('hipsolverDnCsytrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsytrf', ('hipsolverDnDsytrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsytrf', ('hipsolverDnSsytrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZsytrf', ('hipsolverDnZsytrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCsytrf', ('hipsolverDnCsytrf', CONV_MATH_FUNC, API_SPECIAL)), + + # gesdva strided + ( + 'cusolverDnSgesvdaStridedBatched_bufferSize', + ('hipsolverDnSgesvdaStridedBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL) + ), + ( + 'cusolverDnDgesvdaStridedBatched_bufferSize', + ('hipsolverDnDgesvdaStridedBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL) + ), + ( + 'cusolverDnCgesvdaStridedBatched_bufferSize', + ('hipsolverDnCgesvdaStridedBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL) + ), + ( + 'cusolverDnZgesvdaStridedBatched_bufferSize', + ('hipsolverDnZgesvdaStridedBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL) + ), + ('cusolverDnSgesvdaStridedBatched', ('hipsolverDnSgesvdaStridedBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdaStridedBatched', ('hipsolverDnDgesvdaStridedBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdaStridedBatched', ('hipsolverDnCgesvdaStridedBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdaStridedBatched', ('hipsolverDnZgesvdaStridedBatched', CONV_MATH_FUNC, API_SPECIAL)), + + # gesvdj SetXXX + ('cusolverDnXgesvdjSetTolerance', ('hipsolverDnXgesvdjSetTolerance', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXgesvdjSetMaxSweeps', ('hipsolverDnXgesvdjSetMaxSweeps', CONV_MATH_FUNC, API_SPECIAL)), + ] +) + +# pyrefly: ignore [no-matching-overload] +PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict( + [ + ("USE_CUDA", ("USE_ROCM", API_PYTORCH)), + ("TORCH_CUDA_CPP_API", ("TORCH_HIP_CPP_API", API_PYTORCH)), + ("TORCH_CUDA_CU_API", ("TORCH_HIP_API", API_PYTORCH)), + ("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)), + ("cudaHostAllocator", ("hipHostAllocator", API_PYTORCH)), + ("cudaDeviceAllocator", ("hipDeviceAllocator", API_PYTORCH)), + ("define MAX_NUM_BLOCKS 200", ("define MAX_NUM_BLOCKS 64", API_PYTORCH)), + ("cuda::CUDAGuard", ("hip::HIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ("CUDAGuard", ("HIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::OptionalCUDAGuard", + ("hip::OptionalHIPGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ("OptionalCUDAGuard", ("OptionalHIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::CUDAStreamGuard", + ("hip::HIPStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ("CUDAStreamGuard", ("HIPStreamGuardMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::OptionalCUDAStreamGuard", + ("hip::OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "OptionalCUDAStreamGuard", + ("OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::CUDAMultiStreamGuard", + ("hip::HIPMultiStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "CUDAMultiStreamGuard", + ("HIPMultiStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + # Only get needs to be transformed this way; all the other ones can go + # straight to the normal versions hip::HIPCachingAllocator + ( + "cuda::CUDACachingAllocator::get", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH), + ), + ( + "CUDACachingAllocator::get", + ("HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::recordStream", + ( + "hip::HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ( + "CUDACachingAllocator::recordStream", + ( + "HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ( + "cuda::CUDACachingAllocator::raw_alloc", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::raw_alloc", API_PYTORCH), + ), + ( + "CUDACachingAllocator::raw_alloc", + ("HIPCachingAllocatorMasqueradingAsCUDA::raw_alloc", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::raw_alloc_with_stream", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::raw_alloc_with_stream", API_PYTORCH), + ), + ( + "CUDACachingAllocator::raw_alloc_with_stream", + ("HIPCachingAllocatorMasqueradingAsCUDA::raw_alloc_with_stream", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::raw_delete", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::raw_delete", API_PYTORCH), + ), + ( + "CUDACachingAllocator::raw_delete", + ("HIPCachingAllocatorMasqueradingAsCUDA::raw_delete", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::init", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::init", API_PYTORCH), + ), + ( + "CUDACachingAllocator::init", + ("HIPCachingAllocatorMasqueradingAsCUDA::init", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::getMemoryFraction", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::getMemoryFraction", API_PYTORCH), + ), + ( + "CUDACachingAllocator::getMemoryFraction", + ("HIPCachingAllocatorMasqueradingAsCUDA::getMemoryFraction", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::setMemoryFraction", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::setMemoryFraction", API_PYTORCH), + ), + ( + "CUDACachingAllocator::setMemoryFraction", + ("HIPCachingAllocatorMasqueradingAsCUDA::setMemoryFraction", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::emptyCache", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::emptyCache", API_PYTORCH), + ), + ( + "CUDACachingAllocator::emptyCache", + ("HIPCachingAllocatorMasqueradingAsCUDA::emptyCache", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::enable", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::enable", API_PYTORCH), + ), + ( + "CUDACachingAllocator::enable", + ("HIPCachingAllocatorMasqueradingAsCUDA::enable", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::isEnabled", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::isEnabled", API_PYTORCH), + ), + ( + "CUDACachingAllocator::isEnabled", + ("HIPCachingAllocatorMasqueradingAsCUDA::isEnabled", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::cacheInfo", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::cacheInfo", API_PYTORCH), + ), + ( + "CUDACachingAllocator::cacheInfo", + ("HIPCachingAllocatorMasqueradingAsCUDA::cacheInfo", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::getBaseAllocation", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::getBaseAllocation", API_PYTORCH), + ), + ( + "CUDACachingAllocator::getBaseAllocation", + ("HIPCachingAllocatorMasqueradingAsCUDA::getBaseAllocation", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::getDeviceStats", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::getDeviceStats", API_PYTORCH), + ), + ( + "CUDACachingAllocator::getDeviceStats", + ("HIPCachingAllocatorMasqueradingAsCUDA::getDeviceStats", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::resetAccumulatedStats", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::resetAccumulatedStats", API_PYTORCH), + ), + ( + "CUDACachingAllocator::resetAccumulatedStats", + ("HIPCachingAllocatorMasqueradingAsCUDA::resetAccumulatedStats", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::resetPeakStats", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::resetPeakStats", API_PYTORCH), + ), + ( + "CUDACachingAllocator::resetPeakStats", + ("HIPCachingAllocatorMasqueradingAsCUDA::resetPeakStats", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::snapshot", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::snapshot", API_PYTORCH), + ), + ( + "CUDACachingAllocator::snapshot", + ("HIPCachingAllocatorMasqueradingAsCUDA::snapshot", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::getCheckpointState", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::getCheckpointState", API_PYTORCH), + ), + ( + "CUDACachingAllocator::getCheckpointState", + ("HIPCachingAllocatorMasqueradingAsCUDA::getCheckpointState", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::setCheckpointState", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::setCheckpointState", API_PYTORCH), + ), + ( + "CUDACachingAllocator::setCheckpointState", + ("HIPCachingAllocatorMasqueradingAsCUDA::setCheckpointState", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::setCheckpointPoolState", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::setCheckpointPoolState", API_PYTORCH), + ), + ( + "CUDACachingAllocator::setCheckpointPoolState", + ("HIPCachingAllocatorMasqueradingAsCUDA::setCheckpointPoolState", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::beginAllocateToPool", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::beginAllocateToPool", API_PYTORCH), + ), + ( + "CUDACachingAllocator::beginAllocateToPool", + ("HIPCachingAllocatorMasqueradingAsCUDA::beginAllocateToPool", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::endAllocateToPool", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::endAllocateToPool", API_PYTORCH), + ), + ( + "CUDACachingAllocator::endAllocateToPool", + ("HIPCachingAllocatorMasqueradingAsCUDA::endAllocateToPool", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::recordHistory", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::recordHistory", API_PYTORCH), + ), + ( + "CUDACachingAllocator::recordHistory", + ("HIPCachingAllocatorMasqueradingAsCUDA::recordHistory", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::recordAnnotation", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::recordAnnotation", API_PYTORCH), + ), + ( + "CUDACachingAllocator::recordAnnotation", + ("HIPCachingAllocatorMasqueradingAsCUDA::recordAnnotation", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::pushCompileContext", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::pushCompileContext", API_PYTORCH), + ), + ( + "CUDACachingAllocator::pushCompileContext", + ("HIPCachingAllocatorMasqueradingAsCUDA::pushCompileContext", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::popCompileContext", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::popCompileContext", API_PYTORCH), + ), + ( + "CUDACachingAllocator::popCompileContext", + ("HIPCachingAllocatorMasqueradingAsCUDA::popCompileContext", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::isHistoryEnabled", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::isHistoryEnabled", API_PYTORCH), + ), + ( + "CUDACachingAllocator::isHistoryEnabled", + ("HIPCachingAllocatorMasqueradingAsCUDA::isHistoryEnabled", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::checkPoolLiveAllocations", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::checkPoolLiveAllocations", API_PYTORCH), + ), + ( + "CUDACachingAllocator::checkPoolLiveAllocations", + ("HIPCachingAllocatorMasqueradingAsCUDA::checkPoolLiveAllocations", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::attachOutOfMemoryObserver", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::attachOutOfMemoryObserver", API_PYTORCH), + ), + ( + "CUDACachingAllocator::attachOutOfMemoryObserver", + ("HIPCachingAllocatorMasqueradingAsCUDA::attachOutOfMemoryObserver", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::attachAllocatorTraceTracker", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::attachAllocatorTraceTracker", API_PYTORCH), + ), + ( + "CUDACachingAllocator::attachAllocatorTraceTracker", + ("HIPCachingAllocatorMasqueradingAsCUDA::attachAllocatorTraceTracker", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::releasePool", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::releasePool", API_PYTORCH), + ), + ( + "CUDACachingAllocator::releasePool", + ("HIPCachingAllocatorMasqueradingAsCUDA::releasePool", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::createOrIncrefPool", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::createOrIncrefPool", API_PYTORCH), + ), + ( + "CUDACachingAllocator::createOrIncrefPool", + ("HIPCachingAllocatorMasqueradingAsCUDA::createOrIncrefPool", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::setUseOnOOM", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::setUseOnOOM", API_PYTORCH), + ), + ( + "CUDACachingAllocator::setUseOnOOM", + ("HIPCachingAllocatorMasqueradingAsCUDA::setUseOnOOM", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::setNoSplit", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::setNoSplit", API_PYTORCH), + ), + ( + "CUDACachingAllocator::setNoSplit", + ("HIPCachingAllocatorMasqueradingAsCUDA::setNoSplit", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::getPoolUseCount", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::getPoolUseCount", API_PYTORCH), + ), + ( + "CUDACachingAllocator::getPoolUseCount", + ("HIPCachingAllocatorMasqueradingAsCUDA::getPoolUseCount", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::getIpcDevPtr", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::getIpcDevPtr", API_PYTORCH), + ), + ( + "CUDACachingAllocator::getIpcDevPtr", + ("HIPCachingAllocatorMasqueradingAsCUDA::getIpcDevPtr", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::shareIpcHandle", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::shareIpcHandle", API_PYTORCH), + ), + ( + "CUDACachingAllocator::shareIpcHandle", + ("HIPCachingAllocatorMasqueradingAsCUDA::shareIpcHandle", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::name", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::name", API_PYTORCH), + ), + ( + "CUDACachingAllocator::name", + ("HIPCachingAllocatorMasqueradingAsCUDA::name", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::memcpyAsync", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::memcpyAsync", API_PYTORCH), + ), + ( + "CUDACachingAllocator::memcpyAsync", + ("HIPCachingAllocatorMasqueradingAsCUDA::memcpyAsync", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::enablePeerAccess", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::enablePeerAccess", API_PYTORCH), + ), + ( + "CUDACachingAllocator::enablePeerAccess", + ("HIPCachingAllocatorMasqueradingAsCUDA::enablePeerAccess", API_PYTORCH), + ), + ( + "cuda::CUDAAllocator::recordStream", + ( + "hip::HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ( + "CUDAAllocator::recordStream", + ( + "HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ("cuda::CUDAStream", ("hip::HIPStreamMasqueradingAsCUDA", API_PYTORCH)), + ("CUDAStream", ("HIPStreamMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::getStreamFromPool", + ("hip::getStreamFromPoolMasqueradingAsCUDA", API_PYTORCH), + ), + ("getStreamFromPool", ("getStreamFromPoolMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::getDefaultCUDAStream", + ("hip::getDefaultHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::getStreamFromExternal", + ("hip::getStreamFromExternalMasqueradingAsCUDA", API_PYTORCH), + ), + ("getStreamFromExternal", ("getStreamFromExternalMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::getDefaultCUDAStream", + ("hip::getDefaultHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "getDefaultCUDAStream", + ("getDefaultHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::getCurrentCUDAStream", + ("hip::getCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "getCurrentCUDAStream", + ("getCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::setCurrentCUDAStream", + ("hip::setCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "setCurrentCUDAStream", + ("setCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "ATen/cudnn/Handle.h", + ("ATen/miopen/Handle.h", API_PYTORCH), + ), + # TODO: Undo this special-case; see the header for motivation behind this + # hack. It's VERY important this is only applied to PyTorch HIPify. + ( + "c10/cuda/CUDAGuard.h", + ("ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h", API_PYTORCH), + ), + ( + "c10/cuda/CUDACachingAllocator.h", + ("ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h", API_PYTORCH), + ), + ( + "c10/cuda/CUDAStream.h", + ("ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h", API_PYTORCH), + ), + ("gloo/cuda.h", ("gloo/hip.h", API_PYTORCH)), + ( + "gloo/cuda_allreduce_halving_doubling.h", + ("gloo/hip_allreduce_halving_doubling.h", API_PYTORCH), + ), + ( + "gloo/cuda_allreduce_halving_doubling_pipelined.h", + ("gloo/hip_allreduce_halving_doubling_pipelined.h", API_PYTORCH), + ), + ("gloo/cuda_allreduce_ring.h", ("gloo/hip_allreduce_ring.h", API_PYTORCH)), + ("gloo/cuda_allreduce_ring_chunked.h", ("gloo/hip_allreduce_ring_chunked.h", API_PYTORCH)), + ( + "gloo/cuda_broadcast_one_to_all.h", + ("gloo/hip_broadcast_one_to_all.h", API_PYTORCH), + ), + ( + "gloo::CudaAllreduceHalvingDoublingPipelined", + ("gloo::HipAllreduceHalvingDoublingPipelined", API_PYTORCH), + ), + ( + "gloo::CudaAllreduceRingChunked", + ("gloo::HipAllreduceRingChunked", API_PYTORCH), + ), + ("gloo::CudaBroadcastOneToAll", ("gloo::HipBroadcastOneToAll", API_PYTORCH)), + ("gloo::CudaHostWorkspace", ("gloo::HipHostWorkspace", API_PYTORCH)), + ("gloo::CudaDeviceWorkspace", ("gloo::HipDeviceWorkspace", API_PYTORCH)), + ("CUDNN_RNN_RELU", ("miopenRNNRELU", API_PYTORCH)), + ("CUDNN_RNN_TANH", ("miopenRNNTANH", API_PYTORCH)), + ("CUDNN_LSTM", ("miopenLSTM", API_PYTORCH)), + ("CUDNN_GRU", ("miopenGRU", API_PYTORCH)), + ("cudnnRNNMode_t", ("miopenRNNMode_t", API_PYTORCH)), + ("magma_queue_create_from_cuda", ("magma_queue_create_from_hip", API_PYTORCH)), + ] +) + +# pyrefly: ignore [no-matching-overload] +CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict( + [ + ("PYTORCH_NO_CUDA_MEMORY_CACHING", ("PYTORCH_NO_CUDA_MEMORY_CACHING", API_CAFFE2)), + ("PYTORCH_CUDA_ALLOC_CONF", ("PYTORCH_CUDA_ALLOC_CONF", API_CAFFE2)), + ("cuda_stream", ("hip_stream", API_CAFFE2)), + # if the header is a native hip folder (under hip directory), + # there is no need to add a hip path to it; the trie in hipify script + # takes this mapping order to forbid further replacement + ("/hip/", ("/hip/", API_CAFFE2)), + ("/context_gpu", ("/hip/context_gpu", API_CAFFE2)), + ("/common_gpu", ("/hip/common_gpu", API_CAFFE2)), + ("/cuda_nccl_gpu", ("/hip/hip_nccl_gpu", API_CAFFE2)), + ("/mixed_utils", ("/hip/mixed_utils", API_CAFFE2)), + ("/operator_fallback_gpu", ("/hip/operator_fallback_gpu", API_CAFFE2)), + ( + "/spatial_batch_norm_op_impl", + ("/hip/spatial_batch_norm_op_impl", API_CAFFE2), + ), + ( + "/recurrent_network_executor_gpu", + ("/hip/recurrent_network_executor_gpu", API_CAFFE2), + ), + ( + "/generate_proposals_op_util_nms_gpu", + ("/hip/generate_proposals_op_util_nms_gpu", API_CAFFE2), + ), + ("/max_pool_with_index_gpu", ("/hip/max_pool_with_index_gpu", API_CAFFE2)), + ("/THCCachingAllocator_gpu", ("/hip/THCCachingAllocator_gpu", API_CAFFE2)), + ("/top_k_heap_selection", ("/hip/top_k_heap_selection", API_CAFFE2)), + ("/top_k_radix_selection", ("/hip/top_k_radix_selection", API_CAFFE2)), + ("/GpuAtomics", ("/hip/GpuAtomics", API_CAFFE2)), + ("/GpuDefs", ("/hip/GpuDefs", API_CAFFE2)), + ("/GpuScanUtils", ("/hip/GpuScanUtils", API_CAFFE2)), + ("/GpuBitonicSort", ("/hip/GpuBitonicSort", API_CAFFE2)), + ("/math/reduce.cuh", ("/math/hip/reduce.cuh", API_CAFFE2)), + ("/sgd/adagrad_fused_op_gpu.cuh", ("/sgd/hip/adagrad_fused_op_gpu.cuh", API_CAFFE2)), + ("/operators/segment_reduction_op_gpu.cuh", ("/operators/hip/segment_reduction_op_gpu.cuh", API_CAFFE2)), + ("/gather_op.cuh", ("/hip/gather_op.cuh", API_CAFFE2)), + ("caffe2/core/common_cudnn.h", ("caffe2/core/hip/common_miopen.h", API_CAFFE2)), + ("REGISTER_CUDA_OPERATOR", ("REGISTER_HIP_OPERATOR", API_CAFFE2)), + ("CUDA_1D_KERNEL_LOOP", ("HIP_1D_KERNEL_LOOP", API_CAFFE2)), + ("CUDAContext", ("HIPContext", API_CAFFE2)), + ("CAFFE_CUDA_NUM_THREADS", ("CAFFE_HIP_NUM_THREADS", API_CAFFE2)), + ("HasCudaGPU", ("HasHipGPU", API_CAFFE2)), + ("__expf", ("expf", API_CAFFE2)), + ("CUBLAS_ENFORCE", ("HIPBLAS_ENFORCE", API_CAFFE2)), + ("CUBLAS_CHECK", ("HIPBLAS_CHECK", API_CAFFE2)), + ("cublas_handle", ("hipblas_handle", API_CAFFE2)), + ("CURAND_ENFORCE", ("HIPRAND_ENFORCE", API_CAFFE2)), + ("CURAND_CHECK", ("HIPRAND_CHECK", API_CAFFE2)), + ("curandGenerateUniform", ("hiprandGenerateUniform", API_CAFFE2)), + ("curand_generator", ("hiprand_generator", API_CAFFE2)), + ("CaffeCudaGetDevice", ("CaffeHipGetDevice", API_CAFFE2)), + # do not rename CUDA_KERNEL_ASSERT, lazyInitCUDA in caffe2 sources + # the ordered dict guarantees this pattern will match first, before "CUDA" + ("CUDA_KERNEL_ASSERT", ("CUDA_KERNEL_ASSERT", API_CAFFE2)), + ("lazyInitCUDA", ("lazyInitCUDA", API_CAFFE2)), + ("CUDA_VERSION", ("TORCH_HIP_VERSION", API_CAFFE2)), + ("CUDA", ("HIP", API_CAFFE2)), + ("Cuda", ("Hip", API_CAFFE2)), + ("cuda_", ("hip_", API_CAFFE2)), + ("_cuda", ("_hip", API_CAFFE2)), + ("CUDNN", ("MIOPEN", API_CAFFE2)), + ("CuDNN", ("MIOPEN", API_CAFFE2)), + ("cudnn", ("miopen", API_CAFFE2)), + ("namespace cuda", ("namespace hip", API_CAFFE2)), + ("cuda::CUDAGuard", ("hip::HIPGuard", API_CAFFE2)), + ("cuda::OptionalCUDAGuard", ("hip::OptionalHIPGuard", API_CAFFE2)), + ("cuda::CUDAStreamGuard", ("hip::HIPStreamGuard", API_CAFFE2)), + ("cuda::OptionalCUDAStreamGuard", ("hip::OptionalHIPStreamGuard", API_CAFFE2)), + ("c10/cuda/CUDAGuard.h", ("c10/hip/HIPGuard.h", API_CAFFE2)), + ("gloo/cuda", ("gloo/hip", API_CAFFE2)), + ] +) + +# We must treat very carefully here. Blanket conversions like are done +# in CAFFE2_SPECIFIC_MAPPINGS are not presently supported on PyTorch, +# because a regex for CUDA will also match a filename like CUDAGuard.h, +# but the HIPIFY script doesn't presently move the file and so the substitution +# will be invalid. Instead, we specifically list out every identifier +# and file from c10/cuda which may be used externally, and do substitutions this +# way. +# +# NB: if you want a transformation to ONLY apply to the c10/ directory, +# put it as API_CAFFE2 +# pyrefly: ignore [no-matching-overload] +C10_MAPPINGS = collections.OrderedDict( + [ + ("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)), + ("CUDA_LAUNCH_BLOCKING=1", ("AMD_SERIALIZE_KERNEL=3", API_C10)), + ("CUDA_LAUNCH_BLOCKING", ("AMD_SERIALIZE_KERNEL", API_C10)), + ("cuda::compat::", ("hip::compat::", API_C10)), + ("c10/cuda/CUDAAlgorithm.h", ("c10/hip/HIPAlgorithm.h", API_C10)), + ("c10/cuda/CUDADeviceAssertion.h", ("c10/hip/HIPDeviceAssertion.h", API_C10)), + ("c10/cuda/CUDADeviceAssertionHost.h", ("c10/hip/HIPDeviceAssertionHost.h", API_C10)), + ("c10/cuda/CUDAException.h", ("c10/hip/HIPException.h", API_C10)), + ("c10/cuda/CUDAMacros.h", ("c10/hip/HIPMacros.h", API_C10)), + ("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)), + ("c10/cuda/CUDAFunctions.h", ("c10/hip/HIPFunctions.h", API_C10)), + ("c10/cuda/CUDAMiscFunctions.h", ("c10/hip/HIPMiscFunctions.h", API_C10)), + ("c10/cuda/CUDAStream.h", ("c10/hip/HIPStream.h", API_C10)), + ("c10/cuda/CUDAGraphsC10Utils.h", ("c10/hip/HIPGraphsC10Utils.h", API_C10)), + ("c10/cuda/CUDAAllocatorConfig.h", ("c10/hip/HIPAllocatorConfig.h", API_C10)), + ("c10/cuda/CUDACachingAllocator.h", ("c10/hip/HIPCachingAllocator.h", API_C10)), + ("c10/cuda/impl/CUDATest.h", ("c10/hip/impl/HIPTest.h", API_C10)), + ("c10/cuda/impl/CUDAGuardImpl.h", ("c10/hip/impl/HIPGuardImpl.h", API_C10)), + ( + "c10/cuda/impl/cuda_cmake_macros.h", + ("c10/hip/impl/hip_cmake_macros.h", API_C10), + ), + ("C10_CUDA_CHECK", ("C10_HIP_CHECK", API_C10)), + ("C10_CUDA_CHECK_WARN", ("C10_HIP_CHECK_WARN", API_C10)), + ("C10_CUDA_ERROR_HANDLED", ("C10_HIP_ERROR_HANDLED", API_C10)), + ("C10_CUDA_IGNORE_ERROR", ("C10_HIP_IGNORE_ERROR", API_C10)), + ("C10_CUDA_CLEAR_ERROR", ("C10_HIP_CLEAR_ERROR", API_C10)), + ("c10::cuda", ("c10::hip", API_C10)), + ("cuda::CUDAStream", ("hip::HIPStream", API_C10)), + ("CUDAStream", ("HIPStream", API_C10)), + # This substitution is not permissible, because there's another copy of this + # function in torch/cuda.h + # ("cuda::device_count", ("hip::device_count", API_C10)), + ("cuda::current_device", ("hip::current_device", API_C10)), + ("cuda::set_device", ("hip::set_device", API_C10)), + ("cuda::device_synchronize", ("hip::device_synchronize", API_C10)), + ("cuda::getStreamFromPool", ("hip::getStreamFromPool", API_C10)), + ("getStreamFromPool", ("getStreamFromPool", API_C10)), + ("cuda::getDefaultCUDAStream", ("hip::getDefaultHIPStream", API_C10)), + ("getDefaultCUDAStream", ("getDefaultHIPStream", API_C10)), + ("cuda::getCurrentCUDAStream", ("hip::getCurrentHIPStream", API_C10)), + ("getCurrentCUDAStream", ("getCurrentHIPStream", API_C10)), + ("cuda::get_cuda_check_prefix", ("hip::get_cuda_check_prefix", API_C10)), + ("cuda::setCurrentCUDAStream", ("hip::setCurrentHIPStream", API_C10)), + ("setCurrentCUDAStream", ("setCurrentHIPStream", API_C10)), + ("cuda::CUDACachingAllocator", ("hip::HIPCachingAllocator", API_C10)), + ("CUDACachingAllocator", ("HIPCachingAllocator", API_C10)), + ("cuda::CUDAAllocatorConfig", ("hip::HIPAllocatorConfig", API_C10)), + ("CUDAAllocatorConfig", ("HIPAllocatorConfig", API_C10)), + ("pinned_use_cuda_host_register", ("pinned_use_hip_host_register", API_C10)), + ("c10::cuda::CUDAAllocator", ("c10::hip::HIPAllocator", API_C10)), + ("cuda::CUDAAllocator", ("hip::HIPAllocator", API_C10)), + ("CUDAStreamCaptureModeGuard", ("HIPStreamCaptureModeGuard", API_C10)), + ("cuda::CUDAStreamCaptureModeGuard", ("cuda::HIPStreamCaptureModeGuard", API_C10)), + ("CUDAAllocator", ("HIPAllocator", API_C10)), + ("C10_CUDA_KERNEL_LAUNCH_CHECK", ("C10_HIP_KERNEL_LAUNCH_CHECK", API_C10)), + ("CUDAKernelLaunchRegistry", ("HIPKernelLaunchRegistry", API_C10)), + ("c10::cuda::get_cuda_check_suffix", ("c10::hip::get_hip_check_suffix", API_C10)), + ("c10::cuda::get_cuda_error_help", ("c10::hip::get_hip_error_help", API_C10)), + ] +) + +# NB: C10 mappings are more specific than Caffe2 mappings, so run them +# first +CUDA_TO_HIP_MAPPINGS = [ + CUDA_IDENTIFIER_MAP, + CUDA_TYPE_NAME_MAP, + CUDA_INCLUDE_MAP, + CUDA_SPECIAL_MAP, + C10_MAPPINGS, + PYTORCH_SPECIFIC_MAPPINGS, + CAFFE2_SPECIFIC_MAPPINGS, +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/hipify_python.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/hipify_python.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b2a863d60f398587bf62a9a60cc0dae03532c3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/hipify_python.py @@ -0,0 +1,1186 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs +""" The Python Hipify script. +## +# Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved. +# 2017-2018 Advanced Micro Devices, Inc. and +# Facebook Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +import argparse +import fnmatch +import re +import shutil +import sys +import os + +from . import constants +from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS +from .cuda_to_hip_mappings import MATH_TRANSPILATIONS + +from collections.abc import Iterator +from collections.abc import Mapping, Iterable +from enum import Enum +import functools +import hashlib + +class CurrentState(Enum): + INITIALIZED = 1 + DONE = 2 + +class HipifyResult: + def __init__(self, current_state, hipified_path) -> None: + self.current_state = current_state + self.hipified_path = hipified_path + self.status = "" + + def __str__(self) -> str: + return (f"HipifyResult:: current_state: {self.current_state}, hipified_path : {self.hipified_path}, status: {self.status}") + +HipifyFinalResult = dict[str, HipifyResult] +HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n" +HIPIFY_FINAL_RESULT: HipifyFinalResult = {} + +# Hardcode the PyTorch template map +"""This dictionary provides the mapping from PyTorch kernel template types +to their actual types.""" +PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"} + +__all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter', + 'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group', + 'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared', + 'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_special_file', 'is_caffe2_gpu_file', + 'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header', + 'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'CurrentState', 'HipifyResult', 'hipify'] + + +class InputError(Exception): + # Exception raised for errors in the input. + + def __init__(self, message) -> None: + super().__init__(message) + self.message = message + + def __str__(self) -> str: + return f"Input error: {self.message}" + + +def openf(filename, mode): + return open(filename, mode, errors='ignore') + + +# Color coding for printing +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + +# To the programmer, the output of hipify most likely are intermediates. +# This class allows users of hipify to ask for a cleanup by running the +# hipify and compilation in a with instantiating this context manager class +# with keep_intermediates=False. +# The main usecase is the cpp_extensions, specifically the load method. +# It is a good idea to keep intermediates (in case of errors or to +# not recompile unchanged files), but in cases where you don't want to +# keep them (e.g. in the CI), this can be used to remove files. +class GeneratedFileCleaner: + """Context Manager to clean up generated files""" + def __init__(self, keep_intermediates=False) -> None: + self.keep_intermediates = keep_intermediates + self.files_to_clean = set() + self.dirs_to_clean = [] + + def __enter__(self): + return self + + def open(self, fn, *args, **kwargs): + if not os.path.exists(fn): + self.files_to_clean.add(os.path.abspath(fn)) + # pyrefly: ignore [not-iterable] + return open(fn, *args, **kwargs) + + def makedirs(self, dn, exist_ok=False) -> None: + parent, n = os.path.split(dn) + if not n: + parent, n = os.path.split(parent) + if parent and n and not os.path.exists(parent): + self.makedirs(parent, exist_ok=True) + if not os.path.isdir(dn) or not exist_ok: + os.mkdir(dn) + self.dirs_to_clean.append(os.path.abspath(dn)) + + def __exit__(self, type, value, traceback): + if not self.keep_intermediates: + for f in self.files_to_clean: + os.unlink(f) + for d in self.dirs_to_clean[::-1]: + os.rmdir(d) + +# Follow UNIX convention for paths to use '/' instead of '\\' on Windows +def _to_unix_path(path: str) -> str: + return path.replace(os.sep, '/') + +def match_extensions(filename: str, extensions: Iterable) -> bool: + """Helper method to see if filename ends with certain extension""" + return any(filename.endswith(e) for e in extensions) + + +def _fnmatch(filepath, patterns): + return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns) + + +def matched_files_iter( + root_path: str, + includes: Iterable = (), + ignores: Iterable = (), + extensions: Iterable = (), + out_of_place_only: bool = False, + is_pytorch_extension: bool = False) -> Iterator[str]: + + exact_matches = set(includes) + + # This is a very rough heuristic; really, we want to avoid scanning + # any file which is not checked into source control, but this script + # needs to work even if you're in a Git or Hg checkout, so easier to + # just block the biggest time sinks that won't matter in the + # end. + for (abs_dirpath, dirs, filenames) in os.walk(root_path, topdown=True): + rel_dirpath = os.path.relpath(abs_dirpath, root_path) + if rel_dirpath == '.': + # Blah blah blah O(n) blah blah + if ".git" in dirs: + dirs.remove(".git") + if "build" in dirs: + dirs.remove("build") + if "third_party" in dirs: + dirs.remove("third_party") + dirs.append("third_party/nvfuser") + for filename in filenames: + filepath = _to_unix_path(os.path.join(abs_dirpath, filename)) + rel_filepath = _to_unix_path(os.path.join(rel_dirpath, filename)) + # We respect extensions, UNLESS you wrote the entire + # filename verbatim, in which case we always accept it + if ( + _fnmatch(filepath, includes) + and (not _fnmatch(filepath, ignores)) + and (match_extensions(filepath, extensions) or filepath in exact_matches) + ): + if not is_pytorch_extension: # for pytorch extensions, consider all files + if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(rel_filepath): + continue + if out_of_place_only and not is_out_of_place(rel_filepath): + continue + yield filepath + + +def preprocess_file_and_save_result( + output_directory: str, + filepath: str, + all_files: Iterable, + header_include_dirs: Iterable, + stats: dict[str, list], + hip_clang_launch: bool, + is_pytorch_extension: bool, + clean_ctx: GeneratedFileCleaner, + show_progress: bool) -> None: + fin_path = os.path.abspath(os.path.join(output_directory, filepath)) + hipify_result = HipifyResult(current_state=CurrentState.INITIALIZED, hipified_path=fin_path) + HIPIFY_FINAL_RESULT[fin_path] = hipify_result + result = preprocessor(output_directory, filepath, all_files, header_include_dirs, stats, + hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) + + # Show what happened + if show_progress and "ignored" not in result.status: + print( + fin_path, "->", + result.hipified_path, result.status, flush=True) + + HIPIFY_FINAL_RESULT[fin_path] = result + + +def compute_stats(stats) -> None: + unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]} + + # Print the number of unsupported calls + print(f"Total number of unsupported CUDA function calls: {len(unsupported_calls):d}") + + # Print the list of unsupported calls + print(", ".join(unsupported_calls)) + + # Print the number of kernel launches + print(f"\nTotal number of replaced kernel launches: {len(stats['kernel_launches']):d}") + + +def add_dim3(kernel_string, cuda_kernel): + '''adds dim3() to the second and third arguments in the kernel launch''' + count = 0 + closure = 0 + kernel_string = kernel_string.replace("<<<", "").replace(">>>", "") + arg_locs: list[dict[str, int]] = [{} for _ in range(2)] + arg_locs[count]['start'] = 0 + for ind, c in enumerate(kernel_string): + if count > 1: + break + if c == "(": + closure += 1 + elif c == ")": + closure -= 1 + if (c == "," or ind == len(kernel_string) - 1) and closure == 0: + arg_locs[count]['end'] = ind + (c != ",") + count += 1 + if count < 2: + arg_locs[count]['start'] = ind + 1 + + first_arg_raw = kernel_string[arg_locs[0]['start']:arg_locs[0]['end'] + 1] + second_arg_raw = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']] + + first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ") + second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ") + + first_arg_dim3 = f"dim3({first_arg_clean})" + second_arg_dim3 = f"dim3({second_arg_clean})" + + first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3) + second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3) + cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3) + return cuda_kernel + + +RE_KERNEL_LAUNCH = re.compile(r'([ ]+)(detail?)::[ ]+\\\n[ ]+') + + +def processKernelLaunches(string, stats): + """ Replace the CUDA style Kernel launches with the HIP style kernel launches.""" + # Concat the namespace with the kernel names. (Find cleaner way of doing this later). + string = RE_KERNEL_LAUNCH.sub(lambda inp: f"{inp.group(1)}{inp.group(2)}::", string) + + def grab_method_and_template(in_kernel): + # The positions for relevant kernel components. + pos = { + "kernel_launch": {"start": in_kernel["start"], "end": in_kernel["end"]}, + "kernel_name": {"start": -1, "end": -1}, + "template": {"start": -1, "end": -1} + } + + # Count for balancing template + count = {"<>": 0} + + # Status for whether we are parsing a certain item. + START = 0 + AT_TEMPLATE = 1 + AFTER_TEMPLATE = 2 + AT_KERNEL_NAME = 3 + + status = START + + # Parse the string character by character + for i in range(pos["kernel_launch"]["start"] - 1, -1, -1): + char = string[i] + + # Handle Templating Arguments + if status in (START, AT_TEMPLATE): + if char == ">": + if status == START: + status = AT_TEMPLATE + pos["template"]["end"] = i + count["<>"] += 1 + + if char == "<": + count["<>"] -= 1 + if count["<>"] == 0 and (status == AT_TEMPLATE): + pos["template"]["start"] = i + status = AFTER_TEMPLATE + + # Handle Kernel Name + if status != AT_TEMPLATE: + if string[i].isalnum() or string[i] in {'(', ')', '_', ':', '#'}: + if status != AT_KERNEL_NAME: + status = AT_KERNEL_NAME + pos["kernel_name"]["end"] = i + + # Case: Kernel name starts the string. + if i == 0: + pos["kernel_name"]["start"] = 0 + + # Finished + return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])] + + else: + # Potential ending point if we're already traversing a kernel's name. + if status == AT_KERNEL_NAME: + pos["kernel_name"]["start"] = i + + # Finished + return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])] + + def find_kernel_bounds(string): + """Finds the starting and ending points for all kernel launches in the string.""" + kernel_end = 0 + kernel_positions = [] + + # Continue until we cannot find any more kernels anymore. + while string.find("<<<", kernel_end) != -1: + # Get kernel starting position (starting from the previous ending point) + kernel_start = string.find("<<<", kernel_end) + + # Get kernel ending position (adjust end point past the >>>) + kernel_end = string.find(">>>", kernel_start) + 3 + if kernel_end <= 0: + raise InputError("no kernel end found") + + # Add to list of traversed kernels + kernel_positions.append({"start": kernel_start, "end": kernel_end, + "group": string[kernel_start: kernel_end]}) + + return kernel_positions + + # Replace comments and string literals from the code so that find_kernel_bounds does not + # wrongly capture kernels in comments and string literals. + # This function replaces them with "x" to keep positions. + def mask_comments(string): + in_comment = '' + prev_c = '' + new_string = '' + for c in string: + if in_comment == '': + # Outside comments + if c == '/' and prev_c == '/': + in_comment = '//' + elif c == '*' and prev_c == '/': + in_comment = '/*' + elif c == '"' and prev_c != '\\' and prev_c != "'": + in_comment = '"' + elif in_comment == '//': + # In // xxx + if c == '\r' or c == '\n': + in_comment = '' + elif in_comment == '/*': + # In /* xxx */ + if c == '/' and prev_c == '*': + in_comment = '' + elif in_comment == '"': + # In "" + if c == '"' and prev_c != '\\': + in_comment = '' + prev_c = c + if in_comment == '': + new_string += c + else: + new_string += 'x' + return new_string + + # Grab positional ranges of all kernel launches + get_kernel_positions = list(find_kernel_bounds(mask_comments(string))) + output_string = string + + # Replace each CUDA kernel with a HIP kernel. + for kernel in get_kernel_positions: + # Get kernel components + params = grab_method_and_template(kernel) + + # Find parenthesis after kernel launch + parenthesis = string.find("(", kernel["end"]) + + # Extract cuda kernel + cuda_kernel = string[params[0]["start"]:parenthesis + 1] + kernel_string = string[kernel['start']:kernel['end']] + end_param_index = 0 if params[1]['end'] == -1 else 1 + kernel_name_with_template = string[params[0]['start']:params[end_param_index]['end'] + 1] + cuda_kernel_dim3 = add_dim3(kernel_string, cuda_kernel) + # Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size) + num_klp = len(extract_arguments(0, kernel["group"].replace("<<<", "(").replace(">>>", ")"))) + + hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace( + ">>>", ", 0" * (4 - num_klp) + ">>>").replace("<<<", ", ").replace( + ">>>", ", ").replace(kernel_name_with_template, "(" + kernel_name_with_template + ")") + + # Replace cuda kernel with hip kernel + output_string = output_string.replace(cuda_kernel, hip_kernel) + + # Update the statistics + stats["kernel_launches"].append(hip_kernel) + + return output_string + + +def find_closure_group(input_string, start, group): + """Generalization for finding a balancing closure group + + if group = ["(", ")"], then finds the first balanced parentheses. + if group = ["{", "}"], then finds the first balanced bracket. + + Given an input string, a starting position in the input string, and the group type, + find_closure_group returns the positions of group[0] and group[1] as a tuple. + + Example: + >>> find_closure_group("(hi)", 0, ["(", ")"]) + (0, 3) + """ + + inside_parenthesis = False + parens = 0 + pos = start + p_start, p_end = -1, -1 + + while pos < len(input_string): + if input_string[pos] == group[0]: + if inside_parenthesis is False: + inside_parenthesis = True + parens = 1 + p_start = pos + else: + parens += 1 + elif input_string[pos] == group[1] and inside_parenthesis: + parens -= 1 + + if parens == 0: + p_end = pos + return p_start, p_end + + pos += 1 + return None, None + + +def find_bracket_group(input_string, start): + """Finds the first balanced parentheses.""" + return find_closure_group(input_string, start, group=["{", "}"]) + + +def find_parentheses_group(input_string, start): + """Finds the first balanced bracket.""" + return find_closure_group(input_string, start, group=["(", ")"]) + + +RE_ASSERT = re.compile(r"\bassert[ ]*\(") + + +def replace_math_functions(input_string): + """FIXME: Temporarily replace std:: invocations of math functions + with non-std:: versions to prevent linker errors NOTE: This + can lead to correctness issues when running tests, since the + correct version of the math function (exp/expf) might not get + called. Plan is to remove this function once HIP supports + std:: math function calls inside device code + + """ + output_string = input_string + for func in MATH_TRANSPILATIONS: + output_string = output_string.replace(fr'{func}(', f'{MATH_TRANSPILATIONS[func]}(') + + return output_string + + +RE_SYNCTHREADS = re.compile(r":?:?\b(__syncthreads)\b(\w*\()") + + +def hip_header_magic(input_string): + """If the file makes kernel builtin calls and does not include the cuda_runtime.h header, + then automatically add an #include to match the "magic" includes provided by NVCC. + TODO: + Update logic to ignore cases where the cuda_runtime.h is included by another file. + """ + + # Copy the input. + output_string = input_string + + # Check if one of the following headers is already included. + headers = ["hip/hip_runtime.h", "hip/hip_runtime_api.h"] + if any(re.search(fr'#include ("{ext}"|<{ext}>)', output_string) for ext in headers): + return output_string + + # Rough logic to detect if we're inside device code + hasDeviceLogic: int + hasDeviceLogic = "hipLaunchKernelGGL" in output_string + hasDeviceLogic += "__global__" in output_string + hasDeviceLogic += "__shared__" in output_string + hasDeviceLogic += RE_SYNCTHREADS.search(output_string) is not None + + # If device logic found, provide the necessary header. + if hasDeviceLogic: + output_string = '#include "hip/hip_runtime.h"\n' + input_string + + return output_string + + +RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;") + + +def replace_extern_shared(input_string): + """ + Match 'extern __shared__ type foo[];' syntax and use HIP_DYNAMIC_SHARED() MACRO instead. + See: https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__ + Examples: + "extern __shared__ char smemChar[];" + => "HIP_DYNAMIC_SHARED( char, smemChar)" + "extern __shared__ unsigned char smem[];" + => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)" + """ + output_string = input_string + output_string = RE_EXTERN_SHARED.sub( + lambda inp: f"HIP_DYNAMIC_SHARED({inp.group(1) or ''} {inp.group(2)}, {inp.group(3)})", output_string) + + return output_string + + +def get_hip_file_path(rel_filepath, is_pytorch_extension=False): + """ + Returns the new name of the hipified file + """ + # At the moment, some PyTorch source files are HIPified in place. The predicate + # is_out_of_place tells us if this is the case or not. + if os.path.isabs(rel_filepath): + raise AssertionError("rel_filepath must be a relative path") + if not is_pytorch_extension and not is_out_of_place(rel_filepath): + return rel_filepath + + dirpath, filename = os.path.split(rel_filepath) + root, ext = os.path.splitext(filename) + + # Here's the plan: + # + # In general, we need to disambiguate the HIPified filename so that + # it gets a different name from the original filename, so + # that we don't overwrite the original file + # + # There's a lot of different naming conventions across PyTorch + # and Caffe2, but the general recipe is to convert occurrences + # of cuda/gpu to hip, and add hip if there are no occurrences + # of cuda/gpu anywhere. + # + # Concretely, we do the following: + # + # - If there is a directory component named "cuda", replace + # it with "hip", AND + # + # - If the file name contains "CUDA", replace it with "HIP", AND + # + # - ALWAYS replace '.cu' with '.hip', because those files + # contain CUDA kernels that needs to be hipified and processed with + # hip compiler + # + # - If we are not hipifying a PyTorch extension, and the parent + # directory name did not change as a result of the above + # transformations, insert "hip" in the file path + # as the direct parent folder of the file + # + # - If we are hipifying a PyTorch extension, and the parent directory + # name as well as the filename (incl. extension) did not change as + # a result of the above transformations, insert "_hip" in the filename + # + # This isn't set in stone; we might adjust this to support other + # naming conventions. + + if ext == '.cu': + ext = '.hip' + + orig_filename = filename + orig_dirpath = dirpath + + dirpath = dirpath.replace('cuda', 'hip') + dirpath = dirpath.replace('CUDA', 'HIP') + dirpath = dirpath.replace('THC', 'THH') + + root = root.replace('cuda', 'hip') + root = root.replace('CUDA', 'HIP') + # Special case to handle caffe2/core/THCCachingAllocator + if dirpath != "caffe2/core": + root = root.replace('THC', 'THH') + + if not is_pytorch_extension and dirpath == orig_dirpath: + dirpath = os.path.join(dirpath, 'hip') + + if is_pytorch_extension and dirpath == orig_dirpath and (root + ext) == orig_filename: + root = root + "_hip" + + return os.path.join(dirpath, root + ext) + + +def is_out_of_place(rel_filepath) -> bool: + if os.path.isabs(rel_filepath): + raise AssertionError("rel_filepath must be a relative path") + if rel_filepath.startswith("torch/"): + return False + if rel_filepath.startswith("third_party/nvfuser/"): + return False + if rel_filepath.startswith("tools/autograd/templates/"): + return False + return True + + +# Keep this synchronized with includes/ignores in build_amd.py +def is_pytorch_file(rel_filepath) -> bool: + if os.path.isabs(rel_filepath): + raise AssertionError("rel_filepath must be a relative path") + if rel_filepath.startswith("aten/"): + if rel_filepath.startswith("aten/src/ATen/core/"): + return False + return True + if rel_filepath.startswith("torch/"): + return True + if rel_filepath.startswith("third_party/nvfuser/"): + return True + if rel_filepath.startswith("third_party/fbgemm/"): + return True + if rel_filepath.startswith("tools/autograd/templates/"): + return True + return False + + +def is_cusparse_file(rel_filepath): + if is_pytorch_file(rel_filepath): + return "sparse" in rel_filepath.lower() + return False + + +def is_special_file(rel_filepath) -> bool: + if is_pytorch_file(rel_filepath): + if "sparse" in rel_filepath.lower(): + return True + elif "linalg" in rel_filepath.lower(): + if "batchlinearalgebralibblas" in rel_filepath.lower(): + return False # don't use "special" mappings for this specific linalg cublas file + return True + return False + +def is_caffe2_gpu_file(rel_filepath): + if os.path.isabs(rel_filepath): + raise AssertionError("rel_filepath must be a relative path") + if rel_filepath.startswith("c10/cuda"): + return True + filename = os.path.basename(rel_filepath) + _, ext = os.path.splitext(filename) + # pyrefly: ignore [unsupported-operation] + return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename) + +class TrieNode: + """A Trie node whose children are represented as a directory of char: TrieNode. + A special char '' represents end of word + """ + + def __init__(self) -> None: + self.children = {} + +class Trie: + """Creates a Trie out of a list of words. The trie can be exported to a Regex pattern. + The corresponding Regex should match much faster than a simple Regex union.""" + + def __init__(self) -> None: + """Initialize the trie with an empty root node.""" + self.root = TrieNode() + self._hash = hashlib.md5(usedforsecurity=False) + self._digest = self._hash.digest() + + def add(self, word) -> None: + """Add a word to the Trie. """ + self._hash.update(word.encode()) + self._digest = self._hash.digest() + node = self.root + + for char in word: + node.children.setdefault(char, TrieNode()) + node = node.children[char] + node.children[''] = True # Mark the end of the word + + def dump(self): + """Return the root node of Trie. """ + return self.root + + def quote(self, char): + """ Escape a char for regex. """ + return re.escape(char) + + def search(self, word): + """Search whether word is present in the Trie. + Returns True if yes, else return False""" + node = self.root + for char in word: + if char in node.children: + node = node.children[char] + else: + return False + + # make sure to check the end-of-word marker present + return '' in node.children + + @functools.lru_cache # noqa: B019 + def _pattern(self, root, digest): + """Convert a Trie into a regular expression pattern + + Memoized on the hash digest of the trie, which is built incrementally + during add(). + """ + node = root + + if "" in node.children and len(node.children.keys()) == 1: + return None + + alt = [] # store alternative patterns + cc = [] # store char to char classes + q = 0 # for node representing the end of word + for char in sorted(node.children.keys()): + if isinstance(node.children[char], TrieNode): + try: + recurse = self._pattern(node.children[char], self._digest) + alt.append(self.quote(char) + recurse) + except Exception: + cc.append(self.quote(char)) + else: + q = 1 + cconly = not len(alt) > 0 + + if len(cc) > 0: + if len(cc) == 1: + alt.append(cc[0]) + else: + alt.append('[' + ''.join(cc) + ']') + + if len(alt) == 1: + result = alt[0] + else: + result = "(?:" + "|".join(alt) + ")" + + if q: + if cconly: + result += "?" + else: + result = f"(?:{result})?" + return result + + def pattern(self): + """Export the Trie to a regex pattern.""" + return self._pattern(self.root, self._digest) + + def export_to_regex(self): + """Export the Trie to a regex pattern.""" + return self._pattern(self.root, self._digest) + +CAFFE2_TRIE = Trie() +CAFFE2_MAP = {} +PYTORCH_TRIE = Trie() +PYTORCH_MAP: dict[str, object] = {} + +# In PyTorch, we map cuBLAS->rocBLAS and cuSPARSE->hipSPARSE. Note the prefix, roc versus hip. +# The 'hip' APIs offer a more direct CUDA-friendly mapping, but calling rocBLAS directly has better performance. +# Unfortunately, the roc* types and hip* types differ, i.e., rocblas_float_complex versus hipComplex. +# In the case of SPARSE, we must use the hip types for complex instead of the roc types, +# but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority. +# Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place. +# When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices. +# Similarly, "linalg" files require rocBLAS -> hipSOLVER so they also need special handling. +PYTORCH_SPECIAL_MAP = {} + +for mapping in CUDA_TO_HIP_MAPPINGS: + if not isinstance(mapping, Mapping): + raise TypeError("Expected each mapping in CUDA_TO_HIP_MAPPINGS to be a Mapping") + for src, value in mapping.items(): + dst = value[0] + meta_data = value[1:] + if constants.API_CAFFE2 not in meta_data: + PYTORCH_TRIE.add(src) + # if src is already in PYTORCH_MAP and dst belongs to API_SPECIAL + # do not overwrite PYTORCH_MAP, store dst separately + if constants.API_SPECIAL in meta_data and PYTORCH_MAP.get(src, ""): + PYTORCH_SPECIAL_MAP[src] = dst + else: + PYTORCH_MAP[src] = dst + if constants.API_PYTORCH not in meta_data and constants.API_SPECIAL not in meta_data: + CAFFE2_TRIE.add(src) + CAFFE2_MAP[src] = dst +RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.export_to_regex()) +RE_PYTORCH_PREPROCESSOR = re.compile(fr'(?<=\W)({PYTORCH_TRIE.export_to_regex()})(?=\W)') + +RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"') +RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>') +RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"') +RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh + +""" +Returns a HipifyResult object with the following details: + "hipified_path" : absolute path of hipified source file + "status" : "ok" if hipified file was written out + "skipped" if an identical hipified file already existed or hipified file couldn't be written out + "ignored" if the source file was a hipified file itself or not meant to be hipified + "current_state" : CurrentState.INITIALIZED if source file is first ready to be hipified + CurrentState.DONE if source file is done with hipification process +""" + + +def preprocessor( + output_directory: str, + filepath: str, + all_files: Iterable, + header_include_dirs: Iterable, + stats: dict[str, list], + hip_clang_launch: bool, + is_pytorch_extension: bool, + clean_ctx: GeneratedFileCleaner, + show_progress: bool) -> HipifyResult: + """ Executes the CUDA -> HIP conversion on the specified file. """ + fin_path = os.path.abspath(os.path.join(output_directory, filepath)) + filepath = _to_unix_path(filepath) + hipify_result = HIPIFY_FINAL_RESULT[fin_path] + if filepath not in all_files: + hipify_result.hipified_path = None + hipify_result.status = "[ignored, not to be hipified]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + + rel_filepath = _to_unix_path(os.path.relpath(filepath, output_directory)) + + with open(fin_path, encoding='utf-8') as fin: + if fin.readline() == HIPIFY_C_BREADCRUMB: + hipify_result.hipified_path = None + hipify_result.status = "[ignored, input is hipified output]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + fin.seek(0) + output_source = fin.read() + + orig_output_source = output_source + + # get_hip_file_path needs a relative path to work correctly + fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(rel_filepath, is_pytorch_extension))) + if not os.path.exists(os.path.dirname(fout_path)): + clean_ctx.makedirs(os.path.dirname(fout_path)) + + # unsupported_calls statistics reporting is broken atm + def pt_repl(m): + return PYTORCH_MAP[m.group(0)] + + def pt_special_repl(m): + # checks SPECIAL map first, and if a miss occurs, falls back to pytorch mappings + return PYTORCH_SPECIAL_MAP.get(m.group(0), pt_repl(m)) + + + if is_pytorch_extension: + output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source) + else: + if is_special_file(rel_filepath): + output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_special_repl, output_source) + elif is_pytorch_file(rel_filepath): + output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source) + else: + def c2_repl(m): + return CAFFE2_MAP[m.group(0)] + output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source) + + # Header rewrites + def mk_repl(templ, include_current_dir=True): + def repl(m): + f = m.group(1) + filename = os.path.basename(f) + if ( + f.startswith(("ATen/cuda", + "ATen/native/cuda", + "ATen/native/nested/cuda", + "ATen/native/quantized/cuda", + "ATen/native/sparse/cuda", + "ATen/native/transformers/cuda", + "THC/")) or + (f.startswith("THC") and not f.startswith("THCP")) + ): + return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension)) + # if filename is one of the files being hipified for this extension + if (is_pytorch_extension and any(s.endswith(filename) for s in all_files)): + header_dir = None + header_filepath = None + # If include_current_dir True, look first in same dir as the including source file + if include_current_dir: + header_dir_to_check = os.path.dirname(fin_path) + header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f)) + if os.path.exists(header_path_to_check): + header_dir = header_dir_to_check + header_filepath = header_path_to_check + # If not found, look in include dirs one by one and first match wins + if header_filepath is None: + for header_include_dir in header_include_dirs: + header_dir_to_check = os.path.join(output_directory, header_include_dir) + header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f)) + if os.path.exists(header_path_to_check): + header_dir = header_dir_to_check + header_filepath = header_path_to_check + # If header file not found, keep as is + if header_filepath is None: + return m.group(0) + # Hipify header file first if needed + if header_filepath not in HIPIFY_FINAL_RESULT: + preprocess_file_and_save_result(output_directory, + header_filepath, + all_files, header_include_dirs, stats, hip_clang_launch, + is_pytorch_extension, clean_ctx, show_progress) + elif header_filepath in HIPIFY_FINAL_RESULT: + header_result = HIPIFY_FINAL_RESULT[header_filepath] + if header_result.current_state == CurrentState.INITIALIZED: + # get_hip_file_path needs a relative path to work correctly + header_rel_path = os.path.relpath(header_filepath, output_directory) + header_fout_path = os.path.abspath(os.path.join(output_directory, + get_hip_file_path(header_rel_path, is_pytorch_extension))) + header_result.hipified_path = header_fout_path + HIPIFY_FINAL_RESULT[header_filepath] = header_result + return templ.format(os.path.relpath(header_fout_path if header_fout_path is not None + else header_filepath, header_dir)) + hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath].hipified_path + return templ.format(_to_unix_path(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None + else header_filepath, header_dir))) + + return m.group(0) + return repl + output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"', True), output_source) + output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>', False), output_source) + output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source) + + # CMakeLists.txt rewrites + if filepath.endswith('CMakeLists.txt'): + output_source = output_source.replace('CUDA', 'HIP') + output_source = output_source.replace('THC', 'THH') + output_source = RE_CU_SUFFIX.sub('.hip', output_source) + + # Perform Kernel Launch Replacements + if not hip_clang_launch: + output_source = processKernelLaunches(output_source, stats) + + # Replace std:: with non-std:: versions + if (filepath.endswith((".cu", ".cuh"))) and "PowKernel" not in filepath: + output_source = replace_math_functions(output_source) + + # Include header if device code is contained. + output_source = hip_header_magic(output_source) + + # Replace the extern __shared__ + # NOTE: No longer needed after transition from hcc to hipclang. + # output_source = replace_extern_shared(output_source) + + # Don't write out identical hipified files for extensions if dirpath has not changed + if ( + is_pytorch_extension + and orig_output_source == output_source + and os.path.dirname(fin_path) == os.path.dirname(fout_path) + ): + hipify_result.hipified_path = fin_path + hipify_result.status = "[skipped, no changes]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + + # Add hipify breadcrumb for C-style files to avoid re-hipification + if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")): + output_source = HIPIFY_C_BREADCRUMB + output_source + + do_write = True + if os.path.exists(fout_path): + with open(fout_path, encoding='utf-8') as fout_old: + do_write = fout_old.read() != output_source + if do_write: + try: + with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout: + fout.write(output_source) + hipify_result.hipified_path = fout_path + hipify_result.status = "[ok]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + except OSError as e: + print(f'{bcolors.WARNING}Failed to save {fout_path} with "{e.strerror}", leaving {fin_path} unchanged.{bcolors.ENDC}', + file=sys.stderr) + hipify_result.hipified_path = fin_path + hipify_result.status = "[skipped, no permissions]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + else: + hipify_result.hipified_path = fout_path + hipify_result.status = "[skipped, already hipified]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + +def file_specific_replacement(filepath, search_string, replace_string, strict=False) -> None: + with openf(filepath, "r+") as f: + contents = f.read() + if strict: + contents = re.sub(fr'\b({re.escape(search_string)})\b', lambda x: replace_string, contents) + else: + contents = contents.replace(search_string, replace_string) + f.seek(0) + f.write(contents) + f.truncate() + + +def file_add_header(filepath, header) -> None: + with openf(filepath, "r+") as f: + contents = f.read() + if header[0] != "<" and header[-1] != ">": + header = f'"{header}"' + contents = (f'#include {header} \n') + contents + f.seek(0) + f.write(contents) + f.truncate() + + +def fix_static_global_kernels(in_txt): + """Static global kernels in HIP results in a compilation error.""" + in_txt = in_txt.replace(" __global__ static", "__global__") + return in_txt + + +RE_INCLUDE = re.compile(r"#include .*\n") + + +def extract_arguments(start, string): + """ + Return the list of arguments in the upcoming function parameter closure. + Example: + string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))' + arguments (output): [{'start': 1, 'end': 7}, {'start': 8, 'end': 16}, \ + {'start': 17, 'end': 19}, {'start': 20, 'end': 53}] + """ + + arguments = [] + closures = { + "<": 0, + "(": 0 + } + current_position = start + argument_start_pos = current_position + 1 + + # Search for final parenthesis + while current_position < len(string): + if string[current_position] == "(": + closures["("] += 1 + elif string[current_position] == ")": + closures["("] -= 1 + elif string[current_position] == "<": + closures["<"] += 1 + elif string[current_position] == ">" and string[current_position - 1] != "-" and closures["<"] > 0: + closures["<"] -= 1 + + # Finished all arguments + if closures["("] == 0 and closures["<"] == 0: + # Add final argument + arguments.append({"start": argument_start_pos, "end": current_position}) + break + + # Finished current argument + if closures["("] == 1 and closures["<"] == 0 and string[current_position] == ",": + arguments.append({"start": argument_start_pos, "end": current_position}) + argument_start_pos = current_position + 1 + + current_position += 1 + + return arguments + + +def str2bool(v : str) -> bool: + """ArgumentParser doesn't support type=bool. Thus, this helper method will convert + from possible string types to True / False.""" + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def hipify( + project_directory: str, + show_detailed: bool = False, + extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"), + header_extensions: Iterable = (".cuh", ".h", ".hpp"), + output_directory: str = "", + header_include_dirs: Iterable = (), + includes: Iterable = ('*',), + extra_files: Iterable = (), + out_of_place_only: bool = False, + ignores: Iterable = (), + show_progress: bool = True, + hip_clang_launch: bool = False, + is_pytorch_extension: bool = False, + hipify_extra_files_only: bool = False, + clean_ctx: GeneratedFileCleaner | None = None +) -> HipifyFinalResult: + if project_directory == "": + project_directory = os.getcwd() + + # Verify the project directory exists. + if not os.path.exists(project_directory): + print("The project folder specified does not exist.") + sys.exit(1) + + # If no output directory, provide a default one. + if not output_directory: + project_directory.rstrip("/") + output_directory = project_directory + "_amd" + + if project_directory != output_directory: + includes = [include.replace(project_directory, output_directory) for include in includes] + ignores = [ignore.replace(project_directory, output_directory) for ignore in ignores] + + # Copy from project directory to output directory if not done already. + if not os.path.exists(output_directory): + shutil.copytree(project_directory, output_directory) + + includes = list(map(_to_unix_path, includes)) + ignores = list(map(_to_unix_path, ignores)) + + all_files = list(matched_files_iter(output_directory, includes=includes, + ignores=ignores, extensions=extensions, + out_of_place_only=out_of_place_only, + is_pytorch_extension=is_pytorch_extension)) + all_files_set = set(all_files) + # pyrefly: ignore [bad-assignment] + for f in extra_files: + if not os.path.isabs(f): + f = os.path.join(output_directory, f) + if f not in all_files_set: + all_files.append(f) + + # List all files in header_include_paths to ensure they are hipified + from pathlib import Path + for header_include_dir in header_include_dirs: + if os.path.isabs(header_include_dir): + header_include_dir_path = Path(header_include_dir) + else: + header_include_dir_path = Path(os.path.join(output_directory, header_include_dir)) + all_files.extend( + str(path) for path in header_include_dir_path.rglob('*') if path.is_file() + and _fnmatch(str(path), includes) + and (not _fnmatch(str(path), ignores)) + and match_extensions(path.name, header_extensions) + ) + + if clean_ctx is None: + clean_ctx = GeneratedFileCleaner(keep_intermediates=True) + + # Preprocessing statistics. + stats: dict[str, list] = {"unsupported_calls": [], "kernel_launches": []} + + for filepath in (all_files if not hipify_extra_files_only else extra_files): + preprocess_file_and_save_result(output_directory, filepath, all_files, header_include_dirs, + stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) + + print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr) + + # Show detailed summary + if show_detailed: + compute_stats(stats) + + return HIPIFY_FINAL_RESULT diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/version.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/version.py new file mode 100644 index 0000000000000000000000000000000000000000..1f356cc57bfa00a3b251402604c54702fb414c96 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/hipify/version.py @@ -0,0 +1 @@ +__version__ = '1.0.0' diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/jit/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/jit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/jit/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/jit/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b81f3a16f43034a42ee29ea24324406e5a673bf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/jit/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/jit/__pycache__/log_extract.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/jit/__pycache__/log_extract.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..844a9835b2d0ac3bd2cd53213c9550064b8fd664 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/jit/__pycache__/log_extract.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/jit/log_extract.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/jit/log_extract.py new file mode 100644 index 0000000000000000000000000000000000000000..9e018457802f4aafd05ba6a8d10ef1c4953b1047 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/jit/log_extract.py @@ -0,0 +1,118 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager +from typing import Any, cast +import random +import torch +import time +from torch.utils.benchmark import Timer + +def extract_ir(filename: str) -> list[str]: + BEGIN = "" + END = "" + pfx = None + graphs = [] + with open(filename) as f: + split_strs = f.read().split(BEGIN) + for i, split_str in enumerate(split_strs): + if i == 0: + continue + end_loc = split_str.find(END) + if end_loc == -1: + continue + s = split_str[:end_loc] + pfx = split_strs[i - 1].splitlines()[-1] + lines = [x[len(pfx):] for x in s.splitlines(keepends=True)] + graphs.append(''.join(lines)) + + return graphs + + +def make_tensor_from_type(inp_type: torch._C.TensorType): + size = inp_type.sizes() + stride = inp_type.strides() + device = inp_type.device() + dtype = inp_type.dtype() + if size is None: + raise AssertionError("make_tensor_from_type: 'size' is None (inp_type.sizes() returned None)") + if stride is None: + raise AssertionError("make_tensor_from_type: 'stride' is None (inp_type.strides() returned None)") + if device is None: + raise AssertionError("make_tensor_from_type: 'device' is None (inp_type.device() returned None)") + if dtype is None: + raise AssertionError("make_tensor_from_type: 'dtype' is None (inp_type.dtype() returned None)") + return torch.empty_strided(size=size, stride=stride, device=device, dtype=dtype) + +def load_graph_and_inputs(ir: str) -> tuple[Any, list[Any]]: + graph = torch._C.parse_ir(ir, parse_tensor_constants=True) + graph.makeMultiOutputIntoTuple() + inputs = [] + for inp in graph.inputs(): + if isinstance(inp.type(), torch._C.FloatType): + inputs.append(random.uniform(.1, 100)) + elif isinstance(inp.type(), torch._C.IntType): + inputs.append(random.randint(1, 100)) + elif isinstance(inp.type(), torch._C.TensorType): + tensorType = cast(torch._C.TensorType, inp.type()) + inputs.append(make_tensor_from_type(tensorType)) + elif isinstance(inp.type(), torch._C.BoolType): + inputs.append(random.randint(0, 1) == 1) + else: + raise NotImplementedError(f"A default value is not implemented for type {inp.type()}") + + func = torch._C._create_function_from_graph("forward", graph) + torch._C._jit_pass_erase_shape_information(func.graph) + return (func, inputs) + +def time_cuda(fn, inputs, test_runs): + t = Timer(stmt="fn(*inputs)", globals={"fn": fn, "inputs" : inputs}) + times = t.blocked_autorange() + return times.median * 1000 # time in ms + +def time_cpu(fn, inputs, test_runs): + s = time.perf_counter() + for _ in range(test_runs): + fn(*inputs) + e = time.perf_counter() + return (e - s) / test_runs * 1000 # time in ms + +def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float: + graph, _ = load_graph_and_inputs(ir) + for _ in range(warmup_runs): + graph(*inputs) + + is_cpu = None + for input in inputs: + if isinstance(input, torch.Tensor): + is_cpu = input.device.type == "cpu" + break + if is_cpu is None: + raise AssertionError("No tensor found in inputs") + + out = time_cpu(graph, inputs, test_runs) if is_cpu else time_cuda(graph, inputs, test_runs) + return out + +@contextmanager +def no_fuser(*args, **kwargs): + old_optimize = torch._C._get_graph_executor_optimize(False) + try: + yield + finally: + torch._C._get_graph_executor_optimize(old_optimize) + +def run_baseline_no_fusion(ir, inputs) -> float: + with no_fuser(): + return run_test(ir, inputs) + + +def run_nnc(ir, inputs, dynamic) -> float: + try: + strat = [("DYNAMIC", 10)] if dynamic else [("STATIC", 10)] + old_strat = torch.jit.set_fusion_strategy(strat) + with torch.jit.fuser("fuser1"): + return run_test(ir, inputs) + finally: + torch.jit.set_fusion_strategy(old_strat) + +def run_nvfuser(ir, inputs) -> float: + with torch.jit.fuser("fuser2"): + return run_test(ir, inputs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16d1ab1c6dd1a2d422cae74eaa5b5888dd2fa175 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/__init__.py @@ -0,0 +1,450 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs +""" +model_dump: a one-stop shop for TorchScript model inspection. + +The goal of this tool is to provide a simple way to extract lots of +useful information from a TorchScript model and make it easy for humans +to consume. It (mostly) replaces zipinfo, common uses of show_pickle, +and various ad-hoc analysis notebooks. + +The tool extracts information from the model and serializes it as JSON. +That JSON can then be rendered by an HTML+JS page, either by +loading the JSON over HTTP or producing a fully self-contained page +with all of the code and data burned-in. +""" + +# Maintainer notes follow. +""" +The implementation strategy has tension between 3 goals: +- Small file size. +- Fully self-contained. +- Easy, modern JS environment. +Using Preact and HTM achieves 1 and 2 with a decent result for 3. +However, the models I tested with result in ~1MB JSON output, +so even using something heavier like full React might be tolerable +if the build process can be worked out. + +One principle I have followed that I think is very beneficial +is to keep the JSON data as close as possible to the model +and do most of the rendering logic on the client. +This makes for easier development (just refresh, usually), +allows for more laziness and dynamism, and lets us add more +views of the same data without bloating the HTML file. + +Currently, this code doesn't actually load the model or even +depend on any part of PyTorch. I don't know if that's an important +feature to maintain, but it's probably worth preserving the ability +to run at least basic analysis on models that cannot be loaded. + +I think the easiest way to develop this code is to cd into model_dump and +run "python -m http.server", then load http://localhost:8000/skeleton.html +in the browser. In another terminal, run +"python -m torch.utils.model_dump --style=json FILE > \ + torch/utils/model_dump/model_info.json" +every time you update the Python code or model. +When you update JS, just refresh. + +Possible improvements: + - Fix various TODO comments in this file and the JS. + - Make the HTML much less janky, especially the auxiliary data panel. + - Make the auxiliary data panel start small, expand when + data is available, and have a button to clear/contract. + - Clean up the JS. There's a lot of copypasta because + I don't really know how to use Preact. + - Make the HTML render and work nicely inside a Jupyter notebook. + - Add the ability for JS to choose the URL to load the JSON based + on the page URL (query or hash). That way we could publish the + inlined skeleton once and have it load various JSON blobs. + - Add a button to expand all expandable sections so ctrl-F works well. + - Add hyperlinking from data to code, and code to code. + - Add hyperlinking from debug info to Diffusion. + - Make small tensor contents available. + - Do something nice for quantized models + (they probably don't work at all right now). +""" + +import argparse +import io +import itertools +import json +import os +import pickle +import pprint +import re +import sys +import urllib.parse +import zipfile +from pathlib import Path +import warnings + +import torch.utils.show_pickle + + +DEFAULT_EXTRA_FILE_SIZE_LIMIT = 16 * 1024 + +__all__ = ['get_storage_info', 'hierarchical_pickle', 'get_model_info', 'get_inline_skeleton', + 'burn_in_info', 'get_info_and_burn_skeleton'] + +def get_storage_info(storage): + if not isinstance(storage, torch.utils.show_pickle.FakeObject): + raise AssertionError(f"storage is not FakeObject: {type(storage)}") + if storage.module != "pers": + raise AssertionError(f"storage.module is not 'pers': {storage.module!r}") + if storage.name != "obj": + raise AssertionError(f"storage.name is not 'obj': {storage.name!r}") + if storage.state is not None: + raise AssertionError(f"storage.state is not None: {storage.state!r}") + if not isinstance(storage.args, tuple): + raise AssertionError(f"storage.args is not a tuple: {type(storage.args)}") + if len(storage.args) != 1: + raise AssertionError(f"len(storage.args) is not 1: {len(storage.args)}") + sa = storage.args[0] + if not isinstance(sa, tuple): + raise AssertionError(f"sa is not a tuple: {type(sa)}") + if len(sa) != 5: + raise AssertionError(f"len(sa) is not 5: {len(sa)}") + if sa[0] != "storage": + raise AssertionError(f"sa[0] is not 'storage': {sa[0]!r}") + if not isinstance(sa[1], torch.utils.show_pickle.FakeClass): + raise AssertionError(f"sa[1] is not FakeClass: {type(sa[1])}") + if sa[1].module != "torch": + raise AssertionError(f"sa[1].module is not 'torch': {sa[1].module!r}") + if not sa[1].name.endswith("Storage"): + raise AssertionError(f"sa[1].name does not end with 'Storage': {sa[1].name!r}") + storage_info = [sa[1].name.replace("Storage", "")] + list(sa[2:]) + return storage_info + + +def hierarchical_pickle(data): + if isinstance(data, (bool, int, float, str, type(None))): + return data + if isinstance(data, list): + return [hierarchical_pickle(d) for d in data] + if isinstance(data, tuple): + return { + "__tuple_values__": hierarchical_pickle(list(data)), + } + if isinstance(data, dict): + return { + "__is_dict__": True, + "keys": hierarchical_pickle(list(data.keys())), + "values": hierarchical_pickle(list(data.values())), + } + if isinstance(data, torch.utils.show_pickle.FakeObject): + typename = f"{data.module}.{data.name}" + if ( + typename.startswith(('__torch__.', 'torch.jit.LoweredWrapper.', 'torch.jit.LoweredModule.')) + ): + if data.args != (): + raise AssertionError("data.args is not ()") + return { + "__module_type__": typename, + "state": hierarchical_pickle(data.state), + } + if typename == "torch._utils._rebuild_tensor_v2": + if data.state is not None: + raise AssertionError("data.state is not None") + storage, offset, size, stride, requires_grad, *_ = data.args + storage_info = get_storage_info(storage) + return {"__tensor_v2__": [storage_info, offset, size, stride, requires_grad]} + if typename == "torch._utils._rebuild_qtensor": + if data.state is not None: + raise AssertionError("data.state is not None") + storage, offset, size, stride, quantizer, requires_grad, *_ = data.args + storage_info = get_storage_info(storage) + if not isinstance(quantizer, tuple): + raise AssertionError("quantizer is not a tuple") + if not isinstance(quantizer[0], torch.utils.show_pickle.FakeClass): + raise AssertionError("quantizer[0] is not a FakeClass") + if quantizer[0].module != "torch": + raise AssertionError("quantizer[0].module is not torch") + if quantizer[0].name == "per_tensor_affine": + if len(quantizer) != 3: + raise AssertionError("len(quantizer) is not 3") + if not isinstance(quantizer[1], float): + raise AssertionError("quantizer[1] is not a float") + if not isinstance(quantizer[2], int): + raise AssertionError("quantizer[2] is not an int") + quantizer_extra = list(quantizer[1:3]) + else: + quantizer_extra = [] + quantizer_json = [quantizer[0].name] + quantizer_extra + return {"__qtensor__": [storage_info, offset, size, stride, quantizer_json, requires_grad]} + if typename == "torch.jit._pickle.restore_type_tag": + if data.state is not None: + raise AssertionError("data.state is not None") + obj, typ = data.args + if not isinstance(typ, str): + raise AssertionError("typ is not a string") + return hierarchical_pickle(obj) + if re.fullmatch(r"torch\.jit\._pickle\.build_[a-z]+list", typename): + if data.state is not None: + raise AssertionError("data.state is not None") + ls, = data.args + if not isinstance(ls, list): + raise AssertionError("ls is not a list") + return hierarchical_pickle(ls) + if typename == "torch.device": + if data.state is not None: + raise AssertionError("data.state is not None") + name, = data.args + if not isinstance(name, str): + raise AssertionError("name is not a string") + # Just forget that it was a device and return the name. + return name + if typename == "builtin.UnicodeDecodeError": + if data.state is not None: + raise AssertionError("data.state is not None") + msg, = data.args + if not isinstance(msg, str): + raise AssertionError("msg is not a string") + # Hack: Pretend this is a module so we don't need custom serialization. + # Hack: Wrap the message in a tuple so it looks like a nice state object. + # TODO: Undo at least that second hack. We should support string states. + return { + "__module_type__": typename, + "state": hierarchical_pickle((msg,)), + } + raise Exception(f"Can't prepare fake object of type for JS: {typename}") # noqa: TRY002 + raise Exception(f"Can't prepare data of type for JS: {type(data)}") # noqa: TRY002 + + +def get_model_info( + path_or_file, + title=None, + extra_file_size_limit=DEFAULT_EXTRA_FILE_SIZE_LIMIT): + """Get JSON-friendly information about a model. + + The result is suitable for being saved as model_info.json, + or passed to burn_in_info. + """ + + if isinstance(path_or_file, os.PathLike): + default_title = os.fspath(path_or_file) + file_size = path_or_file.stat().st_size # type: ignore[attr-defined] + elif isinstance(path_or_file, str): + default_title = path_or_file + file_size = Path(path_or_file).stat().st_size + else: + default_title = "buffer" + path_or_file.seek(0, io.SEEK_END) + file_size = path_or_file.tell() + path_or_file.seek(0) + + title = title or default_title + + with zipfile.ZipFile(path_or_file) as zf: + path_prefix = None + zip_files = [] + # pyrefly: ignore [bad-assignment] + for zi in zf.infolist(): + prefix = re.sub("/.*", "", zi.filename) + if path_prefix is None: + path_prefix = prefix + elif prefix != path_prefix: + raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002 + zip_files.append( + { + "filename": zi.filename, + "compression": zi.compress_type, + "compressed_size": zi.compress_size, + "file_size": zi.file_size, + } + ) + if path_prefix is None: + raise AssertionError("path_prefix is None") + version = zf.read(path_prefix + "/version").decode("utf-8").strip() + + def get_pickle(name): + if path_prefix is None: + raise AssertionError("path_prefix is None") + with zf.open(path_prefix + f"/{name}.pkl") as handle: + raw = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load() + return hierarchical_pickle(raw) + + model_data = get_pickle("data") + constants = get_pickle("constants") + + # Intern strings that are likely to be reused. + # Pickle automatically detects shared structure, + # so reused strings are stored efficiently. + # However, JSON has no way of representing this, + # so we have to do it manually. + interned_strings : dict[str, int] = {} + + def intern(s): + if s not in interned_strings: + interned_strings[s] = len(interned_strings) + return interned_strings[s] + + code_files = {} + for zi in zf.infolist(): + if not zi.filename.endswith(".py"): + continue + with zf.open(zi) as handle: + raw_code = handle.read() + with zf.open(zi.filename + ".debug_pkl") as handle: + raw_debug = handle.read() + + # Parse debug info and add begin/end markers if not present + # to ensure that we cover the entire source code. + debug_info_t = pickle.loads(raw_debug) + text_table = None + + if (len(debug_info_t) == 3 and + isinstance(debug_info_t[0], str) and + debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'): + _, text_table, content = debug_info_t + + def parse_new_format(line): + # (0, (('', '', 0), 0, 0)) + num, ((text_indexes, fname_idx, offset), start, end), tag = line + text = ''.join(text_table[x] for x in text_indexes) # type: ignore[index] + fname = text_table[fname_idx] # type: ignore[index] + return num, ((text, fname, offset), start, end), tag + + debug_info_t = map(parse_new_format, content) + + debug_info = list(debug_info_t) + if not debug_info: + debug_info.append((0, (('', '', 0), 0, 0))) + if debug_info[-1][0] != len(raw_code): + debug_info.append((len(raw_code), (('', '', 0), 0, 0))) + + code_parts = [] + for di, di_next in itertools.pairwise(debug_info): + start, source_range, *_ = di + end = di_next[0] + if end <= start: + raise AssertionError("end is not greater than start") + source, s_start, s_end = source_range + s_text, s_file, s_line = source + # TODO: Handle this case better. TorchScript ranges are in bytes, + # but JS doesn't really handle byte strings. + # if bytes and chars are not equivalent for this string, + # zero out the ranges so we don't highlight the wrong thing. + if len(s_text) != len(s_text.encode("utf-8")): + s_start = 0 + s_end = 0 + text = raw_code[start:end] + code_parts.append([text.decode("utf-8"), intern(s_file), s_line, intern(s_text), s_start, s_end]) + code_files[zi.filename] = code_parts + + extra_files_json_pattern = re.compile(re.escape(path_prefix) + "/extra/.*\\.json") + extra_files_jsons = {} + for zi in zf.infolist(): + if not extra_files_json_pattern.fullmatch(zi.filename): + continue + if zi.file_size > extra_file_size_limit: + continue + with zf.open(zi) as handle: + try: + json_content = json.load(handle) + extra_files_jsons[zi.filename] = json_content + except json.JSONDecodeError: + extra_files_jsons[zi.filename] = "INVALID JSON" + + always_render_pickles = { + "bytecode.pkl", + } + extra_pickles = {} + for zi in zf.infolist(): + if not zi.filename.endswith(".pkl"): + continue + with zf.open(zi) as handle: + # TODO: handle errors here and just ignore the file? + # NOTE: For a lot of these files (like bytecode), + # we could get away with just unpickling, but this should be safer. + obj = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load() + buf = io.StringIO() + pprint.pprint(obj, buf) + contents = buf.getvalue() + # Checked the rendered length instead of the file size + # because pickles with shared structure can explode in size during rendering. + if os.path.basename(zi.filename) not in always_render_pickles and \ + len(contents) > extra_file_size_limit: + continue + extra_pickles[zi.filename] = contents + + return { + "model": { + "title": title, + "file_size": file_size, + "version": version, + "zip_files": zip_files, + "interned_strings": list(interned_strings), + "code_files": code_files, + "model_data": model_data, + "constants": constants, + "extra_files_jsons": extra_files_jsons, + "extra_pickles": extra_pickles, + } + } + + +def get_inline_skeleton(): + """Get a fully-inlined skeleton of the frontend. + + The returned HTML page has no external network dependencies for code. + It can load model_info.json over HTTP, or be passed to burn_in_info. + """ + + import importlib.resources + + # pyrefly: ignore [bad-argument-type] + skeleton = importlib.resources.read_text(__package__, "skeleton.html") + # pyrefly: ignore [bad-argument-type] + js_code = importlib.resources.read_text(__package__, "code.js") + for js_module in ["preact", "htm"]: + # pyrefly: ignore [bad-argument-type] + js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs") + js_url = "data:application/javascript," + urllib.parse.quote(js_lib) + js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url) + skeleton = skeleton.replace(' src="./code.js">', ">\n" + js_code) + return skeleton + + +def burn_in_info(skeleton, info): + """Burn model info into the HTML skeleton. + + The result will render the hard-coded model info and + have no external network dependencies for code or data. + """ + + # Note that Python's json serializer does not escape slashes in strings. + # Since we're inlining this JSON directly into a script tag, a string + # containing "" would end the script prematurely and + # mess up our page. Unconditionally escape fixes that. + return skeleton.replace( + "BURNED_IN_MODEL_INFO = null", + "BURNED_IN_MODEL_INFO = " + json.dumps(info, sort_keys=True).replace("/", "\\/")) + + +def get_info_and_burn_skeleton(path_or_bytesio, **kwargs): + model_info = get_model_info(path_or_bytesio, **kwargs) + skeleton = get_inline_skeleton() + page = burn_in_info(skeleton, model_info) + return page + + +def main(argv, *, stdout=None) -> None: + warnings.warn("torch.utils.model_dump is deprecated and will be removed in a future PyTorch release.", stacklevel=2) + parser = argparse.ArgumentParser() + parser.add_argument("--style", choices=["json", "html"]) + parser.add_argument("--title") + parser.add_argument("model") + args = parser.parse_args(argv[1:]) + + info = get_model_info(args.model, title=args.title) + + output = stdout or sys.stdout + + if args.style == "json": + output.write(json.dumps(info, sort_keys=True) + "\n") + elif args.style == "html": + skeleton = get_inline_skeleton() + page = burn_in_info(skeleton, info) + output.write(page) + else: + raise Exception("Invalid style") # noqa: TRY002 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/__main__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d4bdac389bb1f270d74efb6c876258d46077110 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/__main__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +import sys +from . import main + +sys.exit(main(sys.argv)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d865e95f79d6c562d41caf50474ba9273ec7a359 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/__pycache__/__main__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/__pycache__/__main__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6660d194768b4d057e48a3a77c6871ce9c82c6a7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/__pycache__/__main__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/code.js b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/code.js new file mode 100644 index 0000000000000000000000000000000000000000..173ddfb639d847159ee4fdf46691404bf1bbb7a3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/code.js @@ -0,0 +1,689 @@ +import { h, Component, render } from 'https://unpkg.com/preact?module'; +import htm from 'https://unpkg.com/htm?module'; + +const html = htm.bind(h); + +const BURNED_IN_MODEL_INFO = null; + +// https://stackoverflow.com/a/20732091 +function humanFileSize(size) { + if (size == 0) { return "0 B"; } + var i = Math.floor( Math.log(size) / Math.log(1024) ); + return (size / Math.pow(1024, i)).toFixed(2) * 1 + ' ' + ['B', 'kB', 'MB', 'GB', 'TB'][i]; +} + +function caret(down) { + return down ? "\u25BE" : "\u25B8"; +} + +class Blamer { + constructor() { + this.blame_on_click = false; + this.aux_content_pane = null; + } + + setAuxContentPane(pane) { + this.aux_content_pane = pane; + } + + readyBlame() { + this.blame_on_click = true; + } + + maybeBlame(arg) { + if (!this.blame_on_click) { + return; + } + this.blame_on_click = false; + if (!this.aux_content_pane) { + return; + } + this.aux_content_pane.doBlame(arg); + } +} + +let blame = new Blamer(); + +class Hider extends Component { + constructor() { + super(); + this.state = { shown: null }; + } + + componentDidMount() { + this.setState({ shown: this.props.shown === "true" }); + } + + render({name, children}, {shown}) { + let my_caret = html` this.click()} >${caret(shown)}`; + return html`
+

${my_caret} ${name}

+
${shown ? this.props.children : []}
`; + } + + click() { + this.setState({shown: !this.state.shown}); + } +} + +function ModelSizeSection({model: {file_size, zip_files}}) { + let store_size = 0; + let compr_size = 0; + for (const zi of zip_files) { + if (zi.compression === 0) { + // TODO: Maybe check that compressed_size === file_size. + store_size += zi.compressed_size; + } else { + compr_size += zi.compressed_size; + } + } + let zip_overhead = file_size - store_size - compr_size; + // TODO: Better formatting. Right-align this. + return html` + <${Hider} name="Model Size" shown=true> +
.
+      Model size: ${file_size} (${humanFileSize(file_size)})
+      Stored files: ${store_size} (${humanFileSize(store_size)})
+      Compressed files: ${compr_size} (${humanFileSize(compr_size)})
+      Zip overhead: ${zip_overhead} (${humanFileSize(zip_overhead)})
+    
`; +} + +function StructuredDataSection({name, data, shown}) { + return html` + <${Hider} name=${name} shown=${shown}> +
+ <${StructuredData} data=${data} indent="" prefix=""/> +
`; +} + +class StructuredData extends Component { + constructor() { + super(); + this.state = { shown: false }; + + this.INLINE_TYPES = new Set(["boolean", "number", "string"]) + this.IGNORED_STATE_KEYS = new Set(["training", "_is_full_backward_hook"]) + } + + click() { + this.setState({shown: !this.state.shown}); + } + + expando(data) { + if (data === null || this.INLINE_TYPES.has(typeof(data))) { + return false; + } + if (typeof(data) != "object") { + throw new Error("Not an object"); + } + if (Array.isArray(data)) { + // TODO: Maybe show simple lists and tuples on one line. + return true; + } + if (data.__tuple_values__) { + // TODO: Maybe show simple lists and tuples on one line. + return true; + } + if (data.__is_dict__) { + // TODO: Maybe show simple (empty?) dicts on one line. + return true; + } + if (data.__module_type__) { + return true; + } + if (data.__tensor_v2__) { + return false; + } + if (data.__qtensor__) { + return false; + } + throw new Error("Can't handle data type.", data); + } + + renderHeadline(data) { + if (data === null) { + return "None"; + } + if (typeof(data) == "boolean") { + const sd = String(data); + return sd.charAt(0).toUpperCase() + sd.slice(1); + } + if (typeof(data) == "number") { + return JSON.stringify(data); + } + if (typeof(data) == "string") { + return JSON.stringify(data); + } + if (typeof(data) != "object") { + throw new Error("Not an object"); + } + if (Array.isArray(data)) { + return "list(["; + } + if (data.__tuple_values__) { + return "tuple(("; + } + if (data.__is_dict__) { + return "dict({"; + } + if (data.__module_type__) { + return data.__module_type__ + "()"; + } + if (data.__tensor_v2__) { + const [storage, offset, size, stride, grad] = data.__tensor_v2__; + const [dtype, key, device, numel] = storage; + return this.renderTensor( + "tensor", dtype, key, device, numel, offset, size, stride, grad, []); + } + if (data.__qtensor__) { + const [storage, offset, size, stride, quantizer, grad] = data.__qtensor__; + const [dtype, key, device, numel] = storage; + let extra_parts = []; + if (quantizer[0] == "per_tensor_affine") { + extra_parts.push(`scale=${quantizer[1]}`); + extra_parts.push(`zero_point=${quantizer[2]}`); + } else { + extra_parts.push(`quantizer=${quantizer[0]}`); + } + return this.renderTensor( + "qtensor", dtype, key, device, numel, offset, size, stride, grad, extra_parts); + } + throw new Error("Can't handle data type.", data); + } + + renderTensor( + prefix, + dtype, + storage_key, + device, + storage_numel, + offset, + size, + stride, + grad, + extra_parts) { + let parts = [ + "(" + size.join(",") + ")", + dtype, + ]; + parts.push(...extra_parts); + if (device != "cpu") { + parts.push(device); + } + if (grad) { + parts.push("grad"); + } + // TODO: Check stride and indicate if the tensor is channels-last or non-contiguous + // TODO: Check size, stride, offset, and numel and indicate if + // the tensor doesn't use all data in storage. + // TODO: Maybe show key? + void(offset); + void(stride); + void(storage_key); + void(storage_numel); + return prefix + "(" + parts.join(", ") + ")"; + } + + renderBody(indent, data) { + if (data === null || this.INLINE_TYPES.has(typeof(data))) { + throw "Should not reach here." + } + if (typeof(data) != "object") { + throw new Error("Not an object"); + } + if (Array.isArray(data)) { + let new_indent = indent + "\u00A0\u00A0"; + let parts = []; + for (let idx = 0; idx < data.length; idx++) { + // Does it make sense to put explicit index numbers here? + parts.push(html`
<${StructuredData} prefix=${idx + ": "} indent=${new_indent} data=${data[idx]} />`); + } + return parts; + } + if (data.__tuple_values__) { + // Handled the same as lists. + return this.renderBody(indent, data.__tuple_values__); + } + if (data.__is_dict__) { + let new_indent = indent + "\u00A0\u00A0"; + let parts = []; + for (let idx = 0; idx < data.keys.length; idx++) { + if (typeof(data.keys[idx]) != "string") { + parts.push(html`
${new_indent}Non-string key`); + } else { + parts.push(html`
<${StructuredData} prefix=${data.keys[idx] + ": "} indent=${new_indent} data=${data.values[idx]} />`); + } + } + return parts; + } + if (data.__module_type__) { + const mstate = data.state; + if (mstate === null || typeof(mstate) != "object") { + throw new Error("Bad module state"); + } + let new_indent = indent + "\u00A0\u00A0"; + let parts = []; + if (mstate.__is_dict__) { + // TODO: Less copy/paste between this and normal dicts. + for (let idx = 0; idx < mstate.keys.length; idx++) { + if (typeof(mstate.keys[idx]) != "string") { + parts.push(html`
${new_indent}Non-string key`); + } else if (this.IGNORED_STATE_KEYS.has(mstate.keys[idx])) { + // Do nothing. + } else { + parts.push(html`
<${StructuredData} prefix=${mstate.keys[idx] + ": "} indent=${new_indent} data=${mstate.values[idx]} />`); + } + } + } else if (mstate.__tuple_values__) { + parts.push(html`
<${StructuredData} prefix="" indent=${new_indent} data=${mstate} />`); + } else if (mstate.__module_type__) { + // We normally wouldn't have the state of a module be another module, + // but we use "modules" to encode special values (like Unicode decode + // errors) that might be valid states. Just go with it. + parts.push(html`
<${StructuredData} prefix="" indent=${new_indent} data=${mstate} />`); + } else { + throw new Error("Bad module state"); + } + return parts; + } + if (data.__tensor_v2__) { + throw "Should not reach here." + } + if (data.__qtensor__) { + throw "Should not reach here." + } + throw new Error("Can't handle data type.", data); + } + + render({data, indent, prefix}, {shown}) { + const exp = this.expando(data) ? html` this.click()} >${caret(shown)} ` : ""; + const headline = this.renderHeadline(data); + const body = shown ? this.renderBody(indent, data) : ""; + return html`${indent}${exp}${prefix}${headline}${body}`; + } +} + +function ZipContentsSection({model: {zip_files}}) { + // TODO: Add human-readable sizes? + // TODO: Add sorting options? + // TODO: Add hierarchical collapsible tree? + return html` + <${Hider} name="Zip Contents" shown=false> + + + + + + + + + + + ${zip_files.map(zf => html` + + + + + `)} + +
ModeSizeCompressedName
${{0: "store", 8: "deflate"}[zf.compression] || zf.compression}${zf.file_size}${zf.compressed_size}${zf.filename}
`; +} + +function CodeSection({model: {code_files}}) { + return html` + <${Hider} name="Code" shown=false> +
+ ${Object.entries(code_files).map(([fn, code]) => html`<${OneCodeSection} + filename=${fn} code=${code} />`)} +
`; +} + +class OneCodeSection extends Component { + constructor() { + super(); + this.state = { shown: false }; + } + + click() { + const shown = !this.state.shown; + this.setState({shown: shown}); + } + + render({filename, code}, {shown}) { + const header = html` +

+ this.click()} >${caret(shown)} + ${filename}

+ `; + if (!shown) { + return header; + } + return html` + ${header} +
${code.map(c => this.renderBlock(c))}
+ `; + } + + renderBlock([text, ist_file, line, ist_s_text, s_start, s_end]) { + return html` blame.maybeBlame({ist_file, line, ist_s_text, s_start, s_end})} + >${text}`; + } +} + +function ExtraJsonSection({files}) { + return html` + <${Hider} name="Extra files (JSON)" shown=false> +
+

Use "Log Raw Model Info" for hierarchical view in browser console.

+ ${Object.entries(files).map(([fn, json]) => html`<${OneJsonSection} + filename=${fn} json=${json} />`)} +
`; +} + +class OneJsonSection extends Component { + constructor() { + super(); + this.state = { shown: false }; + } + + click() { + const shown = !this.state.shown; + this.setState({shown: shown}); + } + + render({filename, json}, {shown}) { + const header = html` +

+ this.click()} >${caret(shown)} + ${filename}

+ `; + if (!shown) { + return header; + } + return html` + ${header} +
${JSON.stringify(json, null, 2)}
+ `; + } +} + +function ExtraPicklesSection({files}) { + return html` + <${Hider} name="Extra Pickles" shown=false> +
+ ${Object.entries(files).map(([fn, content]) => html`<${OnePickleSection} + filename=${fn} content=${content} />`)} +
`; +} + +class OnePickleSection extends Component { + constructor() { + super(); + this.state = { shown: false }; + } + + click() { + const shown = !this.state.shown; + this.setState({shown: shown}); + } + + render({filename, content}, {shown}) { + const header = html` +

+ this.click()} >${caret(shown)} + ${filename}

+ `; + if (!shown) { + return header; + } + return html` + ${header} +
${content}
+ `; + } +} + +function assertStorageAreEqual(key, lhs, rhs) { + if (lhs.length !== rhs.length || + !lhs.every((val, idx) => val === rhs[idx])) { + throw new Error("Storage mismatch for key '" + key + "'"); + } +} + +function computeTensorMemory(numel, dtype) { + const sizes = { + "Byte": 1, + "Char": 1, + "Short": 2, + "Int": 4, + "Long": 8, + "Half": 2, + "Float": 4, + "Double": 8, + "ComplexHalf": 4, + "ComplexFloat": 8, + "ComplexDouble": 16, + "Bool": 1, + "QInt8": 1, + "QUInt8": 1, + "QInt32": 4, + "BFloat16": 2, + }; + let dtsize = sizes[dtype]; + if (!dtsize) { + throw new Error("Unrecognized dtype: " + dtype); + } + return numel * dtsize; +} + +// TODO: Maybe track by dtype as well. +// TODO: Maybe distinguish between visible size and storage size. +function getTensorStorages(data) { + if (data === null) { + return new Map(); + } + if (typeof(data) == "boolean") { + return new Map(); + } + if (typeof(data) == "number") { + return new Map(); + } + if (typeof(data) == "string") { + return new Map(); + } + if (typeof(data) != "object") { + throw new Error("Not an object"); + } + if (Array.isArray(data)) { + let result = new Map(); + for (const item of data) { + const tensors = getTensorStorages(item); + for (const [key, storage] of tensors.entries()) { + if (!result.has(key)) { + result.set(key, storage); + } else { + const old_storage = result.get(key); + assertStorageAreEqual(key, old_storage, storage); + } + } + } + return result; + } + if (data.__tuple_values__) { + return getTensorStorages(data.__tuple_values__); + } + if (data.__is_dict__) { + return getTensorStorages(data.values); + } + if (data.__module_type__) { + return getTensorStorages(data.state); + } + if (data.__tensor_v2__) { + const [storage, offset, size, stride, grad] = data.__tensor_v2__; + const [dtype, key, device, numel] = storage; + return new Map([[key, storage]]); + } + if (data.__qtensor__) { + const [storage, offset, size, stride, quantizer, grad] = data.__qtensor__; + const [dtype, key, device, numel] = storage; + return new Map([[key, storage]]); + } + throw new Error("Can't handle data type.", data); +} + +function getTensorMemoryByDevice(pickles) { + let all_tensors = []; + for (const [name, pickle] of pickles) { + const tensors = getTensorStorages(pickle); + all_tensors.push(...tensors.values()); + } + let result = {}; + for (const storage of all_tensors.values()) { + const [dtype, key, device, numel] = storage; + const size = computeTensorMemory(numel, dtype); + result[device] = (result[device] || 0) + size; + } + return result; +} + +// Make this a separate component so it is rendered lazily. +class OpenTensorMemorySection extends Component { + render({model: {model_data, constants}}) { + let sizes = getTensorMemoryByDevice(new Map([ + ["data", model_data], + ["constants", constants], + ])); + return html` + + + + + + + + + + ${Object.entries(sizes).map(([dev, size]) => html` + + + + `)} + +
DeviceBytesHuman
${dev}${size}${humanFileSize(size)}
`; + } +} + +function TensorMemorySection({model}) { + return html` + <${Hider} name="Tensor Memory" shown=false> + <${OpenTensorMemorySection} model=${model} />`; +} + +class AuxContentPane extends Component { + constructor() { + super(); + this.state = { + blame_info: null, + }; + } + + doBlame(arg) { + this.setState({...this.state, blame_info: arg}); + } + + render({model: {interned_strings}}, {blame_info}) { + let blame_content = ""; + if (blame_info) { + const {ist_file, line, ist_s_text, s_start, s_end} = blame_info; + let s_text = interned_strings[ist_s_text]; + if (s_start != 0 || s_end != s_text.length) { + let prefix = s_text.slice(0, s_start); + let main = s_text.slice(s_start, s_end); + let suffix = s_text.slice(s_end); + s_text = html`${prefix}${main}${suffix}`; + } + blame_content = html` +

${interned_strings[ist_file]}:${line}

+
${s_start}:${s_end}
+
${s_text}

+ `; + } + return html` + +
+ ${blame_content} + `; + } +} + +class App extends Component { + constructor() { + super(); + this.state = { + err: false, + model: null, + }; + } + + componentDidMount() { + const app = this; + if (BURNED_IN_MODEL_INFO !== null) { + app.setState({model: BURNED_IN_MODEL_INFO}); + } else { + fetch("./model_info.json").then(function(response) { + if (!response.ok) { + throw new Error("Response not ok."); + } + return response.json(); + }).then(function(body) { + app.setState({model: body}); + }).catch(function(error) { + console.log("Top-level error: ", error); + }); + } + } + + componentDidCatch(error) { + void(error); + this.setState({...this.state, err: true}); + } + + render(_, {err}) { + if (this.state.model === null) { + return html`

Loading...

`; + } + + const model = this.state.model.model; + + let error_msg = ""; + if (err) { + error_msg = html`

An error occurred. Check console

`; + } + + return html` + ${error_msg} +
+

TorchScript Model (version ${model.version}): ${model.title}

+ + <${ModelSizeSection} model=${model}/> + <${StructuredDataSection} name="Model Data" data=${model.model_data} shown=true/> + <${StructuredDataSection} name="Constants" data=${model.constants} shown=false/> + <${ZipContentsSection} model=${model}/> + <${CodeSection} model=${model}/> + <${ExtraJsonSection} files=${model.extra_files_jsons}/> + <${ExtraPicklesSection} files=${model.extra_pickles}/> + <${TensorMemorySection} model=${model}/> +
+
+ <${AuxContentPane} + err=${this.state.error} + model=${model} + ref=${(p) => blame.setAuxContentPane(p)}/> +
+ `; + } +} + +render(h(App), document.body); diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/htm.mjs b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/htm.mjs new file mode 100644 index 0000000000000000000000000000000000000000..06f25a13d8021ff4f43de442bbf0279f24735d6c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/htm.mjs @@ -0,0 +1,2 @@ +// HTM, Apache License +var n=function(t,s,r,e){var u;s[0]=0;for(var h=1;h=5&&((e||!n&&5===r)&&(h.push(r,0,e,s),r=6),n&&(h.push(r,n,0,s),r=6)),e=""},a=0;a"===t?(r=1,e=""):e=t+e[0]:u?t===u?u="":e+=t:'"'===t||"'"===t?u=t:">"===t?(p(),r=1):r&&("="===t?(r=5,s=e,e=""):"/"===t&&(r<5||">"===n[a][l+1])?(p(),3===r&&(h=h[0]),r=h,(h=h[0]).push(2,0,r),r=0):" "===t||"\t"===t||"\n"===t||"\r"===t?(p(),r=2):e+=t),3===r&&"!--"===e&&(r=4,h=h[0])}return p(),h}(s)),r),arguments,[])).length>1?r:r[0]} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/preact.mjs b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/preact.mjs new file mode 100644 index 0000000000000000000000000000000000000000..8c85bd948c6772ca8d40fc8d6fab6a220d55a1ef --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/model_dump/preact.mjs @@ -0,0 +1,2 @@ +// Preact, MIT License +var n,l,u,i,t,o,r={},f=[],e=/acit|ex(?:s|g|n|p|$)|rph|grid|ows|mnc|ntw|ine[ch]|zoo|^ord|itera/i;function c(e,n){for(var t in n)e[t]=n[t];return e}function s(e){var n=e.parentNode;n&&n.removeChild(e)}function a(e,n,t){var _,l,o,r=arguments,i={};for(o in n)"key"==o?_=n[o]:"ref"==o?l=n[o]:i[o]=n[o];if(arguments.length>3)for(t=[t],o=3;o0?v(m.type,m.props,m.key,null,m.__v):m)){if(m.__=t,m.__b=t.__b+1,null===(h=P[p])||h&&m.key==h.key&&m.type===h.type)P[p]=void 0;else for(a=0;a3)for(t=[t],o=3;o + + + TorchScript Model + + + + + + + + diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/serialization/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/serialization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d63bc18b69b138a026622de599aed656cc868c8e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/serialization/__init__.py @@ -0,0 +1 @@ +from . import config diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..614dc1e70529ab6bfe8f177745b8b24cdf3acb1e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b709eb69b8fed8115da57de129b132cc603e298d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/config.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/serialization/config.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/serialization/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e6729c68583f7206d07df7bfa2666007a6bd67 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/serialization/config.py @@ -0,0 +1,25 @@ +import sys +from typing import Optional as _Optional, TYPE_CHECKING as _TYPE_CHECKING + + +if _TYPE_CHECKING: + from torch.serialization import LoadEndianness as _LoadEndianess + +from torch.utils._config_module import install_config_module as _install_config_module + + +class load: + mmap: bool = False + endianness: _Optional["_LoadEndianess"] = None + # MAP_PRIVATE = 2 + mmap_flags: int | None = None if sys.platform == "win32" else 2 + calculate_storage_offsets: bool = False + + +class save: + compute_crc32: bool = True + use_pinned_memory_for_d2h: bool = False + storage_alignment: int = 64 + + +_install_config_module(sys.modules[__name__]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2ac5edd05e16ef51e75f2ca68864b65da5d58 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__init__.py @@ -0,0 +1,19 @@ +import tensorboard +from torch._vendor.packaging.version import Version + +if not hasattr(tensorboard, "__version__") or Version( + tensorboard.__version__ +) < Version("1.15"): + raise ImportError("TensorBoard logging requires TensorBoard version 1.15 or above") + +del Version +del tensorboard + +from .writer import FileWriter, SummaryWriter +from tensorboard.summary.writer.record_writer import RecordWriter + +__all__ = [ + "FileWriter", + "RecordWriter", + "SummaryWriter", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..769d64456d609340f523146e0da822f33eb2a642 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_convert_np.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_convert_np.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcd585b570d4898ba8d38530142a44d5a10c9650 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_convert_np.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_embedding.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_embedding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ca12c2f0568f3808ce0da1aa07ef29a31ca1b8b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_embedding.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_onnx_graph.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_onnx_graph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55633670908a87a60e4244c3d322ea5dcc67d7b4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_onnx_graph.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_proto_graph.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_proto_graph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1652a69d2d7225f7f884c6c3d1d7e3b004c62c39 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_proto_graph.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_pytorch_graph.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_pytorch_graph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..501d488d9f982dce55f49e030fdd188fd78d3ed7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_pytorch_graph.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7121bc9b944072cae85e489d5da3ad873dadbedc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/summary.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/summary.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2152caa60d9f0de9e566371f6fa1b5e6c0b3ecd3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/summary.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/writer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/writer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..726603d298b3dbe7ebfc50fb8479b39f44f5a061 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/writer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_convert_np.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_convert_np.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e8910580de16d9e7cf90f10d6327556e9a37a6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_convert_np.py @@ -0,0 +1,37 @@ +"""This module converts objects into numpy array.""" + +import numpy as np + +import torch + + +def make_np(x: torch.Tensor) -> np.ndarray: + """ + Convert an object into numpy array. + + Args: + x: An instance of torch tensor + + Returns: + numpy.array: Numpy array + """ + if isinstance(x, np.ndarray): + return x + if np.isscalar(x): + return np.array([x]) + if isinstance(x, torch.Tensor): + if x.device.type == "meta": + return np.random.randn(1) + return _prepare_pytorch(x) + raise NotImplementedError( + f"Got {type(x)}, but numpy array or torch tensor are expected." + ) + + +def _prepare_pytorch(x: torch.Tensor) -> np.ndarray: + if x.dtype == torch.bfloat16: + x = x.to(torch.float16) + # pyrefly: ignore [bad-assignment] + x = x.detach().cpu().numpy() + # pyrefly: ignore [bad-return] + return x diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_embedding.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..73413e219d0efbabe7d66747bd108ee5e8be4319 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_embedding.py @@ -0,0 +1,87 @@ +# mypy: allow-untyped-defs +import math +import numpy as np +from ._convert_np import make_np +from ._utils import make_grid +from tensorboard.compat import tf +from tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo + + +_HAS_GFILE_JOIN = hasattr(tf.io.gfile, "join") + + +def _gfile_join(a, b): + # The join API is different between tensorboard's TF stub and TF: + # https://github.com/tensorflow/tensorboard/issues/6080 + # We need to try both because `tf` may point to either the stub or the real TF. + if _HAS_GFILE_JOIN: + return tf.io.gfile.join(a, b) + else: + fs = tf.io.gfile.get_filesystem(a) + return fs.join(a, b) + + +def make_tsv(metadata, save_path, metadata_header=None) -> None: + if not metadata_header: + metadata = [str(x) for x in metadata] + else: + if len(metadata_header) != len( + metadata[0] + ): + raise AssertionError("len of header must be equal to the number of columns in metadata") + metadata = ["\t".join(str(e) for e in l) for l in [metadata_header] + metadata] + + metadata_bytes = tf.compat.as_bytes("\n".join(metadata) + "\n") + with tf.io.gfile.GFile(_gfile_join(save_path, "metadata.tsv"), "wb") as f: + f.write(metadata_bytes) + + +# https://github.com/tensorflow/tensorboard/issues/44 image label will be squared +def make_sprite(label_img, save_path) -> None: + from PIL import Image + from io import BytesIO + + # this ensures the sprite image has correct dimension as described in + # https://www.tensorflow.org/get_started/embedding_viz + nrow = math.ceil((label_img.size(0)) ** 0.5) + arranged_img_CHW = make_grid(make_np(label_img), ncols=nrow) + + # augment images so that #images equals nrow*nrow + arranged_augment_square_HWC = np.zeros( + (arranged_img_CHW.shape[2], arranged_img_CHW.shape[2], 3) + ) + arranged_img_HWC = arranged_img_CHW.transpose(1, 2, 0) # chw -> hwc + arranged_augment_square_HWC[: arranged_img_HWC.shape[0], :, :] = arranged_img_HWC + im = Image.fromarray(np.uint8((arranged_augment_square_HWC * 255).clip(0, 255))) + + with BytesIO() as buf: + im.save(buf, format="PNG") + im_bytes = buf.getvalue() + + with tf.io.gfile.GFile(_gfile_join(save_path, "sprite.png"), "wb") as f: + f.write(im_bytes) + + +def get_embedding_info(metadata, label_img, subdir, global_step, tag): + info = EmbeddingInfo() + info.tensor_name = f"{tag}:{str(global_step).zfill(5)}" + info.tensor_path = _gfile_join(subdir, "tensors.tsv") + if metadata is not None: + info.metadata_path = _gfile_join(subdir, "metadata.tsv") + if label_img is not None: + info.sprite.image_path = _gfile_join(subdir, "sprite.png") + info.sprite.single_image_dim.extend([label_img.size(3), label_img.size(2)]) + return info + + +def write_pbtxt(save_path, contents) -> None: + config_path = _gfile_join(save_path, "projector_config.pbtxt") + with tf.io.gfile.GFile(config_path, "wb") as f: + f.write(tf.compat.as_bytes(contents)) + + +def make_mat(matlist, save_path) -> None: + with tf.io.gfile.GFile(_gfile_join(save_path, "tensors.tsv"), "wb") as f: + for x in matlist: + x = [str(i.item()) for i in x] + f.write(tf.compat.as_bytes("\t".join(x) + "\n")) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_onnx_graph.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_onnx_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..abadb7c9fdb421eb328f031a81aab5e231d4ca40 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_onnx_graph.py @@ -0,0 +1,62 @@ +# mypy: allow-untyped-defs +from tensorboard.compat.proto.graph_pb2 import GraphDef +from tensorboard.compat.proto.node_def_pb2 import NodeDef +from tensorboard.compat.proto.versions_pb2 import VersionDef +from tensorboard.compat.proto.attr_value_pb2 import AttrValue +from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto + + +def load_onnx_graph(fname): + import onnx + + m = onnx.load(fname) # type: ignore[attr-defined] + g = m.graph + return parse(g) + + +def parse(graph): + nodes = [] + import itertools + + nodes_proto = list(itertools.chain(graph.input, graph.output)) + + for node in nodes_proto: + print(node.name) + shapeproto = TensorShapeProto( + dim=[ + # pyrefly: ignore [missing-attribute] + TensorShapeProto.Dim(size=d.dim_value) + for d in node.type.tensor_type.shape.dim + ] + ) + nodes.append( + NodeDef( + name=node.name.encode(encoding="utf_8"), + op="Variable", + input=[], + attr={ + "dtype": AttrValue(type=node.type.tensor_type.elem_type), + "shape": AttrValue(shape=shapeproto), + }, + ) + ) + + for node in graph.node: + _attr = [" = ".join([str(f[1]) for f in s.ListFields()]) for s in node.attribute] + attr = ", ".join(_attr).encode(encoding="utf_8") + print(node.output[0]) + nodes.append( + NodeDef( + name=node.output[0].encode(encoding="utf_8"), + op=node.op_type, + input=node.input, + attr={"parameters": AttrValue(s=attr)}, + ) + ) + + # two pass token replacement, appends opname to object id + mapping = {} + for node in nodes: + mapping[node.name] = node.op + "_" + node.name + + return GraphDef(node=nodes, versions=VersionDef(producer=22)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_proto_graph.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_proto_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..b79ba0dfac04802b057a1c29109a0ff163711d6e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_proto_graph.py @@ -0,0 +1,59 @@ +import torch + +from collections.abc import Sequence +from tensorboard.compat.proto.node_def_pb2 import NodeDef +from tensorboard.compat.proto.attr_value_pb2 import AttrValue +from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto + + +# pyrefly: ignore [not-a-type] +def attr_value_proto(dtype: object, shape: Sequence[int] | None, s: str | None) -> dict[str, AttrValue]: + """Create a dict of objects matching a NodeDef's attr field. + + Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto + specifically designed for a NodeDef. The values have been reverse engineered from + standard TensorBoard logged data. + """ + attr = {} + if s is not None: + attr["attr"] = AttrValue(s=s.encode(encoding="utf_8")) + if shape is not None: + shapeproto = tensor_shape_proto(shape) + # pyrefly: ignore [missing-attribute] + attr["_output_shapes"] = AttrValue(list=AttrValue.ListValue(shape=[shapeproto])) + return attr + + +# pyrefly: ignore [not-a-type] +def tensor_shape_proto(outputsize: Sequence[int]) -> TensorShapeProto: + """Create an object matching a tensor_shape field. + + Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto . + """ + # pyrefly: ignore [missing-attribute] + return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in outputsize]) + + +def node_proto( + name: str, + op: str = "UnSpecified", + input: list[str] | str | None = None, + dtype: torch.dtype | None = None, + shape: tuple[int, ...] | None = None, + outputsize: Sequence[int] | None = None, + attributes: str = "", +) -> NodeDef: # pyrefly: ignore [not-a-type] + """Create an object matching a NodeDef. + + Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto . + """ + if input is None: + input = [] + if not isinstance(input, list): + input = [input] + return NodeDef( + name=name.encode(encoding="utf_8"), + op=op, + input=input, + attr=attr_value_proto(dtype, outputsize, attributes), + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_pytorch_graph.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_pytorch_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..5a052016130b100f80b9a07dbb126182bce530cd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_pytorch_graph.py @@ -0,0 +1,378 @@ +# mypy: allow-untyped-defs +from collections import OrderedDict +import contextlib +from typing import Any + +from tensorboard.compat.proto.config_pb2 import RunMetadata +from tensorboard.compat.proto.graph_pb2 import GraphDef +from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats +from tensorboard.compat.proto.versions_pb2 import VersionDef + +import torch +from ._proto_graph import node_proto + +methods_OP = [ + "attributeNames", + "hasMultipleOutputs", + "hasUses", + "inputs", + "kind", + "outputs", + "outputsSize", + "scopeName", +] +# Some additional methods to explure for methods_IO are +# +# 'unique' (type int) +# 'type' (type >) +# +# But the below are sufficient for now. +methods_IO = ["node", "offset", "debugName"] + +GETATTR_KIND = "prim::GetAttr" +CLASSTYPE_KIND = "ClassType" + + +class NodeBase: + def __init__( + self, + debugName=None, + inputs=None, + scope=None, + tensor_size=None, + op_type="UnSpecified", + attributes="", + ) -> None: + # TODO; Specify a __slots__ for this class or potentially + # used namedtuple instead + self.debugName = debugName + self.inputs = inputs + self.tensor_size = tensor_size + self.kind = op_type + self.attributes = attributes + self.scope = scope + + def __repr__(self) -> str: + repr = [] + repr.append(str(type(self))) + repr.extend( + m + ": " + str(getattr(self, m)) + str(type(getattr(self, m))) + for m in dir(self) + if "__" not in m + ) + return "\n".join(repr) + "\n\n" + + +class NodePy(NodeBase): + def __init__(self, node_cpp, valid_methods) -> None: + super().__init__(node_cpp) + valid_methods = valid_methods[:] + self.inputs = [] + + for m in valid_methods: + if m == "inputs" or m == "outputs": + list_of_node = list(getattr(node_cpp, m)()) + io_unique_names = [] + io_tensor_sizes = [] + for n in list_of_node: + io_unique_names.append(n.debugName()) + if n.isCompleteTensor(): + io_tensor_sizes.append(n.type().sizes()) + else: + io_tensor_sizes.append(None) + + setattr(self, m, io_unique_names) + setattr(self, m + "tensor_size", io_tensor_sizes) + + else: + setattr(self, m, getattr(node_cpp, m)()) + + +class NodePyIO(NodePy): + def __init__(self, node_cpp, input_or_output=None) -> None: + super().__init__(node_cpp, methods_IO) + try: + tensor_size = node_cpp.type().sizes() + except RuntimeError: + tensor_size = [ + 1, + ] # fail when constant model is used. + self.tensor_size = tensor_size + # Kind attribute string is purely descriptive and will be shown + # in detailed information for the node in TensorBoard's graph plugin. + # + # NodePyOP nodes get this from their kind() method. + self.kind = "Parameter" + if input_or_output: + self.input_or_output = input_or_output + self.kind = "IO Node" + + +class NodePyOP(NodePy): + def __init__(self, node_cpp) -> None: + super().__init__(node_cpp, methods_OP) + # Replace single quote which causes strange behavior in TensorBoard + # TODO: See if we can remove this in the future + self.attributes = str( + {k: _node_get(node_cpp, k) for k in node_cpp.attributeNames()} + ).replace("'", " ") + self.kind = node_cpp.kind() + + +class GraphPy: + """Helper class to convert torch.nn.Module to GraphDef proto and visualization with TensorBoard. + + GraphDef generation operates in two passes: + + In the first pass, all nodes are read and saved to two lists. + One list is for input/output nodes (nodes_io), which only have inbound + or outbound connections, but not both. Another list is for internal + operator nodes (nodes_op). The first pass also saves all scope name + appeared in the nodes in scope_name_appeared list for later processing. + + In the second pass, scope names are fully applied to all nodes. + debugNameToScopedName is a mapping from a node's ID to its fully qualified + scope name. e.g. Net1/Linear[0]/1. Unfortunately torch.jit doesn't have + totally correct scope output, so this is nontrivial. The function + populate_namespace_from_OP_to_IO and find_common_root are used to + assign scope name to a node based on the connection between nodes + in a heuristic kind of way. Bookkeeping is done with shallowest_scope_name + and scope_name_appeared. + """ + + def __init__(self) -> None: + self.nodes_op = [] + self.nodes_io = OrderedDict() + self.unique_name_to_scoped_name = {} + self.shallowest_scope_name = "default" + self.scope_name_appeared = [] + + def append(self, x) -> None: + if isinstance(x, NodePyIO): + self.nodes_io[x.debugName] = x + if isinstance(x, NodePyOP): + self.nodes_op.append(x) + + def printall(self) -> None: + print("all nodes") + for node in self.nodes_op: + print(node) + for key in self.nodes_io: + print(self.nodes_io[key]) + + def find_common_root(self) -> None: + for fullscope in self.scope_name_appeared: + if fullscope: + self.shallowest_scope_name = fullscope.split("/")[0] + + def populate_namespace_from_OP_to_IO(self) -> None: + for node in self.nodes_op: + for node_output, outputSize in zip(node.outputs, node.outputstensor_size, strict=True): + self.scope_name_appeared.append(node.scopeName) + self.nodes_io[node_output] = NodeBase( + node_output, + node.inputs, + node.scopeName, + outputSize, + op_type=node.kind, + attributes=node.attributes, + ) + + self.find_common_root() + + for node in self.nodes_op: + for input_node_id in node.inputs: + self.unique_name_to_scoped_name[input_node_id] = ( + node.scopeName + "/" + input_node_id + ) + + for key, node in self.nodes_io.items(): + if type(node) is NodeBase: + # pyrefly: ignore [unsupported-operation] + self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName + if hasattr(node, "input_or_output"): + self.unique_name_to_scoped_name[key] = ( + node.input_or_output + "/" + node.debugName + ) + + if hasattr(node, "scope") and node.scope is not None: + self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName + if node.scope == "" and self.shallowest_scope_name: + self.unique_name_to_scoped_name[node.debugName] = ( + # pyrefly: ignore [unsupported-operation] + self.shallowest_scope_name + "/" + node.debugName + ) + + # replace name + for key, node in self.nodes_io.items(): + self.nodes_io[key].inputs = [ + self.unique_name_to_scoped_name[node_input_id] + for node_input_id in node.inputs + ] + if node.debugName in self.unique_name_to_scoped_name: + self.nodes_io[key].debugName = self.unique_name_to_scoped_name[ + node.debugName + ] + + def to_proto(self): + """Convert graph representation of GraphPy object to TensorBoard required format.""" + # TODO: compute correct memory usage and CPU time once + # PyTorch supports it + nodes = [ + node_proto( + v.debugName, + input=v.inputs, + outputsize=v.tensor_size, + op=v.kind, + attributes=v.attributes, + ) + for v in self.nodes_io.values() + ] + return nodes + + +def parse(graph, trace, args=None, omit_useless_nodes=True): + """Parse an optimized PyTorch model graph and produces a list of nodes and node stats. + + Useful for eventual conversion to TensorBoard protobuf format. + + Args: + graph (PyTorch module): The model graph to be parsed. + trace (PyTorch JIT TracedModule): The model trace to be parsed. + args (tuple): input tensor[s] for the model. + omit_useless_nodes (boolean): Whether to remove nodes from the graph. + """ + nodes_py = GraphPy() + for node in graph.inputs(): + if omit_useless_nodes: + if ( + len(node.uses()) == 0 + ): # number of user of the node (= number of outputs/ fanout) + continue + + if node.type().kind() != CLASSTYPE_KIND: + nodes_py.append(NodePyIO(node, "input")) + + attr_to_scope: dict[Any, str] = {} + for node in graph.nodes(): + if node.kind() == GETATTR_KIND: + attr_name = node.s("name") + attr_key = node.output().debugName() + parent = node.input().node() + if ( + parent.kind() == GETATTR_KIND + ): # If the parent node is not the top-level "self" node + parent_attr_key = parent.output().debugName() + parent_scope = attr_to_scope[parent_attr_key] + attr_scope = parent_scope.split("/")[-1] + attr_to_scope[attr_key] = f"{parent_scope}/{attr_scope}.{attr_name}" + else: + attr_to_scope[attr_key] = f"__module.{attr_name}" + # We don't need classtype nodes; scope will provide this information + if node.output().type().kind() != CLASSTYPE_KIND: + node_py = NodePyOP(node) + node_py.scopeName = attr_to_scope[attr_key] # type: ignore[attr-defined] + nodes_py.append(node_py) + else: + nodes_py.append(NodePyOP(node)) + + for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops + node_pyio = NodePyIO(node, "output") + node_pyio.debugName = f"output.{i + 1}" + node_pyio.inputs = [node.debugName()] + nodes_py.append(node_pyio) + + def parse_traced_name(module): + if isinstance(module, torch.jit.TracedModule): + module_name = module._name + else: + module_name = getattr(module, "original_name", "Module") + return module_name + + alias_to_name = {} + base_name = parse_traced_name(trace) + for name, module in trace.named_modules(prefix="__module"): + mod_name = parse_traced_name(module) + attr_name = name.split(".")[-1] + alias_to_name[name] = f"{mod_name}[{attr_name}]" + + for node in nodes_py.nodes_op: + module_aliases = node.scopeName.split("/") + replacements = [ + alias_to_name[alias] if alias in alias_to_name else alias.split(".")[-1] + for alias in module_aliases + ] + node.scopeName = base_name + if any(replacements): + node.scopeName += "/" + "/".join(replacements) + + nodes_py.populate_namespace_from_OP_to_IO() + return nodes_py.to_proto() + + +def graph(model, args, verbose=False, use_strict_trace=True): + """ + Process a PyTorch model and produces a `GraphDef` proto that can be logged to TensorBoard. + + Args: + model (PyTorch module): The model to be parsed. + args (tuple): input tensor[s] for the model. + verbose (bool): Whether to print out verbose information while + processing. + use_strict_trace (bool): Whether to pass keyword argument `strict` to + `torch.jit.trace`. Pass False when you want the tracer to + record your mutable container types (list, dict) + """ + with _set_model_to_eval(model): + try: + trace = torch.jit.trace(model, args, strict=use_strict_trace) + graph = trace.graph + torch._C._jit_pass_inline(graph) + except RuntimeError as e: + print(e) + print("Error occurs, No graph saved") + raise e + + if verbose: + print(graph) + list_of_nodes = parse(graph, trace, args) + # We are hardcoding that this was run on CPU even though it might have actually + # run on GPU. Note this is what is shown in TensorBoard and has no bearing + # on actual execution. + # TODO: See if we can extract GPU vs CPU information from the PyTorch model + # and pass it correctly to TensorBoard. + # + # Definition of StepStats and DeviceStepStats can be found at + # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/proto.ts + # and + # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto + stepstats = RunMetadata( + step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]) + ) + return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats + # The producer version has been reverse engineered from standard + # TensorBoard logged data. + + +@contextlib.contextmanager +def _set_model_to_eval(model): + """Context manager to temporarily set the training mode of ``model`` to eval.""" + if not isinstance(model, torch.jit.ScriptFunction): + originally_training = model.training + model.train(False) + try: + yield + finally: + model.train(originally_training) + else: + # Do nothing for ScriptFunction + try: + yield + finally: + pass + + +def _node_get(node: torch._C.Node, key: str): + """Get attributes of a node which is polymorphic over return type.""" + sel = node.kindOf(key) + return getattr(node, sel)(key) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1bafc22183afbce001c8db20bd62f35cbcc2a663 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/_utils.py @@ -0,0 +1,131 @@ +# mypy: allow-untyped-defs +import numpy as np +import numpy.typing as npt + + +# Functions for converting +def figure_to_image(figures, close=True): + """Render matplotlib figure to numpy format. + + Note that this requires the ``matplotlib`` package. + + Args: + figures (matplotlib.pyplot.figure or list of figures): figure or a list of figures + close (bool): Flag to automatically close the figure + + Returns: + numpy.array: image in [CHW] order + """ + import matplotlib.pyplot as plt + import matplotlib.backends.backend_agg as plt_backend_agg + + def render_to_rgb(figure): + canvas = plt_backend_agg.FigureCanvasAgg(figure) + canvas.draw() + data: npt.NDArray = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) + w, h = figure.canvas.get_width_height() + image_hwc = data.reshape([h, w, 4])[:, :, 0:3] + image_chw = np.moveaxis(image_hwc, source=2, destination=0) + if close: + plt.close(figure) + return image_chw + + if isinstance(figures, list): + images = [render_to_rgb(figure) for figure in figures] + return np.stack(images) + else: + image = render_to_rgb(figures) + return image + + +def _prepare_video(V): + """ + Convert a 5D tensor into 4D tensor. + + Convesrion is done from [batchsize, time(frame), channel(color), height, width] (5D tensor) + to [time(frame), new_width, new_height, channel] (4D tensor). + + A batch of images are spread to a grid, which forms a frame. + e.g. Video with batchsize 16 will have a 4x4 grid. + """ + b, t, c, h, w = V.shape + + if V.dtype == np.uint8: + V = np.float32(V) / 255.0 + + def is_power2(num): + return num != 0 and ((num & (num - 1)) == 0) + + # pad to nearest power of 2, all at once + # pyrefly: ignore [index-error] + if not is_power2(V.shape[0]): + # pyrefly: ignore [index-error] + len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0]) + V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0) + + n_rows = 2 ** ((b.bit_length() - 1) // 2) + # pyrefly: ignore [index-error] + n_cols = V.shape[0] // n_rows + + V = np.reshape(V, (n_rows, n_cols, t, c, h, w)) + V = np.transpose(V, axes=(2, 0, 4, 1, 5, 3)) + V = np.reshape(V, (t, n_rows * h, n_cols * w, c)) + + return V + + +def make_grid(I, ncols=8): + # I: N1HW or N3HW + if not isinstance(I, np.ndarray): + raise AssertionError("plugin error, should pass numpy array here") + if I.shape[1] == 1: + I = np.concatenate([I, I, I], 1) + if I.ndim != 4 or I.shape[1] != 3: + raise AssertionError("Input should be a 4D numpy array with 3 channels") + nimg = I.shape[0] + H = I.shape[2] + W = I.shape[3] + ncols = min(nimg, ncols) + nrows = int(np.ceil(float(nimg) / ncols)) + canvas = np.zeros((3, H * nrows, W * ncols), dtype=I.dtype) + i = 0 + for y in range(nrows): + for x in range(ncols): + if i >= nimg: + break + canvas[:, y * H : (y + 1) * H, x * W : (x + 1) * W] = I[i] + i = i + 1 + return canvas + + # if modality == 'IMG': + # if x.dtype == np.uint8: + # x = x.astype(np.float32) / 255.0 + + +def convert_to_HWC(tensor, input_format): # tensor: numpy array + if len(set(input_format)) != len(input_format): + raise AssertionError(f"You can not use the same dimension shordhand twice. \ + input_format: {input_format}") + if len(tensor.shape) != len(input_format): + raise AssertionError(f"size of input tensor and input format are different. \ + tensor shape: {tensor.shape}, input_format: {input_format}") + input_format = input_format.upper() + + if len(input_format) == 4: + index = [input_format.find(c) for c in "NCHW"] + tensor_NCHW = tensor.transpose(index) + tensor_CHW = make_grid(tensor_NCHW) + return tensor_CHW.transpose(1, 2, 0) + + if len(input_format) == 3: + index = [input_format.find(c) for c in "HWC"] + tensor_HWC = tensor.transpose(index) + if tensor_HWC.shape[2] == 1: + tensor_HWC = np.concatenate([tensor_HWC, tensor_HWC, tensor_HWC], 2) + return tensor_HWC + + if len(input_format) == 2: + index = [input_format.find(c) for c in "HW"] + tensor = tensor.transpose(index) + tensor = np.stack([tensor, tensor, tensor], 2) + return tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/summary.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/summary.py new file mode 100644 index 0000000000000000000000000000000000000000..3e538ddc4c02d9f26d80e653eb57e66ae7146ed4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/summary.py @@ -0,0 +1,1032 @@ +# mypy: allow-untyped-defs +import json +import logging +import struct + +from typing import Any + +import torch +import numpy as np + + +from google.protobuf import struct_pb2 + +from tensorboard.compat.proto.summary_pb2 import ( + HistogramProto, + Summary, + SummaryMetadata, +) +from tensorboard.compat.proto.tensor_pb2 import TensorProto +from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto +from tensorboard.plugins.custom_scalar import layout_pb2 +from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData +from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData + +from ._convert_np import make_np +from ._utils import _prepare_video, convert_to_HWC + +__all__ = [ + "half_to_int", + "int_to_half", + "hparams", + "scalar", + "histogram_raw", + "histogram", + "make_histogram", + "image", + "image_boxes", + "draw_boxes", + "make_image", + "video", + "make_video", + "audio", + "custom_scalars", + "text", + "tensor_proto", + "pr_curve_raw", + "pr_curve", + "compute_curve", + "mesh", +] + +logger = logging.getLogger(__name__) + +def half_to_int(f: float) -> int: + """Casts a half-precision float value into an integer. + + Converts a half precision floating point value, such as `torch.half` or + `torch.bfloat16`, into an integer value which can be written into the + half_val field of a TensorProto for storage. + + To undo the effects of this conversion, use int_to_half(). + + """ + buf = struct.pack("f", f) + return struct.unpack("i", buf)[0] + +def int_to_half(i: int) -> float: + """Casts an integer value to a half-precision float. + + Converts an integer value obtained from half_to_int back into a floating + point value. + + """ + buf = struct.pack("i", i) + return struct.unpack("f", buf)[0] + +def _tensor_to_half_val(t: torch.Tensor) -> list[int]: + return [half_to_int(x) for x in t.flatten().tolist()] + +def _tensor_to_complex_val(t: torch.Tensor) -> list[float]: + return torch.view_as_real(t).flatten().tolist() + +def _tensor_to_list(t: torch.Tensor) -> list[Any]: + return t.flatten().tolist() + +# type maps: torch.Tensor type -> (protobuf type, protobuf val field) +_TENSOR_TYPE_MAP = { + torch.half: ("DT_HALF", "half_val", _tensor_to_half_val), + torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val), + torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val), + torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list), + torch.float: ("DT_FLOAT", "float_val", _tensor_to_list), + torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list), + torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list), + torch.int8: ("DT_INT8", "int_val", _tensor_to_list), + torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list), + torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list), + torch.int16: ("DT_INT16", "int_val", _tensor_to_list), + torch.short: ("DT_INT16", "int_val", _tensor_to_list), + torch.int: ("DT_INT32", "int_val", _tensor_to_list), + torch.int32: ("DT_INT32", "int_val", _tensor_to_list), + torch.qint32: ("DT_INT32", "int_val", _tensor_to_list), + torch.int64: ("DT_INT64", "int64_val", _tensor_to_list), + torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val), + torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val), + torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val), + torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val), + torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list), + torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val), + torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val), + torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list), + torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list), + torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list), +} + + +def _calc_scale_factor(tensor) -> int: + converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor + return 1 if converted.dtype == np.uint8 else 255 + + +def _draw_single_box( + image, + xmin, + ymin, + xmax, + ymax, + display_str, + color="black", + color_text="black", + thickness=2, +): + from PIL import ImageDraw, ImageFont + + font = ImageFont.load_default() + draw = ImageDraw.Draw(image) + (left, right, top, bottom) = (xmin, xmax, ymin, ymax) + draw.line( + [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], + width=thickness, + fill=color, + ) + if display_str: + text_bottom = bottom + # Reverse list and print from bottom to top. + _left, _top, _right, _bottom = font.getbbox(display_str) + text_width, text_height = _right - _left, _bottom - _top + margin = np.ceil(0.05 * text_height) + draw.rectangle( + [ + (left, text_bottom - text_height - 2 * margin), + (left + text_width, text_bottom), + ], + fill=color, + ) + draw.text( + (left + margin, text_bottom - text_height - margin), + display_str, + fill=color_text, + font=font, + ) + return image + + +def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): + """Output three `Summary` protocol buffers needed by hparams plugin. + + `Experiment` keeps the metadata of an experiment, such as the name of the + hyperparameters and the name of the metrics. + `SessionStartInfo` keeps key-value pairs of the hyperparameters + `SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS + + Args: + hparam_dict: A dictionary that contains names of the hyperparameters + and their values. + metric_dict: A dictionary that contains names of the metrics + and their values. + hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that + contains names of the hyperparameters and all discrete values they can hold + + Returns: + The `Summary` protobufs for Experiment, SessionStartInfo and + SessionEndInfo + """ + import torch + from tensorboard.plugins.hparams.api_pb2 import ( + DataType, + Experiment, + HParamInfo, + MetricInfo, + MetricName, + Status, + ) + from tensorboard.plugins.hparams.metadata import ( + EXPERIMENT_TAG, + PLUGIN_DATA_VERSION, + PLUGIN_NAME, + SESSION_END_INFO_TAG, + SESSION_START_INFO_TAG, + ) + from tensorboard.plugins.hparams.plugin_data_pb2 import ( + HParamsPluginData, + SessionEndInfo, + SessionStartInfo, + ) + + # TODO: expose other parameters in the future. + # hp = HParamInfo(name='lr',display_name='learning rate', + # type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10, + # max_value=100)) + # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy', + # description='', dataset_type=DatasetType.DATASET_VALIDATION) + # exp = Experiment(name='123', description='456', time_created_secs=100.0, + # hparam_infos=[hp], metric_infos=[mt], user='tw') + + if not isinstance(hparam_dict, dict): + logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.") + raise TypeError( + "parameter: hparam_dict should be a dictionary, nothing logged." + ) + if not isinstance(metric_dict, dict): + logger.warning("parameter: metric_dict should be a dictionary, nothing logged.") + raise TypeError( + "parameter: metric_dict should be a dictionary, nothing logged." + ) + + hparam_domain_discrete = hparam_domain_discrete or {} + if not isinstance(hparam_domain_discrete, dict): + raise TypeError( + "parameter: hparam_domain_discrete should be a dictionary, nothing logged." + ) + for k, v in hparam_domain_discrete.items(): + if ( + k not in hparam_dict + or not isinstance(v, list) + or not all(isinstance(d, type(hparam_dict[k])) for d in v) + ): + raise TypeError( + f"parameter: hparam_domain_discrete[{k}] should be a list of same type as hparam_dict[{k}]." + ) + hps = [] + + ssi = SessionStartInfo() + for k, v in hparam_dict.items(): + if v is None: + continue + if isinstance(v, (int, float)): + ssi.hparams[k].number_value = v + + if k in hparam_domain_discrete: + domain_discrete: struct_pb2.ListValue | None = struct_pb2.ListValue( + values=[ + struct_pb2.Value(number_value=d) + for d in hparam_domain_discrete[k] + ] + ) + else: + domain_discrete = None + + hps.append( + HParamInfo( + name=k, + # pyrefly: ignore [missing-attribute] + type=DataType.Value("DATA_TYPE_FLOAT64"), + domain_discrete=domain_discrete, + ) + ) + continue + + if isinstance(v, str): + ssi.hparams[k].string_value = v + + if k in hparam_domain_discrete: + domain_discrete = struct_pb2.ListValue( + values=[ + struct_pb2.Value(string_value=d) + for d in hparam_domain_discrete[k] + ] + ) + else: + domain_discrete = None + + hps.append( + HParamInfo( + name=k, + # pyrefly: ignore [missing-attribute] + type=DataType.Value("DATA_TYPE_STRING"), + domain_discrete=domain_discrete, + ) + ) + continue + + if isinstance(v, bool): + ssi.hparams[k].bool_value = v + + if k in hparam_domain_discrete: + domain_discrete = struct_pb2.ListValue( + values=[ + struct_pb2.Value(bool_value=d) + for d in hparam_domain_discrete[k] + ] + ) + else: + domain_discrete = None + + hps.append( + HParamInfo( + name=k, + # pyrefly: ignore [missing-attribute] + type=DataType.Value("DATA_TYPE_BOOL"), + domain_discrete=domain_discrete, + ) + ) + continue + + if isinstance(v, torch.Tensor): + v = make_np(v)[0] + ssi.hparams[k].number_value = v + # pyrefly: ignore [missing-attribute] + hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64"))) + continue + raise ValueError( + "value should be one of int, float, str, bool, or torch.Tensor" + ) + + content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION) + smd = SummaryMetadata( + # pyrefly: ignore [missing-attribute] + plugin_data=SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content.SerializeToString() + ) + ) + # pyrefly: ignore [missing-attribute] + ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)]) + + mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict] + + exp = Experiment(hparam_infos=hps, metric_infos=mts) + + content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION) + smd = SummaryMetadata( + # pyrefly: ignore [missing-attribute] + plugin_data=SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content.SerializeToString() + ) + ) + # pyrefly: ignore [missing-attribute] + exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)]) + + # pyrefly: ignore [missing-attribute] + sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS")) + content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION) + smd = SummaryMetadata( + # pyrefly: ignore [missing-attribute] + plugin_data=SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content.SerializeToString() + ) + ) + # pyrefly: ignore [missing-attribute] + sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)]) + + return exp, ssi, sei + + +def scalar(name, tensor, collections=None, new_style=False, double_precision=False): + """Output a `Summary` protocol buffer containing a single scalar value. + + The generated Summary has a Tensor.proto containing the input Tensor. + Args: + name: A name for the generated node. Will also serve as the series name in + TensorBoard. + tensor: A real numeric Tensor containing a single value. + collections: Optional list of graph collections keys. The new summary op is + added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. + new_style: Whether to use new style (tensor field) or old style (simple_value + field). New style could lead to faster data loading. + Returns: + A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf. + Raises: + ValueError: If tensor has the wrong shape or type. + """ + tensor = make_np(tensor).squeeze() + if tensor.ndim != 0: + raise AssertionError(f"Tensor should contain one element (0 dimensions). \ + Was given size: {tensor.size} and {tensor.ndim} dimensions.") + # python float is double precision in numpy + scalar = float(tensor) + if new_style: + tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT") + if double_precision: + tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE") + + # pyrefly: ignore [missing-attribute] + plugin_data = SummaryMetadata.PluginData(plugin_name="scalars") + smd = SummaryMetadata(plugin_data=plugin_data) + return Summary( + value=[ + # pyrefly: ignore [missing-attribute] + Summary.Value( + tag=name, + tensor=tensor_proto, + metadata=smd, + ) + ] + ) + else: + # pyrefly: ignore [missing-attribute] + return Summary(value=[Summary.Value(tag=name, simple_value=scalar)]) + + +def tensor_proto(tag, tensor): + """Outputs a `Summary` protocol buffer containing the full tensor. + The generated Summary has a Tensor.proto containing the input Tensor. + Args: + tag: A name for the generated node. Will also serve as the series name in + TensorBoard. + tensor: Tensor to be converted to protobuf + Returns: + A tensor protobuf in a `Summary` protobuf. + Raises: + ValueError: If tensor is too big to be converted to protobuf, or + tensor data type is not supported + """ + if tensor.numel() * tensor.itemsize >= (1 << 31): + raise ValueError( + "tensor is bigger than protocol buffer's hard limit of 2GB in size" + ) + + if tensor.dtype in _TENSOR_TYPE_MAP: + dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype] + tensor_proto = TensorProto( + **{ + "dtype": dtype, + "tensor_shape": TensorShapeProto( + # pyrefly: ignore [missing-attribute] + dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape] + ), + field_name: conversion_fn(tensor), + }, + ) + else: + raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}") + + # pyrefly: ignore [missing-attribute] + plugin_data = SummaryMetadata.PluginData(plugin_name="tensor") + smd = SummaryMetadata(plugin_data=plugin_data) + # pyrefly: ignore [missing-attribute] + return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)]) + + +def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts): + # pylint: disable=line-too-long + """Output a `Summary` protocol buffer with a histogram. + + The generated + [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) + has one summary value containing a histogram for `values`. + Args: + name: A name for the generated node. Will also serve as a series name in + TensorBoard. + min: A float or int min value + max: A float or int max value + num: Int number of values + sum: Float or int sum of all values + sum_squares: Float or int sum of squares for all values + bucket_limits: A numeric `Tensor` with upper value per bucket + bucket_counts: A numeric `Tensor` with number of values per bucket + Returns: + A scalar `Tensor` of type `string`. The serialized `Summary` protocol + buffer. + """ + hist = HistogramProto( + min=min, + max=max, + num=num, + sum=sum, + sum_squares=sum_squares, + bucket_limit=bucket_limits, + bucket=bucket_counts, + ) + # pyrefly: ignore [missing-attribute] + return Summary(value=[Summary.Value(tag=name, histo=hist)]) + + +def histogram(name, values, bins, max_bins=None): + # pylint: disable=line-too-long + """Output a `Summary` protocol buffer with a histogram. + + The generated + [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) + has one summary value containing a histogram for `values`. + This op reports an `InvalidArgument` error if any value is not finite. + Args: + name: A name for the generated node. Will also serve as a series name in + TensorBoard. + values: A real numeric `Tensor`. Any shape. Values to use to + build the histogram. + Returns: + A scalar `Tensor` of type `string`. The serialized `Summary` protocol + buffer. + """ + values = make_np(values) + hist = make_histogram(values.astype(float), bins, max_bins) + # pyrefly: ignore [missing-attribute] + return Summary(value=[Summary.Value(tag=name, histo=hist)]) + + +def make_histogram(values, bins, max_bins=None): + """Convert values into a histogram proto using logic from histogram.cc.""" + if values.size == 0: + raise ValueError("The input has no element.") + values = values.reshape(-1) + counts, limits = np.histogram(values, bins=bins) + num_bins = len(counts) + if max_bins is not None and num_bins > max_bins: + subsampling = num_bins // max_bins + subsampling_remainder = num_bins % subsampling + if subsampling_remainder != 0: + # pyrefly: ignore [no-matching-overload] + counts = np.pad( + counts, + pad_width=[[0, subsampling - subsampling_remainder]], + mode="constant", + constant_values=0, + ) + counts = counts.reshape(-1, subsampling).sum(axis=-1) + new_limits = np.empty((counts.size + 1,), limits.dtype) + new_limits[:-1] = limits[:-1:subsampling] + new_limits[-1] = limits[-1] + limits = new_limits + + # Find the first and the last bin defining the support of the histogram: + + cum_counts = np.cumsum(np.greater(counts, 0)) + start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right") + start = int(start) + end = int(end) + 1 + del cum_counts + + # TensorBoard only includes the right bin limits. To still have the leftmost limit + # included, we include an empty bin left. + # If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the + # first nonzero-count bin: + counts = ( + counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]]) + ) + limits = limits[start : end + 1] + + if counts.size == 0 or limits.size == 0: + raise ValueError("The histogram is empty, please file a bug report.") + + sum_sq = values.dot(values) + return HistogramProto( + min=values.min(), + max=values.max(), + num=len(values), + sum=values.sum(), + sum_squares=sum_sq, + bucket_limit=limits.tolist(), + bucket=counts.tolist(), + ) + + +def image(tag, tensor, rescale=1, dataformats="NCHW"): + """Output a `Summary` protocol buffer with images. + + The summary has up to `max_images` summary values containing images. The + images are built from `tensor` which must be 3-D with shape `[height, width, + channels]` and where `channels` can be: + * 1: `tensor` is interpreted as Grayscale. + * 3: `tensor` is interpreted as RGB. + * 4: `tensor` is interpreted as RGBA. + The `name` in the outputted Summary.Value protobufs is generated based on the + name, with a suffix depending on the max_outputs setting: + * If `max_outputs` is 1, the summary value tag is '*name*/image'. + * If `max_outputs` is greater than 1, the summary value tags are + generated sequentially as '*name*/image/0', '*name*/image/1', etc. + Args: + tag: A name for the generated node. Will also serve as a series name in + TensorBoard. + tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width, + channels]` where `channels` is 1, 3, or 4. + 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8). + The image() function will scale the image values to [0, 255] by applying + a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values + will be clipped. + Returns: + A scalar `Tensor` of type `string`. The serialized `Summary` protocol + buffer. + """ + tensor = make_np(tensor) + tensor = convert_to_HWC(tensor, dataformats) + # Do not assume that user passes in values in [0, 255], use data type to detect + scale_factor = _calc_scale_factor(tensor) + tensor = tensor.astype(np.float32) + tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8) + image = make_image(tensor, rescale=rescale) + # pyrefly: ignore [missing-attribute] + return Summary(value=[Summary.Value(tag=tag, image=image)]) + + +def image_boxes( + tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None +): + """Output a `Summary` protocol buffer with images.""" + tensor_image = make_np(tensor_image) + tensor_image = convert_to_HWC(tensor_image, dataformats) + tensor_boxes = make_np(tensor_boxes) + tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image) + image = make_image( + tensor_image.clip(0, 255).astype(np.uint8), + rescale=rescale, + rois=tensor_boxes, + labels=labels, + ) + # pyrefly: ignore [missing-attribute] + return Summary(value=[Summary.Value(tag=tag, image=image)]) + + +def draw_boxes(disp_image, boxes, labels=None): + # xyxy format + num_boxes = boxes.shape[0] + list_gt = range(num_boxes) + for i in list_gt: + disp_image = _draw_single_box( + disp_image, + boxes[i, 0], + boxes[i, 1], + boxes[i, 2], + boxes[i, 3], + display_str=None if labels is None else labels[i], + color="Red", + ) + return disp_image + + +def make_image(tensor, rescale=1, rois=None, labels=None): + """Convert a numpy representation of an image to Image protobuf.""" + from PIL import Image + + height, width, channel = tensor.shape + scaled_height = int(height * rescale) + scaled_width = int(width * rescale) + image = Image.fromarray(tensor) + if rois is not None: + image = draw_boxes(image, rois, labels=labels) + ANTIALIAS = Image.Resampling.LANCZOS + image = image.resize((scaled_width, scaled_height), ANTIALIAS) + import io + + output = io.BytesIO() + image.save(output, format="PNG") + image_string = output.getvalue() + output.close() + # pyrefly: ignore [missing-attribute] + return Summary.Image( + height=height, + width=width, + colorspace=channel, + encoded_image_string=image_string, + ) + + +def video(tag, tensor, fps=4): + tensor = make_np(tensor) + tensor = _prepare_video(tensor) + # If user passes in uint8, then we don't need to rescale by 255 + scale_factor = _calc_scale_factor(tensor) + tensor = tensor.astype(np.float32) + tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8) + video = make_video(tensor, fps) + # pyrefly: ignore [missing-attribute] + return Summary(value=[Summary.Value(tag=tag, image=video)]) + + +def make_video(tensor, fps): + try: + import moviepy # noqa: F401 + except ImportError: + print("add_video needs package moviepy") + return + try: + from moviepy import editor as mpy + except ImportError: + print( + "moviepy is installed, but can't import moviepy.editor.", + "Some packages could be missing [imageio, requests]", + ) + return + import tempfile + + _t, h, w, c = tensor.shape + + # encode sequence of images into gif string + clip = mpy.ImageSequenceClip(list(tensor), fps=fps) + + with tempfile.NamedTemporaryFile(suffix=".gif") as f: + filename = f.name + try: # newer version of moviepy use logger instead of progress_bar argument. + clip.write_gif(filename, verbose=False, logger=None) + except TypeError: + try: # older version of moviepy does not support progress_bar argument. + clip.write_gif(filename, verbose=False, progress_bar=False) + except TypeError: + clip.write_gif(filename, verbose=False) + + f.seek(0) + tensor_string = f.read() + + # pyrefly: ignore [missing-attribute] + return Summary.Image( + height=h, width=w, colorspace=c, encoded_image_string=tensor_string + ) + + +def audio(tag, tensor, sample_rate=44100): + array = make_np(tensor) + array = array.squeeze() + if abs(array).max() > 1: + print("warning: audio amplitude out of range, auto clipped.") + array = array.clip(-1, 1) + if array.ndim != 1: + raise AssertionError("input tensor should be 1 dimensional.") + array = (array * np.iinfo(np.int16).max).astype(" 127: # weird, value > 127 breaks protobuf + num_thresholds = 127 + data = np.stack((tp, fp, tn, fn, precision, recall)) + pr_curve_plugin_data = PrCurvePluginData( + version=0, num_thresholds=num_thresholds + ).SerializeToString() + # pyrefly: ignore [missing-attribute] + plugin_data = SummaryMetadata.PluginData( + plugin_name="pr_curves", content=pr_curve_plugin_data + ) + smd = SummaryMetadata(plugin_data=plugin_data) + tensor = TensorProto( + dtype="DT_FLOAT", + float_val=data.reshape(-1).tolist(), + tensor_shape=TensorShapeProto( + dim=[ + # pyrefly: ignore [missing-attribute] + TensorShapeProto.Dim(size=data.shape[0]), + # pyrefly: ignore [missing-attribute] + TensorShapeProto.Dim(size=data.shape[1]), + ] + ), + ) + # pyrefly: ignore [missing-attribute] + return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) + + +def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None): + # weird, value > 127 breaks protobuf + num_thresholds = min(num_thresholds, 127) + data = compute_curve( + labels, predictions, num_thresholds=num_thresholds, weights=weights + ) + pr_curve_plugin_data = PrCurvePluginData( + version=0, num_thresholds=num_thresholds + ).SerializeToString() + # pyrefly: ignore [missing-attribute] + plugin_data = SummaryMetadata.PluginData( + plugin_name="pr_curves", content=pr_curve_plugin_data + ) + smd = SummaryMetadata(plugin_data=plugin_data) + tensor = TensorProto( + dtype="DT_FLOAT", + float_val=data.reshape(-1).tolist(), + tensor_shape=TensorShapeProto( + dim=[ + # pyrefly: ignore [missing-attribute] + TensorShapeProto.Dim(size=data.shape[0]), + # pyrefly: ignore [missing-attribute] + TensorShapeProto.Dim(size=data.shape[1]), + ] + ), + ) + # pyrefly: ignore [missing-attribute] + return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) + + +# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py +def compute_curve(labels, predictions, num_thresholds=None, weights=None): + _MINIMUM_COUNT = 1e-7 + + if weights is None: + weights = 1.0 + + # Compute bins of true positives and false positives. + # pyrefly: ignore [unsupported-operation] + bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1))) + float_labels = labels.astype(np.float64) + # pyrefly: ignore [unsupported-operation] + histogram_range = (0, num_thresholds - 1) + tp_buckets, _ = np.histogram( + bucket_indices, + # pyrefly: ignore [bad-argument-type] + bins=num_thresholds, + range=histogram_range, + weights=float_labels * weights, + ) + fp_buckets, _ = np.histogram( + bucket_indices, + # pyrefly: ignore [bad-argument-type] + bins=num_thresholds, + range=histogram_range, + weights=(1.0 - float_labels) * weights, + ) + + # Obtain the reverse cumulative sum. + tp = np.cumsum(tp_buckets[::-1])[::-1] + fp = np.cumsum(fp_buckets[::-1])[::-1] + tn = fp[0] - fp + fn = tp[0] - tp + precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp) + recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn) + return np.stack((tp, fp, tn, fn, precision, recall)) + + +def _get_tensor_summary( + name, display_name, description, tensor, content_type, components, json_config +): + """Create a tensor summary with summary metadata. + + Args: + name: Uniquely identifiable name of the summary op. Could be replaced by + combination of name and type to make it unique even outside of this + summary. + display_name: Will be used as the display name in TensorBoard. + Defaults to `name`. + description: A longform readable description of the summary data. Markdown + is supported. + tensor: Tensor to display in summary. + content_type: Type of content inside the Tensor. + components: Bitmask representing present parts (vertices, colors, etc.) that + belong to the summary. + json_config: A string, JSON-serialized dictionary of ThreeJS classes + configuration. + + Returns: + Tensor summary with metadata. + """ + import torch + from tensorboard.plugins.mesh import metadata + + tensor = torch.as_tensor(tensor) + + tensor_metadata = metadata.create_summary_metadata( + name, + display_name, + content_type, + components, + tensor.shape, + description, + json_config=json_config, + ) + + tensor = TensorProto( + dtype="DT_FLOAT", + float_val=tensor.reshape(-1).tolist(), + tensor_shape=TensorShapeProto( + dim=[ + # pyrefly: ignore [missing-attribute] + TensorShapeProto.Dim(size=tensor.shape[0]), + # pyrefly: ignore [missing-attribute] + TensorShapeProto.Dim(size=tensor.shape[1]), + # pyrefly: ignore [missing-attribute] + TensorShapeProto.Dim(size=tensor.shape[2]), + ] + ), + ) + + # pyrefly: ignore [missing-attribute] + tensor_summary = Summary.Value( + tag=metadata.get_instance_name(name, content_type), + tensor=tensor, + metadata=tensor_metadata, + ) + + return tensor_summary + + +def _get_json_config(config_dict): + """Parse and returns JSON string from python dictionary.""" + json_config = "{}" + if config_dict is not None: + json_config = json.dumps(config_dict, sort_keys=True) + return json_config + + +# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py +def mesh( + tag, vertices, colors, faces, config_dict, display_name=None, description=None +): + """Output a merged `Summary` protocol buffer with a mesh/point cloud. + + Args: + tag: A name for this summary operation. + vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D + coordinates of vertices. + faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of + vertices within each triangle. + colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each + vertex. + display_name: If set, will be used as the display name in TensorBoard. + Defaults to `name`. + description: A longform readable description of the summary data. Markdown + is supported. + config_dict: Dictionary with ThreeJS classes names and configuration. + + Returns: + Merged summary for mesh/point cloud representation. + """ + from tensorboard.plugins.mesh import metadata + from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData + + json_config = _get_json_config(config_dict) + + summaries = [] + tensors = [ + # pyrefly: ignore [missing-attribute] + (vertices, MeshPluginData.VERTEX), + # pyrefly: ignore [missing-attribute] + (faces, MeshPluginData.FACE), + # pyrefly: ignore [missing-attribute] + (colors, MeshPluginData.COLOR), + ] + tensors = [tensor for tensor in tensors if tensor[0] is not None] + components = metadata.get_components_bitmask( + [content_type for (tensor, content_type) in tensors] + ) + + for tensor, content_type in tensors: + summaries.append( + _get_tensor_summary( + tag, + display_name, + description, + tensor, + content_type, + components, + json_config, + ) + ) + + return Summary(value=summaries) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/writer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/writer.py new file mode 100644 index 0000000000000000000000000000000000000000..008ae59e94e6a5172907b39df7062752ca74c954 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/tensorboard/writer.py @@ -0,0 +1,1219 @@ +# mypy: allow-untyped-defs +"""Provide an API for writing protocol buffers to event files to be consumed by TensorBoard for visualization.""" + +import os +import time +from typing import TYPE_CHECKING, Union + +import torch + +if TYPE_CHECKING: + from matplotlib.figure import Figure +from tensorboard.compat import tf +from tensorboard.compat.proto import event_pb2 +from tensorboard.compat.proto.event_pb2 import Event, SessionLog +from tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig +from tensorboard.summary.writer.event_file_writer import EventFileWriter + +from ._convert_np import make_np +from ._embedding import get_embedding_info, make_mat, make_sprite, make_tsv, write_pbtxt +from ._onnx_graph import load_onnx_graph +from ._pytorch_graph import graph +from ._utils import figure_to_image +from .summary import ( + audio, + custom_scalars, + histogram, + histogram_raw, + hparams, + image, + image_boxes, + mesh, + pr_curve, + pr_curve_raw, + scalar, + tensor_proto, + text, + video, +) + +__all__ = ["FileWriter", "SummaryWriter"] + + +class FileWriter: + """Writes protocol buffers to event files to be consumed by TensorBoard. + + The `FileWriter` class provides a mechanism to create an event file in a + given directory and add summaries and events to it. The class updates the + file contents asynchronously. This allows a training program to call methods + to add data to the file directly from the training loop, without slowing down + training. + """ + + def __init__(self, log_dir, max_queue=10, flush_secs=120, filename_suffix="") -> None: + """Create a `FileWriter` and an event file. + + On construction the writer creates a new event file in `log_dir`. + The other arguments to the constructor control the asynchronous writes to + the event file. + + Args: + log_dir: A string. Directory where event file will be written. + max_queue: Integer. Size of the queue for pending events and + summaries before one of the 'add' calls forces a flush to disk. + Default is ten items. + flush_secs: Number. How often, in seconds, to flush the + pending events and summaries to disk. Default is every two minutes. + filename_suffix: A string. Suffix added to all event filenames + in the log_dir directory. More details on filename construction in + tensorboard.summary.writer.event_file_writer.EventFileWriter. + """ + # Sometimes PosixPath is passed in and we need to coerce it to + # a string in all cases + # TODO: See if we can remove this in the future if we are + # actually the ones passing in a PosixPath + log_dir = str(log_dir) + self.event_writer = EventFileWriter( + log_dir, max_queue, flush_secs, filename_suffix + ) + + def get_logdir(self): + """Return the directory where event file will be written.""" + return self.event_writer.get_logdir() + + def add_event(self, event, step=None, walltime=None) -> None: + """Add an event to the event file. + + Args: + event: An `Event` protocol buffer. + step: Number. Optional global step value for training process + to record with the event. + walltime: float. Optional walltime to override the default (current) + walltime (from time.time()) seconds after epoch + """ + event.wall_time = time.time() if walltime is None else walltime + if step is not None: + # Make sure step is converted from numpy or other formats + # since protobuf might not convert depending on version + event.step = int(step) + self.event_writer.add_event(event) + + def add_summary(self, summary, global_step=None, walltime=None) -> None: + """Add a `Summary` protocol buffer to the event file. + + This method wraps the provided summary in an `Event` protocol buffer + and adds it to the event file. + + Args: + summary: A `Summary` protocol buffer. + global_step: Number. Optional global step value for training process + to record with the summary. + walltime: float. Optional walltime to override the default (current) + walltime (from time.time()) seconds after epoch + """ + event = event_pb2.Event(summary=summary) + self.add_event(event, global_step, walltime) + + def add_graph(self, graph_profile, walltime=None) -> None: + """Add a `Graph` and step stats protocol buffer to the event file. + + Args: + graph_profile: A `Graph` and step stats protocol buffer. + walltime: float. Optional walltime to override the default (current) + walltime (from time.time()) seconds after epoch + """ + graph = graph_profile[0] + stepstats = graph_profile[1] + event = event_pb2.Event(graph_def=graph.SerializeToString()) + self.add_event(event, None, walltime) + + trm = event_pb2.TaggedRunMetadata( + tag="step1", run_metadata=stepstats.SerializeToString() + ) + event = event_pb2.Event(tagged_run_metadata=trm) + self.add_event(event, None, walltime) + + def add_onnx_graph(self, graph, walltime=None) -> None: + """Add a `Graph` protocol buffer to the event file. + + Args: + graph: A `Graph` protocol buffer. + walltime: float. Optional walltime to override the default (current) + _get_file_writerfrom time.time()) + """ + event = event_pb2.Event(graph_def=graph.SerializeToString()) + self.add_event(event, None, walltime) + + def flush(self) -> None: + """Flushes the event file to disk. + + Call this method to make sure that all pending events have been written to + disk. + """ + self.event_writer.flush() + + def close(self) -> None: + """Flushes the event file to disk and close the file. + + Call this method when you do not need the summary writer anymore. + """ + self.event_writer.close() + + def reopen(self) -> None: + """Reopens the EventFileWriter. + + Can be called after `close()` to add more events in the same directory. + The events will go into a new events file. + Does nothing if the EventFileWriter was not closed. + """ + # pyrefly: ignore [missing-attribute] + self.event_writer.reopen() + + +class SummaryWriter: + """Writes entries directly to event files in the log_dir to be consumed by TensorBoard. + + The `SummaryWriter` class provides a high-level API to create an event file + in a given directory and add summaries and events to it. The class updates the + file contents asynchronously. This allows a training program to call methods + to add data to the file directly from the training loop, without slowing down + training. + """ + + def __init__( + self, + log_dir=None, + comment="", + purge_step=None, + max_queue=10, + flush_secs=120, + filename_suffix="", + ) -> None: + """Create a `SummaryWriter` that will write out events and summaries to the event file. + + Args: + log_dir (str): Save directory location. Default is + runs/**CURRENT_DATETIME_HOSTNAME**, which changes after each run. + Use hierarchical folder structure to compare + between runs easily. e.g. pass in 'runs/exp1', 'runs/exp2', etc. + for each new experiment to compare across them. + comment (str): Comment log_dir suffix appended to the default + ``log_dir``. If ``log_dir`` is assigned, this argument has no effect. + purge_step (int): + When logging crashes at step :math:`T+X` and restarts at step :math:`T`, + any events whose global_step larger or equal to :math:`T` will be + purged and hidden from TensorBoard. + Note that crashed and resumed experiments should have the same ``log_dir``. + max_queue (int): Size of the queue for pending events and + summaries before one of the 'add' calls forces a flush to disk. + Default is ten items. + flush_secs (int): How often, in seconds, to flush the + pending events and summaries to disk. Default is every two minutes. + filename_suffix (str): Suffix added to all event filenames in + the log_dir directory. More details on filename construction in + tensorboard.summary.writer.event_file_writer.EventFileWriter. + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + + # create a summary writer with automatically generated folder name. + writer = SummaryWriter() + # folder location: runs/May04_22-14-54_s-MacBook-Pro.local/ + + # create a summary writer using the specified folder name. + writer = SummaryWriter("my_experiment") + # folder location: my_experiment + + # create a summary writer with comment appended. + writer = SummaryWriter(comment="LR_0.1_BATCH_16") + # folder location: runs/May04_22-14-54_s-MacBook-Pro.localLR_0.1_BATCH_16/ + + """ + torch._C._log_api_usage_once("tensorboard.create.summarywriter") + if not log_dir: + import socket + from datetime import datetime + + current_time = datetime.now().strftime("%b%d_%H-%M-%S") + log_dir = os.path.join( + "runs", current_time + "_" + socket.gethostname() + comment + ) + self.log_dir = log_dir + self.purge_step = purge_step + self.max_queue = max_queue + self.flush_secs = flush_secs + self.filename_suffix = filename_suffix + + # Initialize the file writers, but they can be cleared out on close + # and recreated later as needed. + self.file_writer = self.all_writers = None + self._get_file_writer() + + # Create default bins for histograms, see generate_testdata.py in tensorflow/tensorboard + v = 1e-12 + buckets = [] + neg_buckets = [] + while v < 1e20: + # pyrefly: ignore [bad-argument-type] + buckets.append(v) + # pyrefly: ignore [bad-argument-type] + neg_buckets.append(-v) + v *= 1.1 + self.default_bins = neg_buckets[::-1] + [0] + buckets + + def _get_file_writer(self): + """Return the default FileWriter instance. Recreates it if closed.""" + if self.all_writers is None or self.file_writer is None: + # pyrefly: ignore [bad-assignment] + self.file_writer = FileWriter( + self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix + ) + # pyrefly: ignore [bad-assignment, missing-attribute] + self.all_writers = {self.file_writer.get_logdir(): self.file_writer} + if self.purge_step is not None: + most_recent_step = self.purge_step + # pyrefly: ignore [missing-attribute] + self.file_writer.add_event( + Event(step=most_recent_step, file_version="brain.Event:2") + ) + # pyrefly: ignore [missing-attribute] + self.file_writer.add_event( + Event( + step=most_recent_step, + # pyrefly: ignore [missing-attribute] + session_log=SessionLog(status=SessionLog.START), + ) + ) + self.purge_step = None + return self.file_writer + + def get_logdir(self): + """Return the directory where event files will be written.""" + return self.log_dir + + def add_hparams( + self, + hparam_dict, + metric_dict, + hparam_domain_discrete=None, + run_name=None, + global_step=None, + ) -> None: + """Add a set of hyperparameters to be compared in TensorBoard. + + Args: + hparam_dict (dict): Each key-value pair in the dictionary is the + name of the hyper parameter and it's corresponding value. + The type of the value can be one of `bool`, `string`, `float`, + `int`, or `None`. + metric_dict (dict): Each key-value pair in the dictionary is the + name of the metric and it's corresponding value. Note that the key used + here should be unique in the tensorboard record. Otherwise the value + you added by ``add_scalar`` will be displayed in hparam plugin. In most + cases, this is unwanted. + hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that + contains names of the hyperparameters and all discrete values they can hold + run_name (str): Name of the run, to be included as part of the logdir. + If unspecified, will use current timestamp. + global_step (int): Global step value to record + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + with SummaryWriter() as w: + for i in range(5): + w.add_hparams({'lr': 0.1*i, 'bsize': i}, + {'hparam/accuracy': 10*i, 'hparam/loss': 10*i}) + + Expected result: + + .. image:: _static/img/tensorboard/add_hparam.png + :scale: 50 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_hparams") + if type(hparam_dict) is not dict or type(metric_dict) is not dict: + raise TypeError("hparam_dict and metric_dict should be dictionary.") + exp, ssi, sei = hparams(hparam_dict, metric_dict, hparam_domain_discrete) + + if not run_name: + run_name = str(time.time()) + logdir = os.path.join(self._get_file_writer().get_logdir(), run_name) + with SummaryWriter(log_dir=logdir) as w_hp: + w_hp.file_writer.add_summary(exp, global_step) + w_hp.file_writer.add_summary(ssi, global_step) + w_hp.file_writer.add_summary(sei, global_step) + for k, v in metric_dict.items(): + w_hp.add_scalar(k, v, global_step) + + def add_scalar( + self, + tag, + scalar_value, + global_step=None, + walltime=None, + new_style=False, + double_precision=False, + ) -> None: + """Add scalar data to summary. + + Args: + tag (str): Data identifier + scalar_value (float or string/blobname): Value to save + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + with seconds after epoch of event + new_style (boolean): Whether to use new style (tensor field) or old + style (simple_value field). New style could lead to faster data loading. + Examples:: + + from torch.utils.tensorboard import SummaryWriter + writer = SummaryWriter() + x = range(100) + for i in x: + writer.add_scalar('y=2x', i * 2, i) + writer.close() + + Expected result: + + .. image:: _static/img/tensorboard/add_scalar.png + :scale: 50 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_scalar") + + summary = scalar( + tag, scalar_value, new_style=new_style, double_precision=double_precision + ) + self._get_file_writer().add_summary(summary, global_step, walltime) + + def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None) -> None: + """Add many scalar data to summary. + + Args: + main_tag (str): The parent name for the tags + tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + writer = SummaryWriter() + r = 5 + for i in range(100): + writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r), + 'xcosx':i*np.cos(i/r), + 'tanx': np.tan(i/r)}, i) + writer.close() + # This call adds three values to the same scalar plot with the tag + # 'run_14h' in TensorBoard's scalar section. + + Expected result: + + .. image:: _static/img/tensorboard/add_scalars.png + :scale: 50 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_scalars") + walltime = time.time() if walltime is None else walltime + fw_logdir = self._get_file_writer().get_logdir() + for tag, scalar_value in tag_scalar_dict.items(): + fw_tag = fw_logdir + "/" + main_tag.replace("/", "_") + "_" + tag + if self.all_writers is None: + raise AssertionError("self.all_writers is None") + if fw_tag in self.all_writers: + fw = self.all_writers[fw_tag] + else: + fw = FileWriter( + fw_tag, self.max_queue, self.flush_secs, self.filename_suffix + ) + self.all_writers[fw_tag] = fw + fw.add_summary(scalar(main_tag, scalar_value), global_step, walltime) + + def add_tensor( + self, + tag, + tensor, + global_step=None, + walltime=None, + ) -> None: + """Add tensor data to summary. + + Args: + tag (str): Data identifier + tensor (torch.Tensor): tensor to save + global_step (int): Global step value to record + Examples:: + + from torch.utils.tensorboard import SummaryWriter + writer = SummaryWriter() + x = torch.tensor([1,2,3]) + writer.add_scalar('x', x) + writer.close() + + Expected result: + Summary::tensor::float_val [1,2,3] + ::tensor::shape [3] + ::tag 'x' + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_tensor") + + summary = tensor_proto(tag, tensor) + self._get_file_writer().add_summary(summary, global_step, walltime) + + def add_histogram( + self, + tag, + values, + global_step=None, + bins="tensorflow", + walltime=None, + max_bins=None, + ) -> None: + """Add histogram to summary. + + Args: + tag (str): Data identifier + values (torch.Tensor, numpy.ndarray, or string/blobname): Values to build histogram + global_step (int): Global step value to record + bins (str): One of {'tensorflow','auto', 'fd', ...}. This determines how the bins are made. You can find + other options in: https://numpy.org/doc/stable/reference/generated/numpy.histogram.html + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + import numpy as np + writer = SummaryWriter() + for i in range(10): + x = np.random.random(1000) + writer.add_histogram('distribution centers', x + i, i) + writer.close() + + Expected result: + + .. image:: _static/img/tensorboard/add_histogram.png + :scale: 50 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_histogram") + if isinstance(bins, str) and bins == "tensorflow": + bins = self.default_bins + self._get_file_writer().add_summary( + histogram(tag, values, bins, max_bins=max_bins), global_step, walltime + ) + + def add_histogram_raw( + self, + tag, + min, + max, + num, + sum, + sum_squares, + bucket_limits, + bucket_counts, + global_step=None, + walltime=None, + ) -> None: + """Add histogram with raw data. + + Args: + tag (str): Data identifier + min (float or int): Min value + max (float or int): Max value + num (int): Number of values + sum (float or int): Sum of all values + sum_squares (float or int): Sum of squares for all values + bucket_limits (torch.Tensor, numpy.ndarray): Upper value per bucket. + The number of elements of it should be the same as `bucket_counts`. + bucket_counts (torch.Tensor, numpy.ndarray): Number of values per bucket + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/README.md + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + import numpy as np + writer = SummaryWriter() + dummy_data = [] + for idx, value in enumerate(range(50)): + dummy_data += [idx + 0.001] * value + + bins = list(range(50+2)) + bins = np.array(bins) + values = np.array(dummy_data).astype(float).reshape(-1) + counts, limits = np.histogram(values, bins=bins) + sum_sq = values.dot(values) + writer.add_histogram_raw( + tag='histogram_with_raw_data', + min=values.min(), + max=values.max(), + num=len(values), + sum=values.sum(), + sum_squares=sum_sq, + bucket_limits=limits[1:].tolist(), + bucket_counts=counts.tolist(), + global_step=0) + writer.close() + + Expected result: + + .. image:: _static/img/tensorboard/add_histogram_raw.png + :scale: 50 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_histogram_raw") + if len(bucket_limits) != len(bucket_counts): + raise ValueError( + "len(bucket_limits) != len(bucket_counts), see the document." + ) + self._get_file_writer().add_summary( + histogram_raw( + tag, min, max, num, sum, sum_squares, bucket_limits, bucket_counts + ), + global_step, + walltime, + ) + + def add_image( + self, tag, img_tensor, global_step=None, walltime=None, dataformats="CHW" + ) -> None: + """Add image data to summary. + + Note that this requires the ``pillow`` package. + + Args: + tag (str): Data identifier + img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + dataformats (str): Image data format specification of the form + CHW, HWC, HW, WH, etc. + Shape: + img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to + convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job. + Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitable as long as + corresponding ``dataformats`` argument is passed, e.g. ``CHW``, ``HWC``, ``HW``. + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + import numpy as np + img = np.zeros((3, 100, 100)) + img[0] = np.arange(0, 10000).reshape(100, 100) / 10000 + img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000 + + img_HWC = np.zeros((100, 100, 3)) + img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 + img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000 + + writer = SummaryWriter() + writer.add_image('my_image', img, 0) + + # If you have non-default dimension setting, set the dataformats argument. + writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC') + writer.close() + + Expected result: + + .. image:: _static/img/tensorboard/add_image.png + :scale: 50 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_image") + self._get_file_writer().add_summary( + image(tag, img_tensor, dataformats=dataformats), global_step, walltime + ) + + def add_images( + self, tag, img_tensor, global_step=None, walltime=None, dataformats="NCHW" + ) -> None: + """Add batched image data to summary. + + Note that this requires the ``pillow`` package. + + Args: + tag (str): Data identifier + img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + dataformats (str): Image data format specification of the form + NCHW, NHWC, CHW, HWC, HW, WH, etc. + Shape: + img_tensor: Default is :math:`(N, 3, H, W)`. If ``dataformats`` is specified, other shape will be + accepted. e.g. NCHW or NHWC. + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + import numpy as np + + img_batch = np.zeros((16, 3, 100, 100)) + for i in range(16): + img_batch[i, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 / 16 * i + img_batch[i, 1] = (1 - np.arange(0, 10000).reshape(100, 100) / 10000) / 16 * i + + writer = SummaryWriter() + writer.add_images('my_image_batch', img_batch, 0) + writer.close() + + Expected result: + + .. image:: _static/img/tensorboard/add_images.png + :scale: 30 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_images") + self._get_file_writer().add_summary( + image(tag, img_tensor, dataformats=dataformats), global_step, walltime + ) + + def add_image_with_boxes( + self, + tag, + img_tensor, + box_tensor, + global_step=None, + walltime=None, + rescale=1, + dataformats="CHW", + labels=None, + ) -> None: + """Add image and draw bounding boxes on the image. + + Args: + tag (str): Data identifier + img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data + box_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Box data (for detected objects) + box should be represented as [x1, y1, x2, y2]. + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + rescale (float): Optional scale override + dataformats (str): Image data format specification of the form + NCHW, NHWC, CHW, HWC, HW, WH, etc. + labels (list of string): The label to be shown for each bounding box. + Shape: + img_tensor: Default is :math:`(3, H, W)`. It can be specified with ``dataformats`` argument. + e.g. CHW or HWC + + box_tensor: (torch.Tensor, numpy.ndarray, or string/blobname): NX4, where N is the number of + boxes and each 4 elements in a row represents (xmin, ymin, xmax, ymax). + """ + torch._C._log_api_usage_once("tensorboard.logging.add_image_with_boxes") + if labels is not None: + if isinstance(labels, str): + labels = [labels] + if len(labels) != box_tensor.shape[0]: + labels = None + self._get_file_writer().add_summary( + image_boxes( + tag, + img_tensor, + box_tensor, + rescale=rescale, + dataformats=dataformats, + labels=labels, + ), + global_step, + walltime, + ) + + def add_figure( + self, + tag: str, + figure: Union["Figure", list["Figure"]], + global_step: int | None = None, + close: bool = True, + walltime: float | None = None, + ) -> None: + """Render matplotlib figure into an image and add it to summary. + + Note that this requires the ``matplotlib`` package. + + Args: + tag: Data identifier + figure: Figure or a list of figures + global_step: Global step value to record + close: Flag to automatically close the figure + walltime: Optional override default walltime (time.time()) + seconds after epoch of event + """ + torch._C._log_api_usage_once("tensorboard.logging.add_figure") + if isinstance(figure, list): + self.add_image( + tag, + figure_to_image(figure, close), + global_step, + walltime, + dataformats="NCHW", + ) + else: + self.add_image( + tag, + figure_to_image(figure, close), + global_step, + walltime, + dataformats="CHW", + ) + + def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None) -> None: + """Add video data to summary. + + Note that this requires the ``moviepy`` package. + + Args: + tag (str): Data identifier + vid_tensor (torch.Tensor): Video data + global_step (int): Global step value to record + fps (float or int): Frames per second + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + Shape: + vid_tensor: :math:`(N, T, C, H, W)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`. + """ + torch._C._log_api_usage_once("tensorboard.logging.add_video") + self._get_file_writer().add_summary( + video(tag, vid_tensor, fps), global_step, walltime + ) + + def add_audio( + self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None + ) -> None: + """Add audio data to summary. + + Args: + tag (str): Data identifier + snd_tensor (torch.Tensor): Sound data + global_step (int): Global step value to record + sample_rate (int): sample rate in Hz + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + Shape: + snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1]. + """ + torch._C._log_api_usage_once("tensorboard.logging.add_audio") + self._get_file_writer().add_summary( + audio(tag, snd_tensor, sample_rate=sample_rate), global_step, walltime + ) + + def add_text(self, tag, text_string, global_step=None, walltime=None) -> None: + """Add text data to summary. + + Args: + tag (str): Data identifier + text_string (str): String to save + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + Examples:: + + writer.add_text('lstm', 'This is an lstm', 0) + writer.add_text('rnn', 'This is an rnn', 10) + """ + torch._C._log_api_usage_once("tensorboard.logging.add_text") + self._get_file_writer().add_summary( + text(tag, text_string), global_step, walltime + ) + + def add_onnx_graph(self, prototxt) -> None: + torch._C._log_api_usage_once("tensorboard.logging.add_onnx_graph") + self._get_file_writer().add_onnx_graph(load_onnx_graph(prototxt)) + + def add_graph( + self, model, input_to_model=None, verbose=False, use_strict_trace=True + ) -> None: + """Add graph data to summary. + + Args: + model (torch.nn.Module): Model to draw. + input_to_model (torch.Tensor or list of torch.Tensor): A variable or a tuple of + variables to be fed. + verbose (bool): Whether to print graph structure in console. + use_strict_trace (bool): Whether to pass keyword argument `strict` to + `torch.jit.trace`. Pass False when you want the tracer to + record your mutable container types (list, dict) + """ + torch._C._log_api_usage_once("tensorboard.logging.add_graph") + # A valid PyTorch model should have a 'forward' method + self._get_file_writer().add_graph( + graph(model, input_to_model, verbose, use_strict_trace) + ) + + @staticmethod + def _encode(rawstr): + # I'd use urllib but, I'm unsure about the differences from python3 to python2, etc. + retval = rawstr + retval = retval.replace("%", f"%{ord('%'):02x}") + retval = retval.replace("/", f"%{ord('/'):02x}") + retval = retval.replace("\\", "%%%02x" % (ord("\\"))) # noqa: UP031 + return retval + + def add_embedding( + self, + mat, + metadata=None, + label_img=None, + global_step=None, + tag="default", + metadata_header=None, + ) -> None: + """Add embedding projector data to summary. + + Args: + mat (torch.Tensor or numpy.ndarray): A matrix which each row is the feature vector of the data point + metadata (list): A list of labels, each element will be converted to string + label_img (torch.Tensor): Images correspond to each data point + global_step (int): Global step value to record + tag (str): Name for the embedding + metadata_header (list): A list of headers for multi-column metadata. If given, each metadata must be + a list with values corresponding to headers. + Shape: + mat: :math:`(N, D)`, where N is number of data and D is feature dimension + + label_img: :math:`(N, C, H, W)` + + Examples:: + + import keyword + import torch + meta = [] + while len(meta)<100: + meta = meta+keyword.kwlist # get some strings + meta = meta[:100] + + for i, v in enumerate(meta): + meta[i] = v+str(i) + + label_img = torch.rand(100, 3, 10, 32) + for i in range(100): + label_img[i]*=i/100.0 + + writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img) + writer.add_embedding(torch.randn(100, 5), label_img=label_img) + writer.add_embedding(torch.randn(100, 5), metadata=meta) + + .. note:: + Categorical (i.e. non-numeric) metadata cannot have more than 50 unique values if they are to be used for + coloring in the embedding projector. + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_embedding") + mat = make_np(mat) + if global_step is None: + global_step = 0 + # clear pbtxt? + + # Maybe we should encode the tag so slashes don't trip us up? + # I don't think this will mess us up, but better safe than sorry. + subdir = f"{str(global_step).zfill(5)}/{self._encode(tag)}" + save_path = os.path.join(self._get_file_writer().get_logdir(), subdir) + + fs = tf.io.gfile + if fs.exists(save_path): + if fs.isdir(save_path): + print( + "warning: Embedding dir exists, did you set global_step for add_embedding()?" + ) + else: + raise NotADirectoryError( + f"Path: `{save_path}` exists, but is a file. Cannot proceed." + ) + else: + fs.makedirs(save_path) + + if metadata is not None: + if mat.shape[0] != len( + metadata + ): + raise AssertionError("#labels should equal with #data points") + make_tsv(metadata, save_path, metadata_header=metadata_header) + + if label_img is not None: + if mat.shape[0] != label_img.shape[0]: + raise AssertionError("#images should equal with #data points") + make_sprite(label_img, save_path) + + if mat.ndim != 2: + raise AssertionError("mat should be 2D, where mat.size(0) is the number of data points") + make_mat(mat, save_path) + + # Filesystem doesn't necessarily have append semantics, so we store an + # internal buffer to append to and re-write whole file after each + # embedding is added + if not hasattr(self, "_projector_config"): + self._projector_config = ProjectorConfig() + embedding_info = get_embedding_info( + metadata, label_img, subdir, global_step, tag + ) + self._projector_config.embeddings.extend([embedding_info]) + + + from google.protobuf import text_format + + config_pbtxt = text_format.MessageToString(self._projector_config) + write_pbtxt(self._get_file_writer().get_logdir(), config_pbtxt) + + def add_pr_curve( + self, + tag, + labels, + predictions, + global_step=None, + num_thresholds=127, + weights=None, + walltime=None, + ) -> None: + """Add precision recall curve. + + Plotting a precision-recall curve lets you understand your model's + performance under different threshold settings. With this function, + you provide the ground truth labeling (T/F) and prediction confidence + (usually the output of your model) for each target. The TensorBoard UI + will let you choose the threshold interactively. + + Args: + tag (str): Data identifier + labels (torch.Tensor, numpy.ndarray, or string/blobname): + Ground truth data. Binary label for each element. + predictions (torch.Tensor, numpy.ndarray, or string/blobname): + The probability that an element be classified as true. + Value should be in [0, 1] + global_step (int): Global step value to record + num_thresholds (int): Number of thresholds used to draw the curve. + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + import numpy as np + labels = np.random.randint(2, size=100) # binary label + predictions = np.random.rand(100) + writer = SummaryWriter() + writer.add_pr_curve('pr_curve', labels, predictions, 0) + writer.close() + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve") + labels, predictions = make_np(labels), make_np(predictions) + self._get_file_writer().add_summary( + pr_curve(tag, labels, predictions, num_thresholds, weights), + global_step, + walltime, + ) + + def add_pr_curve_raw( + self, + tag, + true_positive_counts, + false_positive_counts, + true_negative_counts, + false_negative_counts, + precision, + recall, + global_step=None, + num_thresholds=127, + weights=None, + walltime=None, + ) -> None: + """Add precision recall curve with raw data. + + Args: + tag (str): Data identifier + true_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): true positive counts + false_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): false positive counts + true_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): true negative counts + false_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): false negative counts + precision (torch.Tensor, numpy.ndarray, or string/blobname): precision + recall (torch.Tensor, numpy.ndarray, or string/blobname): recall + global_step (int): Global step value to record + num_thresholds (int): Number of thresholds used to draw the curve. + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/README.md + """ + torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve_raw") + self._get_file_writer().add_summary( + pr_curve_raw( + tag, + true_positive_counts, + false_positive_counts, + true_negative_counts, + false_negative_counts, + precision, + recall, + num_thresholds, + weights, + ), + global_step, + walltime, + ) + + def add_custom_scalars_multilinechart( + self, tags, category="default", title="untitled" + ) -> None: + """Shorthand for creating multilinechart. Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*. + + Args: + tags (list): list of tags that have been used in ``add_scalar()`` + + Examples:: + + writer.add_custom_scalars_multilinechart(['twse/0050', 'twse/2330']) + """ + torch._C._log_api_usage_once( + "tensorboard.logging.add_custom_scalars_multilinechart" + ) + layout = {category: {title: ["Multiline", tags]}} + self._get_file_writer().add_summary(custom_scalars(layout)) + + def add_custom_scalars_marginchart( + self, tags, category="default", title="untitled" + ) -> None: + """Shorthand for creating marginchart. + + Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*, + which should have exactly 3 elements. + + Args: + tags (list): list of tags that have been used in ``add_scalar()`` + + Examples:: + + writer.add_custom_scalars_marginchart(['twse/0050', 'twse/2330', 'twse/2006']) + """ + torch._C._log_api_usage_once( + "tensorboard.logging.add_custom_scalars_marginchart" + ) + if len(tags) != 3: + raise AssertionError(f"Expected 3 tags, got {len(tags)}.") + layout = {category: {title: ["Margin", tags]}} + self._get_file_writer().add_summary(custom_scalars(layout)) + + def add_custom_scalars(self, layout) -> None: + """Create special chart by collecting charts tags in 'scalars'. + + NOTE: This function can only be called once for each SummaryWriter() object. + + Because it only provides metadata to tensorboard, the function can be called before or after the training loop. + + Args: + layout (dict): {categoryName: *charts*}, where *charts* is also a dictionary + {chartName: *ListOfProperties*}. The first element in *ListOfProperties* is the chart's type + (one of **Multiline** or **Margin**) and the second element should be a list containing the tags + you have used in add_scalar function, which will be collected into the new chart. + + Examples:: + + layout = {'Taiwan':{'twse':['Multiline',['twse/0050', 'twse/2330']]}, + 'USA':{ 'dow':['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']], + 'nasdaq':['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]}} + + writer.add_custom_scalars(layout) + """ + torch._C._log_api_usage_once("tensorboard.logging.add_custom_scalars") + self._get_file_writer().add_summary(custom_scalars(layout)) + + def add_mesh( + self, + tag, + vertices, + colors=None, + faces=None, + config_dict=None, + global_step=None, + walltime=None, + ) -> None: + """Add meshes or 3D point clouds to TensorBoard. + + The visualization is based on Three.js, + so it allows users to interact with the rendered object. Besides the basic definitions + such as vertices, faces, users can further provide camera parameter, lighting condition, etc. + Please check https://threejs.org/docs/index.html#manual/en/introduction/Creating-a-scene for + advanced usage. + + Args: + tag (str): Data identifier + vertices (torch.Tensor): List of the 3D coordinates of vertices. + colors (torch.Tensor): Colors for each vertex + faces (torch.Tensor): Indices of vertices within each triangle. (Optional) + config_dict: Dictionary with ThreeJS classes names and configuration. + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + + Shape: + vertices: :math:`(B, N, 3)`. (batch, number_of_vertices, channels) + + colors: :math:`(B, N, 3)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`. + + faces: :math:`(B, N, 3)`. The values should lie in [0, number_of_vertices] for type `uint8`. + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + vertices_tensor = torch.as_tensor([ + [1, 1, 1], + [-1, -1, 1], + [1, -1, -1], + [-1, 1, -1], + ], dtype=torch.float).unsqueeze(0) + colors_tensor = torch.as_tensor([ + [255, 0, 0], + [0, 255, 0], + [0, 0, 255], + [255, 0, 255], + ], dtype=torch.int).unsqueeze(0) + faces_tensor = torch.as_tensor([ + [0, 2, 3], + [0, 3, 1], + [0, 1, 2], + [1, 3, 2], + ], dtype=torch.int).unsqueeze(0) + + writer = SummaryWriter() + writer.add_mesh('my_mesh', vertices=vertices_tensor, colors=colors_tensor, faces=faces_tensor) + + writer.close() + """ + torch._C._log_api_usage_once("tensorboard.logging.add_mesh") + self._get_file_writer().add_summary( + mesh(tag, vertices, colors, faces, config_dict), global_step, walltime + ) + + def flush(self) -> None: + """Flushes the event file to disk. + + Call this method to make sure that all pending events have been written to + disk. + """ + if self.all_writers is None: + return + for writer in self.all_writers.values(): + writer.flush() + + def close(self) -> None: + if self.all_writers is None: + return # ignore double close + for writer in self.all_writers.values(): + writer.flush() + writer.close() + # pyrefly: ignore [bad-assignment] + self.file_writer = self.all_writers = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/viz/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/viz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/viz/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/viz/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ecf19d33fc8f41b116238525edf5745c2d17ce8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/viz/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/viz/__pycache__/_cycles.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/viz/__pycache__/_cycles.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c61dc90ace58c9fc9c4d9a87857042d40800bd60 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/viz/__pycache__/_cycles.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/viz/_cycles.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/viz/_cycles.py new file mode 100644 index 0000000000000000000000000000000000000000..df4bf34db211486edf89a9d4580c1bd792ee7097 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/utils/viz/_cycles.py @@ -0,0 +1,506 @@ +# mypy: allow-untyped-defs +import gc +import sys +from typing import Any, NamedTuple +import types +import weakref +import json +from tempfile import NamedTemporaryFile +import torch +from torch.cuda._memory_viz import _frames_fmt, _block_extra +import atexit +import logging +logger = logging.getLogger(__name__) + +def observe_garbage(observer): + enabled = True + + def disable() -> None: + # when GC runs during exit, things like `sys` will already be unloaded + # so we have to disable the callback to avoid hitting errors. + nonlocal enabled + enabled = False + atexit.register(disable) + + def gc_callback(phase, info) -> None: + nonlocal enabled + if not enabled: + return + if phase == "start": + gc.set_debug(gc.DEBUG_SAVEALL) + elif phase == "stop": + orig_trace = sys.getprofile() + self_return = [False] + + def do_collect(*args, **kwargs): + nonlocal enabled + if not self_return[0]: + self_return[0] = True + else: + sys.setprofile(orig_trace) + enabled = False + try: + # things in gc.garbage have survived a collection + # so to free them we have to collect a generation greater than them + # but that might _also_ free other stuff and we don't want to miss + # that stuff. So we have to now force gc at the highest level here, + # report all of what we found, _then_ we can free it up. + if info['generation'] != 2: + gc.collect() + observer(gc.garbage) + gc.garbage.clear() + # we have to re-run GC to clean up the cycles + # we saved from before. + gc.set_debug(0) + before = torch.cuda.memory_allocated() + gc.collect() + after = torch.cuda.memory_allocated() + if before != after: + logger.warning("CUDA Memory changed during GC, %d bytes freed.", before - after) + finally: + enabled = True + if orig_trace is not None: + return orig_trace(*args, **kwargs) + sys.setprofile(do_collect) + + gc.callbacks.append(gc_callback) + + # provide a way to disarm the callback + def remove() -> None: + gc.callbacks.remove(gc_callback) + return remove + +# Function to visualize cycles adapted from refcycle: +# Copyright 2013 Mark Dickinson +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def _get_cell_type(): + def f(x=None): + return lambda: x + return type(f().__closure__[0]) + +CellType = _get_cell_type() + +def annotated_references(obj): + """ + Return known information about references held by the given object. + + Returns a mapping from referents to lists of descriptions. Note that there + may be more than one edge leading to any particular referent; hence the + need for a list. Descriptions are currently strings. + + """ + references: dict[int, list[str]] = {} + + def add_reference(name, obj) -> None: + references.setdefault(id(obj), []).append(name) + + def add_attrs(*attrs) -> None: + for attr in attrs: + if hasattr(obj, attr): + add_reference(attr, getattr(obj, attr)) + + def add_cell_references() -> None: + try: + add_attrs("cell_contents") + except ValueError: + # if cell_contents is empty, + # accessing it raises ValueError + # in this case there is no object to + # annotate + pass + + def add_function_references() -> None: + add_attrs("__defaults__", + "__closure__", + "__globals__", + "__code__", + "__name__", + "__module__", + "__doc__" + "__qualname__", + "__annotations__", + "__kwdefaults__") + + + def add_sequence_references() -> None: + for position, item in enumerate(obj): + add_reference(f"[{position}]", item) + + def add_dict_references() -> None: + for key, value in obj.items(): + add_reference("key", key) + add_reference(f"[{repr(key)}]", value) + + def add_set_references() -> None: + for elt in obj: + add_reference("element", elt) + + def add_bound_method_references() -> None: + add_attrs("__self__", "__func__", "im_class") + + def add_weakref_references() -> None: + # For subclasses of weakref, we can't reliably distinguish the + # callback (if any) from other attributes. + if type(obj) is weakref.ref: + referents = gc.get_referents(obj) + if len(referents) == 1: + target = referents[0] + add_reference("__callback__", target) + + + def add_frame_references() -> None: + f_locals = obj.f_locals + add_attrs("f_back", "f_code", "f_builtins", "f_globals", "f_trace", "f_locals") + # Some badly-behaved code replaces the f_locals dict with + # something that doesn't support the full dict interface. So we + # only continue with the annotation if f_locals is a Python dict. + if type(f_locals) is dict: + for name, local in obj.f_locals.items(): + add_reference(f"local {name}", local) + + def add_getset_descriptor_references() -> None: + add_attrs("__objclass__", "__name__", "__doc__") + + type_based_references = { + tuple: add_sequence_references, + list: add_sequence_references, + dict: add_dict_references, + set: add_set_references, + frozenset: add_set_references, + types.FunctionType: add_function_references, + types.FrameType: add_frame_references, + CellType: add_cell_references, + types.MethodType: add_bound_method_references, + weakref.ref: add_weakref_references, + types.GetSetDescriptorType: add_getset_descriptor_references, + } + + for type_ in type(obj).__mro__: + if type_ in type_based_references: + type_based_references[type_]() + + add_attrs("__dict__", "__class__") + if isinstance(obj, type): + add_attrs("__mro__") + + return references + +############################################################################### +# Object annotations. + + +BASE_TYPES = (int, float, complex, type(None), str, bytes) +FRAME_FILENAME_LIMIT = 32 + +def object_annotation(obj): + """ + Return a string to be used for Graphviz nodes. + + The string should be short but as informative as possible. + """ + + def format_sequence(obj): + body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for x in obj[:8]) + if len(obj) > 8: + body = f'{body}, ...{len(obj) - 8}' + return body + + # For basic types, use the repr. + if isinstance(obj, BASE_TYPES): + return repr(obj) + if type(obj).__name__ == 'function': + return f"function\n{obj.__name__}" + elif isinstance(obj, types.MethodType): + try: + func_name = obj.__func__.__qualname__ + except AttributeError: + func_name = "" + return f"instancemethod\n{func_name}" + elif isinstance(obj, list): + return f"[{format_sequence(obj)}]" + elif isinstance(obj, tuple): + return f"({format_sequence(obj)})" + elif isinstance(obj, dict): + return f"dict[{len(obj)}]" + elif isinstance(obj, types.ModuleType): + return f"module\n{obj.__name__}" + elif isinstance(obj, type): + return f"type\n{obj.__name__}" + elif isinstance(obj, weakref.ref): + referent = obj() + if referent is None: + return "weakref (dead referent)" + else: + return f"weakref to id 0x{id(referent):x}" + elif isinstance(obj, types.FrameType): + filename = obj.f_code.co_filename + if len(filename) > FRAME_FILENAME_LIMIT: + filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):] + return f"frame\n{filename}:{obj.f_lineno}" + elif is_cuda_tensor(obj): + return f"object\n{type(obj).__module__}.{type(obj).__name__} ({obj.shape})" + else: + return f"object\n{type(obj).__module__}.{type(obj).__name__}" + + + +class Node(NamedTuple): + label: str + context: str | None + root: bool + referrents: list[tuple[str, int]] + +def create_graph(objects, *, context=None, filter=None): + if context is None: + context = cuda_allocation_context() + if filter is None: + filter = is_cuda_tensor + + objects = [obj for obj in objects if not isinstance(obj, weakref.ProxyTypes)] + nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects] + node_referrers: list[list[int]] = [[] for obj in objects] + + id_to_node = {id(obj): i for i, obj in enumerate(objects)} + for obj in objects: + fidx = id_to_node[id(obj)] + f = nodes[fidx] + references = annotated_references(obj) + for referrent in gc.get_referents(obj): + rid = id(referrent) + tidx = id_to_node.get(rid) + if tidx is None: + continue + labels = references.get(rid, ["?"]) + node_referrers[tidx].append(fidx) + for label in labels: + f.referrents.append((label, tidx)) + + to_search = [i for i, n in enumerate(nodes) if n.root] + to_keep = set() + while to_search: + idx = to_search.pop() + if idx in to_keep: + continue + to_keep.add(idx) + referrers = node_referrers[idx] + to_search.extend(referrers) + id_to_filtered_id: dict[int, int] = {} + filtered: list[Any] = [] + for i, n in enumerate(nodes): + if i in to_keep: + id_to_filtered_id[i] = len(id_to_filtered_id) + filtered.append(n) + for n in filtered: + n.referrents[:] = [(label, id_to_filtered_id[idx]) + for (label, idx) in n.referrents + if idx in id_to_filtered_id] + return filtered + +def escape(n): + return json.dumps(n) + + +def is_cuda_tensor(obj): + return ( + isinstance(obj, torch.Tensor) and + obj.device.type == "cuda" and + not isinstance(obj, torch._subclasses.FakeTensor) + ) + +def cuda_allocation_context(): + snapshot = torch.cuda.memory._snapshot() + addr_to_frame = {} + for seg in snapshot['segments']: + addr = seg['address'] + for blk in seg['blocks']: + if blk['state'] == 'active_allocated': + frames, _real_size = _block_extra(blk) + addr_to_frame[addr] = frames + addr += blk['size'] + + def object_context(obj): + if is_cuda_tensor(obj): + addr = obj.untyped_storage().data_ptr() + frames = addr_to_frame.get(addr) + if frames is not None: + return '\n'.join(_frames_fmt(frames, full_filename=True)) + return None + return object_context + +def to_dot(nodes): + lines = ["digraph GraphName {", "node [shape=rect];", 'rankdir=LR;'] + for i, n in enumerate(nodes): + lines.append(f'{i} [label={escape(n.label)}, color={"red" if n.root else "black"}];') + + for i, f in enumerate(nodes): + for label, j in f.referrents: + lines.append(f'{i} -> {j} [label = {escape(label)}]') + lines.append("}\n") + return '\n'.join(lines) + +_template = """ + + + + + + +
+
+
+
+
Mouse over tensor objects to see where they were allocated.
+
+
+ + + + +""" +_listener_template = """ +document.getElementById('node{id}').addEventListener('mouseover', function(event) {{ + document.getElementById("stacktrace").textContent = {stack} +}}) +""" +def to_html(nodes): + listeners = [] + for i, n in enumerate(nodes): + if n.context is None: + continue + s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}')) + # pyrefly: ignore [bad-argument-type] + listeners.append(s) + dot = to_dot(nodes) + return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners)) + +def observe_tensor_cycles(callback): + torch.cuda.memory._record_memory_history(max_entries=100000) + + def observer(garbage) -> None: + if garbage: + if not any(is_cuda_tensor(obj) for obj in garbage): + logger.info("No CUDA Tensors found in garbage") + return + callback(to_html(create_graph(garbage))) + return observe_garbage(observer) + + +def warn_tensor_cycles(): + """ + Install a warning that reports whenever a cycle that is holding CUDA memory is observed. + + The warning produces an .html file that visualizes the cycle, + and links it to the stack frame that allocated the CUDA tensor. + + Reference cycles are freed by the cycle collector rather than being cleaned up + when the objects in the cycle first become unreachable. If a cycle points to a tensor, + the CUDA memory for that tensor will not be freed until garbage collection runs. + Accumulation of CUDA allocations can lead to out of memory errors (OOMs), as well as + non-deterministic allocation behavior which is harder to debug. + """ + logger.info("Watching Python reference cycles for CUDA Tensors.") + + def write_and_log(html) -> None: + with NamedTemporaryFile('w', suffix='.html') as f: + f.write(html) + logger.warning('Reference cycle includes a CUDA Tensor see visualization of cycle %s', f.name) + return observe_tensor_cycles(write_and_log)